From 2b9e168736c69e969f3d4cf71c0939a9fce24a56 Mon Sep 17 00:00:00 2001 From: Nat Friedman Date: Wed, 4 Jan 2023 16:34:00 -0800 Subject: [PATCH] Strip unwanted prefix from state keys when loading model --- sample.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sample.py b/sample.py index e245efc..618cacb 100644 --- a/sample.py +++ b/sample.py @@ -29,7 +29,12 @@ ckpt_path = os.path.join(out_dir, 'ckpt.pt') checkpoint = torch.load(ckpt_path, map_location=device) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) -model.load_state_dict(checkpoint['model']) +state_dict = checkpoint['model'] +unwanted_prefix = '_orig_mod.' +for k,v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) +model.load_state_dict(state_dict) model.eval() model.to(device) if compile: