From d995c221282ea45eb43abe7acf67ba834d5f1b60 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 1 Feb 2023 02:05:34 +0000 Subject: [PATCH] fix bug with loading GPT-2 parameters, assert gets incorrectly tripped due to .bias missing since it is now optionally present depending on flash or not --- model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/model.py b/model.py index e6683a0..53a963a 100644 --- a/model.py +++ b/model.py @@ -229,18 +229,22 @@ class GPT(nn.Module): config = GPTConfig(block_size=1024, bias=True, **config_args) # note: force bias=True, as in gpt2 models model = GPT(config) sd = model.state_dict() + sd_keys = sd.keys() + sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param # init a huggingface/transformers model model_hf = GPT2LMHeadModel.from_pretrained(model_type) sd_hf = model_hf.state_dict() # copy while ensuring all of the parameters are aligned and match in names and shapes - keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these + sd_keys_hf = sd_hf.keys() + sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer + sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear # this means that we have to transpose these weights when we import them - assert len(keys) == len(sd) - for k in keys: + assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" + for k in sd_keys_hf: if any(k.endswith(w) for w in transposed): # special treatment for the Conv1D weights we need to transpose assert sd_hf[k].shape[::-1] == sd[k].shape