From 7c8288552b3673574e0649e031963b8e7e8d4981 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 14 Jan 2023 01:00:55 +0000 Subject: [PATCH] tie the weights of lm_head.weight and transformer.wte.weight, i.e. the last linear layer of decoder and the token embeddings. --- model.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/model.py b/model.py index 18cf8b2..a301379 100644 --- a/model.py +++ b/model.py @@ -115,9 +115,10 @@ class GPT(nn.Module): ln_f = nn.LayerNorm(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head.weight = self.transformer.wte.weight # https://paperswithcode.com/method/weight-tying - # report number of parameters (note we don't count the decoder parameters in lm_head) - n_params = sum(p.numel() for p in self.transformer.parameters()) + # report number of parameters + n_params = sum(p.numel() for p in self.parameters()) print("number of parameters: %.2fM" % (n_params/1e6,)) def forward(self, idx, targets=None): @@ -156,8 +157,9 @@ class GPT(nn.Module): block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] @classmethod - def from_pretrained(cls, model_type, override_args): + def from_pretrained(cls, model_type, override_args=None): assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} + override_args = override_args or {} # default to empty dict # only dropout can be overridden see more notes below assert all(k == 'dropout' for k in override_args) from transformers import GPT2LMHeadModel @@ -235,6 +237,14 @@ class GPT(nn.Module): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) + # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they + # will appear in the no_decay and decay sets respectively after the above. + # In addition, because named_parameters() doesn't return duplicates, it + # will only return the first occurence, key'd by 'transformer.wte.weight', below. + # so let's manually remove 'lm_head.weight' from decay set. This will include + # this tensor into optimization via transformer.wte.weight only, and not decayed. + decay.remove('lm_head.weight') + # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay