1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

tie the weights of lm_head.weight and transformer.wte.weight, i.e. the last linear layer of decoder and the token embeddings.

This commit is contained in:
Andrej Karpathy 2023-01-14 01:00:55 +00:00
parent 32b4f08d9d
commit 7c8288552b

View File

@ -115,9 +115,10 @@ class GPT(nn.Module):
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.lm_head.weight = self.transformer.wte.weight # https://paperswithcode.com/method/weight-tying
# report number of parameters (note we don't count the decoder parameters in lm_head)
n_params = sum(p.numel() for p in self.transformer.parameters())
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
def forward(self, idx, targets=None):
@ -156,8 +157,9 @@ class GPT(nn.Module):
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
@classmethod
def from_pretrained(cls, model_type, override_args):
def from_pretrained(cls, model_type, override_args=None):
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
override_args = override_args or {} # default to empty dict
# only dropout can be overridden see more notes below
assert all(k == 'dropout' for k in override_args)
from transformers import GPT2LMHeadModel
@ -235,6 +237,14 @@ class GPT(nn.Module):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
# will appear in the no_decay and decay sets respectively after the above.
# In addition, because named_parameters() doesn't return duplicates, it
# will only return the first occurence, key'd by 'transformer.wte.weight', below.
# so let's manually remove 'lm_head.weight' from decay set. This will include
# this tensor into optimization via transformer.wte.weight only, and not decayed.
decay.remove('lm_head.weight')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay