1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2025-10-27 05:17:41 +00:00

candidate changes to apis, have to think through more

This commit is contained in:
Andrej Karpathy
2023-01-01 01:29:48 +00:00
parent 7c6ea8409e
commit 2febf4463c
8 changed files with 111 additions and 19 deletions

View File

@@ -31,6 +31,7 @@ eval_interval = 500
log_interval = 1
eval_iters = 50
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
# wandb logging
wandb_log = False # disabled by default
wandb_entity = 'karpathy'
@@ -138,6 +139,7 @@ elif init_from == 'resume':
checkpoint_model_args = checkpoint['model_args']
for k, v in model_args.items():
assert checkpoint_model_args[k] == v, "for now"
# TODO: think through how passed in params should interact with checkpoint params
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.load_state_dict(checkpoint['model'])
@@ -146,9 +148,14 @@ elif init_from == 'resume':
elif init_from.startswith('gpt2'):
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
# initialize from OpenAI GPT-2 weights
model = GPT.from_pretrained(init_from)
override_args = dict(dropout=dropout)
model = GPT.from_pretrained(init_from, override_args)
# read off and override the GPT sizing model args from the model config
model_args['n_layer'] = model.config.n_layer
model_args['n_head'] = model.config.n_head
model_args['n_embd'] = model.config.n_embd
# crop down the model block size if desired
if block_size < model.block_size:
if block_size < model.config.block_size:
model.crop_block_size(block_size)
model.to(device)
@@ -227,7 +234,7 @@ while True:
"val/loss": losses['val'],
"lr": lr,
})
if losses['val'] < best_val_loss:
if losses['val'] < best_val_loss or always_save_checkpoint:
best_val_loss = losses['val']
raw_model = model.module if ddp else model
if iter_num > 0:
@@ -238,6 +245,7 @@ while True:
'iter_num': iter_num,
'best_val_loss': best_val_loss,
}
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
if iter_num == 0 and eval_only:
break
@@ -260,7 +268,8 @@ while True:
iter_num += 1
# termination conditions
if iter_num >= max_iters:
if iter_num > max_iters:
break
destroy_process_group()
if ddp:
destroy_process_group()