mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
adding a lightweight configurator that may be a terrible mistake lol. also adding configs to evaluate the baseline GPT2 versions released by OpenAI on OWT. we have some ways to go to match those numbers atm
This commit is contained in:
parent
c9fe00c0e9
commit
5d2b4807bf
23
README.md
23
README.md
@ -34,3 +34,26 @@ Once some checkpoints are written to the output directory `out`, we're ready to
|
||||
$ python sample.py
|
||||
```
|
||||
|
||||
Training on 1 GPU overnight currently gets loss ~3.74. Random chance at init is -ln(1/50257) = 10.82. Which brings us to baselines.
|
||||
|
||||
## baselines
|
||||
|
||||
OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:
|
||||
|
||||
```
|
||||
$ python train.py eval_gpt2
|
||||
$ python train.py eval_gpt2_medium
|
||||
$ python train.py eval_gpt2_large
|
||||
$ python train.py eval_gpt2_xl
|
||||
```
|
||||
|
||||
and observe the following losses on train and val:
|
||||
|
||||
| model | params | train loss | val loss |
|
||||
| ------| ------ | ---------- | -------- |
|
||||
| gpt2 | 124M | 3.11 | 3.12 |
|
||||
| gpt2-medium | 350M | 2.85 | 2.84 |
|
||||
| gpt2-large | 774M | 2.66 | 2.67 |
|
||||
| gpt2-xl | 1558M | 2.56 | 2.54 |
|
||||
|
||||
I briefly tried finetuning gpt2 a bit more on our OWT and didn't notice dramatic improvements, suggesting that OWT is not much much different from WT in terms of the data distribution, but this needs a bit more thorough attempt once the code is in a better place.
|
||||
|
8
config/eval_gpt2.py
Normal file
8
config/eval_gpt2.py
Normal file
@ -0,0 +1,8 @@
|
||||
# evaluate the base gpt2
|
||||
# n_layer=12, n_head=12, n_embd=768
|
||||
# 124M parameters
|
||||
batch_size = 8
|
||||
eval_iters = 500 # use more iterations to get good estimate
|
||||
eval_only = True
|
||||
wandb_log = False
|
||||
init_from = 'gpt2'
|
8
config/eval_gpt2_large.py
Normal file
8
config/eval_gpt2_large.py
Normal file
@ -0,0 +1,8 @@
|
||||
# evaluate the base gpt2
|
||||
# n_layer=36, n_head=20, n_embd=1280
|
||||
# 774M parameters
|
||||
batch_size = 8
|
||||
eval_iters = 500 # use more iterations to get good estimate
|
||||
eval_only = True
|
||||
wandb_log = False
|
||||
init_from = 'gpt2-large'
|
8
config/eval_gpt2_medium.py
Normal file
8
config/eval_gpt2_medium.py
Normal file
@ -0,0 +1,8 @@
|
||||
# evaluate the base gpt2
|
||||
# n_layer=24, n_head=16, n_embd=1024
|
||||
# 350M parameters
|
||||
batch_size = 8
|
||||
eval_iters = 500 # use more iterations to get good estimate
|
||||
eval_only = True
|
||||
wandb_log = False
|
||||
init_from = 'gpt2-medium'
|
8
config/eval_gpt2_xl.py
Normal file
8
config/eval_gpt2_xl.py
Normal file
@ -0,0 +1,8 @@
|
||||
# evaluate the base gpt2
|
||||
# n_layer=48, n_head=25, n_embd=1600
|
||||
# 1558M parameters
|
||||
batch_size = 8
|
||||
eval_iters = 500 # use more iterations to get good estimate
|
||||
eval_only = True
|
||||
wandb_log = False
|
||||
init_from = 'gpt2-xl'
|
43
train.py
43
train.py
@ -4,20 +4,25 @@ The text is assumed to pre-tokenized and inside files train.pt and val.pt
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import math
|
||||
from ast import literal_eval
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
from model import GPTConfig, GPT
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# settings, todo argparse or something
|
||||
# default config values
|
||||
# I/O
|
||||
out_dir = 'out'
|
||||
eval_interval = 500
|
||||
log_interval = 1
|
||||
eval_iters = 50
|
||||
eval_only = False # if True, script exits right after the first eval
|
||||
# wandb logging
|
||||
wandb_log = False # disabled by default
|
||||
wandb_entity = 'karpathy'
|
||||
@ -45,6 +50,38 @@ warmup_iters = 2000 # how many steps to warm up for
|
||||
lr_decay_iters = 320000 # how many steps to decay the learning rate for
|
||||
min_lr = 1e-5 # minimum learning rate
|
||||
# -----------------------------------------------------------------------------
|
||||
# poor man's Configurator. Potentially a bad idea. Example usage:
|
||||
# python train.py override_file --batch_size=32
|
||||
# this will first run config/override_file.py, then override batch_size to 32
|
||||
for arg in sys.argv[1:]:
|
||||
if '=' not in arg:
|
||||
# assume it's the name of a config file
|
||||
assert not arg.startswith('--')
|
||||
config_file = os.path.join('config', arg + '.py')
|
||||
print(f"Overriding config with {config_file}:")
|
||||
with open(config_file) as f:
|
||||
print(f.read())
|
||||
exec(open(config_file).read())
|
||||
else:
|
||||
# assume it's a --key=value argument
|
||||
assert arg.startswith('--')
|
||||
key, val = arg.split('=')
|
||||
key = key[2:]
|
||||
if key in globals():
|
||||
try:
|
||||
# attempt to eval it it (e.g. if bool, number, or etc)
|
||||
attempt = literal_eval(val)
|
||||
except SyntaxError:
|
||||
# if that goes wrong, just use the string
|
||||
attempt = val
|
||||
# ensure the types match ok
|
||||
assert type(attempt) == type(globals()[key])
|
||||
# cross fingers
|
||||
print(f"Overriding: {key} = {attempt}")
|
||||
globals()[key] = attempt
|
||||
else:
|
||||
raise ValueError(f"Unknown config key: {key}")
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
torch.manual_seed(1337)
|
||||
@ -88,7 +125,7 @@ elif init_from.startswith('gpt2'):
|
||||
model.to(device)
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_loss(eval_iters=50):
|
||||
def estimate_loss():
|
||||
out = {}
|
||||
model.eval()
|
||||
for split in ['train', 'val']:
|
||||
@ -166,6 +203,8 @@ while True:
|
||||
'iter_num': iter_num,
|
||||
}
|
||||
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
|
||||
if iter_num == 0 and eval_only:
|
||||
break
|
||||
|
||||
X, Y = get_batch('train')
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
|
Loading…
Reference in New Issue
Block a user