mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
tie the weights of lm_head.weight and transformer.wte.weight, i.e. the last linear layer of decoder and the token embeddings.
This commit is contained in:
parent
32b4f08d9d
commit
7c8288552b
16
model.py
16
model.py
@ -115,9 +115,10 @@ class GPT(nn.Module):
|
|||||||
ln_f = nn.LayerNorm(config.n_embd),
|
ln_f = nn.LayerNorm(config.n_embd),
|
||||||
))
|
))
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
self.lm_head.weight = self.transformer.wte.weight # https://paperswithcode.com/method/weight-tying
|
||||||
|
|
||||||
# report number of parameters (note we don't count the decoder parameters in lm_head)
|
# report number of parameters
|
||||||
n_params = sum(p.numel() for p in self.transformer.parameters())
|
n_params = sum(p.numel() for p in self.parameters())
|
||||||
print("number of parameters: %.2fM" % (n_params/1e6,))
|
print("number of parameters: %.2fM" % (n_params/1e6,))
|
||||||
|
|
||||||
def forward(self, idx, targets=None):
|
def forward(self, idx, targets=None):
|
||||||
@ -156,8 +157,9 @@ class GPT(nn.Module):
|
|||||||
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_type, override_args):
|
def from_pretrained(cls, model_type, override_args=None):
|
||||||
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
||||||
|
override_args = override_args or {} # default to empty dict
|
||||||
# only dropout can be overridden see more notes below
|
# only dropout can be overridden see more notes below
|
||||||
assert all(k == 'dropout' for k in override_args)
|
assert all(k == 'dropout' for k in override_args)
|
||||||
from transformers import GPT2LMHeadModel
|
from transformers import GPT2LMHeadModel
|
||||||
@ -235,6 +237,14 @@ class GPT(nn.Module):
|
|||||||
# weights of blacklist modules will NOT be weight decayed
|
# weights of blacklist modules will NOT be weight decayed
|
||||||
no_decay.add(fpn)
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
# subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
|
||||||
|
# will appear in the no_decay and decay sets respectively after the above.
|
||||||
|
# In addition, because named_parameters() doesn't return duplicates, it
|
||||||
|
# will only return the first occurence, key'd by 'transformer.wte.weight', below.
|
||||||
|
# so let's manually remove 'lm_head.weight' from decay set. This will include
|
||||||
|
# this tensor into optimization via transformer.wte.weight only, and not decayed.
|
||||||
|
decay.remove('lm_head.weight')
|
||||||
|
|
||||||
# validate that we considered every parameter
|
# validate that we considered every parameter
|
||||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
inter_params = decay & no_decay
|
inter_params = decay & no_decay
|
||||||
|
Loading…
Reference in New Issue
Block a user