mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
add a patch to fix mysterious unwanted prefix in state dict? maybe remove later
This commit is contained in:
parent
41184a27f5
commit
ec9b1f8182
9
train.py
9
train.py
@ -142,7 +142,14 @@ elif init_from == 'resume':
|
|||||||
# TODO: think through how passed in params should interact with checkpoint params
|
# TODO: think through how passed in params should interact with checkpoint params
|
||||||
gptconf = GPTConfig(**model_args)
|
gptconf = GPTConfig(**model_args)
|
||||||
model = GPT(gptconf)
|
model = GPT(gptconf)
|
||||||
model.load_state_dict(checkpoint['model'])
|
state_dict = checkpoint['model']
|
||||||
|
# fix the keys of the state dictionary :(
|
||||||
|
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
|
||||||
|
unwanted_prefix = '_orig_mod.'
|
||||||
|
for k,v in list(state_dict.items()):
|
||||||
|
if k.startswith(unwanted_prefix):
|
||||||
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
iter_num = checkpoint['iter_num']
|
iter_num = checkpoint['iter_num']
|
||||||
best_val_loss = checkpoint['best_val_loss']
|
best_val_loss = checkpoint['best_val_loss']
|
||||||
elif init_from.startswith('gpt2'):
|
elif init_from.startswith('gpt2'):
|
||||||
|
Loading…
Reference in New Issue
Block a user