mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
properly resume training, also loading iter_num and best_val_loss from checkpoints
This commit is contained in:
parent
f88aa2c2fe
commit
682a0ac8f1
12
train.py
12
train.py
@ -118,15 +118,20 @@ def get_batch(split):
|
|||||||
x, y = x.to(device), y.to(device)
|
x, y = x.to(device), y.to(device)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
|
||||||
|
iter_num = 0
|
||||||
|
best_val_loss = 1e9
|
||||||
|
|
||||||
# model init
|
# model init
|
||||||
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
|
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
|
||||||
if init_from == 'scratch':
|
if init_from == 'scratch':
|
||||||
# init a new model from scratch
|
# init a new model from scratch
|
||||||
|
print("Initializing a new model from scratch")
|
||||||
gptconf = GPTConfig(**model_args)
|
gptconf = GPTConfig(**model_args)
|
||||||
model = GPT(gptconf)
|
model = GPT(gptconf)
|
||||||
elif init_from == 'resume':
|
elif init_from == 'resume':
|
||||||
|
print(f"Resuming training from {out_dir}")
|
||||||
# resume training from a checkpoint.
|
# resume training from a checkpoint.
|
||||||
# TODO: should we also resume iter_num and best_val_loss?
|
|
||||||
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
checkpoint_model_args = checkpoint['model_args']
|
checkpoint_model_args = checkpoint['model_args']
|
||||||
@ -135,7 +140,10 @@ elif init_from == 'resume':
|
|||||||
gptconf = GPTConfig(**model_args)
|
gptconf = GPTConfig(**model_args)
|
||||||
model = GPT(gptconf)
|
model = GPT(gptconf)
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
iter_num = checkpoint['iter_num']
|
||||||
|
best_val_loss = checkpoint['best_val_loss']
|
||||||
elif init_from.startswith('gpt2'):
|
elif init_from.startswith('gpt2'):
|
||||||
|
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
|
||||||
# initialize from OpenAI GPT-2 weights
|
# initialize from OpenAI GPT-2 weights
|
||||||
model = GPT.from_pretrained(init_from)
|
model = GPT.from_pretrained(init_from)
|
||||||
# crop down the model block size if desired
|
# crop down the model block size if desired
|
||||||
@ -191,8 +199,6 @@ if wandb_log and gpu_id == 0:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
iter_num = 0
|
|
||||||
best_val_loss = 1e9
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user