mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2025-10-27 05:17:41 +00:00
candidate changes to apis, have to think through more
This commit is contained in:
19
train.py
19
train.py
@@ -31,6 +31,7 @@ eval_interval = 500
|
||||
log_interval = 1
|
||||
eval_iters = 50
|
||||
eval_only = False # if True, script exits right after the first eval
|
||||
always_save_checkpoint = False # if True, always save a checkpoint after each eval
|
||||
# wandb logging
|
||||
wandb_log = False # disabled by default
|
||||
wandb_entity = 'karpathy'
|
||||
@@ -138,6 +139,7 @@ elif init_from == 'resume':
|
||||
checkpoint_model_args = checkpoint['model_args']
|
||||
for k, v in model_args.items():
|
||||
assert checkpoint_model_args[k] == v, "for now"
|
||||
# TODO: think through how passed in params should interact with checkpoint params
|
||||
gptconf = GPTConfig(**model_args)
|
||||
model = GPT(gptconf)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
@@ -146,9 +148,14 @@ elif init_from == 'resume':
|
||||
elif init_from.startswith('gpt2'):
|
||||
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
|
||||
# initialize from OpenAI GPT-2 weights
|
||||
model = GPT.from_pretrained(init_from)
|
||||
override_args = dict(dropout=dropout)
|
||||
model = GPT.from_pretrained(init_from, override_args)
|
||||
# read off and override the GPT sizing model args from the model config
|
||||
model_args['n_layer'] = model.config.n_layer
|
||||
model_args['n_head'] = model.config.n_head
|
||||
model_args['n_embd'] = model.config.n_embd
|
||||
# crop down the model block size if desired
|
||||
if block_size < model.block_size:
|
||||
if block_size < model.config.block_size:
|
||||
model.crop_block_size(block_size)
|
||||
model.to(device)
|
||||
|
||||
@@ -227,7 +234,7 @@ while True:
|
||||
"val/loss": losses['val'],
|
||||
"lr": lr,
|
||||
})
|
||||
if losses['val'] < best_val_loss:
|
||||
if losses['val'] < best_val_loss or always_save_checkpoint:
|
||||
best_val_loss = losses['val']
|
||||
raw_model = model.module if ddp else model
|
||||
if iter_num > 0:
|
||||
@@ -238,6 +245,7 @@ while True:
|
||||
'iter_num': iter_num,
|
||||
'best_val_loss': best_val_loss,
|
||||
}
|
||||
print(f"saving checkpoint to {out_dir}")
|
||||
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
|
||||
if iter_num == 0 and eval_only:
|
||||
break
|
||||
@@ -260,7 +268,8 @@ while True:
|
||||
iter_num += 1
|
||||
|
||||
# termination conditions
|
||||
if iter_num >= max_iters:
|
||||
if iter_num > max_iters:
|
||||
break
|
||||
|
||||
destroy_process_group()
|
||||
if ddp:
|
||||
destroy_process_group()
|
||||
|
||||
Reference in New Issue
Block a user