mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
Merge pull request #19 from nat/patch-1
Strip unwanted prefix from state keys when loading model in sample.py
This commit is contained in:
commit
529c967a65
@ -29,7 +29,12 @@ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
|||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
gptconf = GPTConfig(**checkpoint['model_args'])
|
gptconf = GPTConfig(**checkpoint['model_args'])
|
||||||
model = GPT(gptconf)
|
model = GPT(gptconf)
|
||||||
model.load_state_dict(checkpoint['model'])
|
state_dict = checkpoint['model']
|
||||||
|
unwanted_prefix = '_orig_mod.'
|
||||||
|
for k,v in list(state_dict.items()):
|
||||||
|
if k.startswith(unwanted_prefix):
|
||||||
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if compile:
|
if compile:
|
||||||
|
Loading…
Reference in New Issue
Block a user