1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

properly resume training, also loading iter_num and best_val_loss from checkpoints

This commit is contained in:
Andrej Karpathy 2022-12-29 18:23:15 +00:00
parent f88aa2c2fe
commit 682a0ac8f1

View File

@ -118,15 +118,20 @@ def get_batch(split):
x, y = x.to(device), y.to(device) x, y = x.to(device), y.to(device)
return x, y return x, y
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9
# model init # model init
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout) model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
if init_from == 'scratch': if init_from == 'scratch':
# init a new model from scratch # init a new model from scratch
print("Initializing a new model from scratch")
gptconf = GPTConfig(**model_args) gptconf = GPTConfig(**model_args)
model = GPT(gptconf) model = GPT(gptconf)
elif init_from == 'resume': elif init_from == 'resume':
print(f"Resuming training from {out_dir}")
# resume training from a checkpoint. # resume training from a checkpoint.
# TODO: should we also resume iter_num and best_val_loss?
ckpt_path = os.path.join(out_dir, 'ckpt.pt') ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device) checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint_model_args = checkpoint['model_args'] checkpoint_model_args = checkpoint['model_args']
@ -135,7 +140,10 @@ elif init_from == 'resume':
gptconf = GPTConfig(**model_args) gptconf = GPTConfig(**model_args)
model = GPT(gptconf) model = GPT(gptconf)
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
iter_num = checkpoint['iter_num']
best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'): elif init_from.startswith('gpt2'):
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
# initialize from OpenAI GPT-2 weights # initialize from OpenAI GPT-2 weights
model = GPT.from_pretrained(init_from) model = GPT.from_pretrained(init_from)
# crop down the model block size if desired # crop down the model block size if desired
@ -191,8 +199,6 @@ if wandb_log and gpu_id == 0:
} }
# training loop # training loop
iter_num = 0
best_val_loss = 1e9
t0 = time.time() t0 = time.time()
while True: while True: