From ec9b1f81827f43bfc1502c3bf0cda5f1b521a474 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 2 Jan 2023 01:25:02 +0000 Subject: [PATCH] add a patch to fix mysterious unwanted prefix in state dict? maybe remove later --- train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 85ce586..778e1a3 100644 --- a/train.py +++ b/train.py @@ -142,7 +142,14 @@ elif init_from == 'resume': # TODO: think through how passed in params should interact with checkpoint params gptconf = GPTConfig(**model_args) model = GPT(gptconf) - model.load_state_dict(checkpoint['model']) + state_dict = checkpoint['model'] + # fix the keys of the state dictionary :( + # honestly no idea how checkpoints sometimes get this prefix, have to debug more + unwanted_prefix = '_orig_mod.' + for k,v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + model.load_state_dict(state_dict) iter_num = checkpoint['iter_num'] best_val_loss = checkpoint['best_val_loss'] elif init_from.startswith('gpt2'):