1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

adding a lightweight configurator that may be a terrible mistake lol. also adding configs to evaluate the baseline GPT2 versions released by OpenAI on OWT. we have some ways to go to match those numbers atm

This commit is contained in:
Andrej Karpathy 2022-12-28 23:31:23 +00:00
parent c9fe00c0e9
commit 5d2b4807bf
6 changed files with 96 additions and 2 deletions

View File

@ -34,3 +34,26 @@ Once some checkpoints are written to the output directory `out`, we're ready to
$ python sample.py $ python sample.py
``` ```
Training on 1 GPU overnight currently gets loss ~3.74. Random chance at init is -ln(1/50257) = 10.82. Which brings us to baselines.
## baselines
OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:
```
$ python train.py eval_gpt2
$ python train.py eval_gpt2_medium
$ python train.py eval_gpt2_large
$ python train.py eval_gpt2_xl
```
and observe the following losses on train and val:
| model | params | train loss | val loss |
| ------| ------ | ---------- | -------- |
| gpt2 | 124M | 3.11 | 3.12 |
| gpt2-medium | 350M | 2.85 | 2.84 |
| gpt2-large | 774M | 2.66 | 2.67 |
| gpt2-xl | 1558M | 2.56 | 2.54 |
I briefly tried finetuning gpt2 a bit more on our OWT and didn't notice dramatic improvements, suggesting that OWT is not much much different from WT in terms of the data distribution, but this needs a bit more thorough attempt once the code is in a better place.

8
config/eval_gpt2.py Normal file
View File

@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=12, n_head=12, n_embd=768
# 124M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2'

View File

@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=36, n_head=20, n_embd=1280
# 774M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-large'

View File

@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=24, n_head=16, n_embd=1024
# 350M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-medium'

8
config/eval_gpt2_xl.py Normal file
View File

@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=48, n_head=25, n_embd=1600
# 1558M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-xl'

View File

@ -4,20 +4,25 @@ The text is assumed to pre-tokenized and inside files train.pt and val.pt
""" """
import os import os
import sys
import time import time
import math import math
from ast import literal_eval
import numpy as np import numpy as np
import torch import torch
import wandb import wandb
from model import GPTConfig, GPT from model import GPTConfig, GPT
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# settings, todo argparse or something # default config values
# I/O # I/O
out_dir = 'out' out_dir = 'out'
eval_interval = 500 eval_interval = 500
log_interval = 1 log_interval = 1
eval_iters = 50
eval_only = False # if True, script exits right after the first eval
# wandb logging # wandb logging
wandb_log = False # disabled by default wandb_log = False # disabled by default
wandb_entity = 'karpathy' wandb_entity = 'karpathy'
@ -45,6 +50,38 @@ warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 320000 # how many steps to decay the learning rate for lr_decay_iters = 320000 # how many steps to decay the learning rate for
min_lr = 1e-5 # minimum learning rate min_lr = 1e-5 # minimum learning rate
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# poor man's Configurator. Potentially a bad idea. Example usage:
# python train.py override_file --batch_size=32
# this will first run config/override_file.py, then override batch_size to 32
for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = os.path.join('config', arg + '.py')
print(f"Overriding config with {config_file}:")
with open(config_file) as f:
print(f.read())
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except SyntaxError:
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
assert type(attempt) == type(globals()[key])
# cross fingers
print(f"Overriding: {key} = {attempt}")
globals()[key] = attempt
else:
raise ValueError(f"Unknown config key: {key}")
# -----------------------------------------------------------------------------
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337) torch.manual_seed(1337)
@ -88,7 +125,7 @@ elif init_from.startswith('gpt2'):
model.to(device) model.to(device)
@torch.no_grad() @torch.no_grad()
def estimate_loss(eval_iters=50): def estimate_loss():
out = {} out = {}
model.eval() model.eval()
for split in ['train', 'val']: for split in ['train', 'val']:
@ -166,6 +203,8 @@ while True:
'iter_num': iter_num, 'iter_num': iter_num,
} }
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
if iter_num == 0 and eval_only:
break
X, Y = get_batch('train') X, Y = get_batch('train')
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):