diff --git a/README.md b/README.md index 95b3fd3..b017587 100644 --- a/README.md +++ b/README.md @@ -34,3 +34,26 @@ Once some checkpoints are written to the output directory `out`, we're ready to $ python sample.py ``` +Training on 1 GPU overnight currently gets loss ~3.74. Random chance at init is -ln(1/50257) = 10.82. Which brings us to baselines. + +## baselines + +OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows: + +``` +$ python train.py eval_gpt2 +$ python train.py eval_gpt2_medium +$ python train.py eval_gpt2_large +$ python train.py eval_gpt2_xl +``` + +and observe the following losses on train and val: + +| model | params | train loss | val loss | +| ------| ------ | ---------- | -------- | +| gpt2 | 124M | 3.11 | 3.12 | +| gpt2-medium | 350M | 2.85 | 2.84 | +| gpt2-large | 774M | 2.66 | 2.67 | +| gpt2-xl | 1558M | 2.56 | 2.54 | + +I briefly tried finetuning gpt2 a bit more on our OWT and didn't notice dramatic improvements, suggesting that OWT is not much much different from WT in terms of the data distribution, but this needs a bit more thorough attempt once the code is in a better place. diff --git a/config/eval_gpt2.py b/config/eval_gpt2.py new file mode 100644 index 0000000..53978cb --- /dev/null +++ b/config/eval_gpt2.py @@ -0,0 +1,8 @@ +# evaluate the base gpt2 +# n_layer=12, n_head=12, n_embd=768 +# 124M parameters +batch_size = 8 +eval_iters = 500 # use more iterations to get good estimate +eval_only = True +wandb_log = False +init_from = 'gpt2' diff --git a/config/eval_gpt2_large.py b/config/eval_gpt2_large.py new file mode 100644 index 0000000..4cbeaef --- /dev/null +++ b/config/eval_gpt2_large.py @@ -0,0 +1,8 @@ +# evaluate the base gpt2 +# n_layer=36, n_head=20, n_embd=1280 +# 774M parameters +batch_size = 8 +eval_iters = 500 # use more iterations to get good estimate +eval_only = True +wandb_log = False +init_from = 'gpt2-large' diff --git a/config/eval_gpt2_medium.py b/config/eval_gpt2_medium.py new file mode 100644 index 0000000..9d0db11 --- /dev/null +++ b/config/eval_gpt2_medium.py @@ -0,0 +1,8 @@ +# evaluate the base gpt2 +# n_layer=24, n_head=16, n_embd=1024 +# 350M parameters +batch_size = 8 +eval_iters = 500 # use more iterations to get good estimate +eval_only = True +wandb_log = False +init_from = 'gpt2-medium' diff --git a/config/eval_gpt2_xl.py b/config/eval_gpt2_xl.py new file mode 100644 index 0000000..1bae34f --- /dev/null +++ b/config/eval_gpt2_xl.py @@ -0,0 +1,8 @@ +# evaluate the base gpt2 +# n_layer=48, n_head=25, n_embd=1600 +# 1558M parameters +batch_size = 8 +eval_iters = 500 # use more iterations to get good estimate +eval_only = True +wandb_log = False +init_from = 'gpt2-xl' diff --git a/train.py b/train.py index 494a80c..7162bb3 100644 --- a/train.py +++ b/train.py @@ -4,20 +4,25 @@ The text is assumed to pre-tokenized and inside files train.pt and val.pt """ import os +import sys import time import math +from ast import literal_eval import numpy as np import torch import wandb from model import GPTConfig, GPT + # ----------------------------------------------------------------------------- -# settings, todo argparse or something +# default config values # I/O out_dir = 'out' eval_interval = 500 log_interval = 1 +eval_iters = 50 +eval_only = False # if True, script exits right after the first eval # wandb logging wandb_log = False # disabled by default wandb_entity = 'karpathy' @@ -45,6 +50,38 @@ warmup_iters = 2000 # how many steps to warm up for lr_decay_iters = 320000 # how many steps to decay the learning rate for min_lr = 1e-5 # minimum learning rate # ----------------------------------------------------------------------------- +# poor man's Configurator. Potentially a bad idea. Example usage: +# python train.py override_file --batch_size=32 +# this will first run config/override_file.py, then override batch_size to 32 +for arg in sys.argv[1:]: + if '=' not in arg: + # assume it's the name of a config file + assert not arg.startswith('--') + config_file = os.path.join('config', arg + '.py') + print(f"Overriding config with {config_file}:") + with open(config_file) as f: + print(f.read()) + exec(open(config_file).read()) + else: + # assume it's a --key=value argument + assert arg.startswith('--') + key, val = arg.split('=') + key = key[2:] + if key in globals(): + try: + # attempt to eval it it (e.g. if bool, number, or etc) + attempt = literal_eval(val) + except SyntaxError: + # if that goes wrong, just use the string + attempt = val + # ensure the types match ok + assert type(attempt) == type(globals()[key]) + # cross fingers + print(f"Overriding: {key} = {attempt}") + globals()[key] = attempt + else: + raise ValueError(f"Unknown config key: {key}") +# ----------------------------------------------------------------------------- os.makedirs(out_dir, exist_ok=True) torch.manual_seed(1337) @@ -88,7 +125,7 @@ elif init_from.startswith('gpt2'): model.to(device) @torch.no_grad() -def estimate_loss(eval_iters=50): +def estimate_loss(): out = {} model.eval() for split in ['train', 'val']: @@ -166,6 +203,8 @@ while True: 'iter_num': iter_num, } torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) + if iter_num == 0 and eval_only: + break X, Y = get_batch('train') with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):