mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2025-01-18 13:12:53 +00:00
tie the weights of lm_head.weight and transformer.wte.weight, i.e. the last linear layer of decoder and the token embeddings.
This commit is contained in:
parent
32b4f08d9d
commit
7c8288552b
16
model.py
16
model.py
@ -115,9 +115,10 @@ class GPT(nn.Module):
|
||||
ln_f = nn.LayerNorm(config.n_embd),
|
||||
))
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.lm_head.weight = self.transformer.wte.weight # https://paperswithcode.com/method/weight-tying
|
||||
|
||||
# report number of parameters (note we don't count the decoder parameters in lm_head)
|
||||
n_params = sum(p.numel() for p in self.transformer.parameters())
|
||||
# report number of parameters
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
print("number of parameters: %.2fM" % (n_params/1e6,))
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
@ -156,8 +157,9 @@ class GPT(nn.Module):
|
||||
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_type, override_args):
|
||||
def from_pretrained(cls, model_type, override_args=None):
|
||||
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
||||
override_args = override_args or {} # default to empty dict
|
||||
# only dropout can be overridden see more notes below
|
||||
assert all(k == 'dropout' for k in override_args)
|
||||
from transformers import GPT2LMHeadModel
|
||||
@ -235,6 +237,14 @@ class GPT(nn.Module):
|
||||
# weights of blacklist modules will NOT be weight decayed
|
||||
no_decay.add(fpn)
|
||||
|
||||
# subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
|
||||
# will appear in the no_decay and decay sets respectively after the above.
|
||||
# In addition, because named_parameters() doesn't return duplicates, it
|
||||
# will only return the first occurence, key'd by 'transformer.wte.weight', below.
|
||||
# so let's manually remove 'lm_head.weight' from decay set. This will include
|
||||
# this tensor into optimization via transformer.wte.weight only, and not decayed.
|
||||
decay.remove('lm_head.weight')
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
|
Loading…
Reference in New Issue
Block a user