diff --git a/sample.py b/sample.py index e245efc..618cacb 100644 --- a/sample.py +++ b/sample.py @@ -29,7 +29,12 @@ ckpt_path = os.path.join(out_dir, 'ckpt.pt') checkpoint = torch.load(ckpt_path, map_location=device) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) -model.load_state_dict(checkpoint['model']) +state_dict = checkpoint['model'] +unwanted_prefix = '_orig_mod.' +for k,v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) +model.load_state_dict(state_dict) model.eval() model.to(device) if compile: