mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
fix bug with loading GPT-2 parameters, assert gets incorrectly tripped due to .bias missing since it is now optionally present depending on flash or not
This commit is contained in:
parent
038ce89438
commit
d995c22128
10
model.py
10
model.py
@ -229,18 +229,22 @@ class GPT(nn.Module):
|
|||||||
config = GPTConfig(block_size=1024, bias=True, **config_args) # note: force bias=True, as in gpt2 models
|
config = GPTConfig(block_size=1024, bias=True, **config_args) # note: force bias=True, as in gpt2 models
|
||||||
model = GPT(config)
|
model = GPT(config)
|
||||||
sd = model.state_dict()
|
sd = model.state_dict()
|
||||||
|
sd_keys = sd.keys()
|
||||||
|
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
||||||
|
|
||||||
# init a huggingface/transformers model
|
# init a huggingface/transformers model
|
||||||
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
||||||
sd_hf = model_hf.state_dict()
|
sd_hf = model_hf.state_dict()
|
||||||
|
|
||||||
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
||||||
keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these
|
sd_keys_hf = sd_hf.keys()
|
||||||
|
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
||||||
|
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
||||||
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
||||||
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
||||||
# this means that we have to transpose these weights when we import them
|
# this means that we have to transpose these weights when we import them
|
||||||
assert len(keys) == len(sd)
|
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
||||||
for k in keys:
|
for k in sd_keys_hf:
|
||||||
if any(k.endswith(w) for w in transposed):
|
if any(k.endswith(w) for w in transposed):
|
||||||
# special treatment for the Conv1D weights we need to transpose
|
# special treatment for the Conv1D weights we need to transpose
|
||||||
assert sd_hf[k].shape[::-1] == sd[k].shape
|
assert sd_hf[k].shape[::-1] == sd[k].shape
|
||||||
|
Loading…
Reference in New Issue
Block a user