1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-22 16:10:28 +00:00

Merge pull request from nat/patch-1

Strip unwanted prefix from state keys when loading model in sample.py
This commit is contained in:
Andrej 2023-01-04 16:46:32 -08:00 committed by GitHub
commit 529c967a65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,7 +29,12 @@ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device) checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args']) gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf) model = GPT(gptconf)
model.load_state_dict(checkpoint['model']) state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval() model.eval()
model.to(device) model.to(device)
if compile: if compile: