diff --git a/c4gzparse_nanogpt.py b/c4gzparse_nanogpt.py new file mode 100644 index 0000000..e89a0c9 --- /dev/null +++ b/c4gzparse_nanogpt.py @@ -0,0 +1,56 @@ +import os +from tqdm import tqdm +import numpy as np +import tiktoken +import json +import gzip + +enc = tiktoken.get_encoding("gpt2") + +if __name__ == '__main__': + # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) + dataset = [] + + with gzip.open("c4-train.00000-of-01024.json.gz", "r") as file: + while line := file.readline(): + try: + dataset.append(json.loads(line)) + except EOFError: + pass + + # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) + def process(example): + ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens + ids.insert(0, enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = [ process(x) for x in dataset ] + divider = len(tokenized) // 100 + tokenized = { + "val": tokenized[:divider], + "train": tokenized[divider:] + } + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = sum((d['len'] for d in dset)) + filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) + total_batches = 1024 + + idx = 0 + for d in tqdm(dset, desc=f'writing {filename}'): + arr[idx : idx + d["len"]] = d["ids"] + idx += d["len"] + arr.flush() + + # train.bin is ~17GB, val.bin ~8.5MB + # train has ~9B tokens (9,035,582,198) + # val has ~4M tokens (4,434,897) + + # to read the bin files later, e.g. with numpy: + # m = np.memmap('train.bin', dtype=np.uint16, mode='r') diff --git a/compare_ckpts.py b/compare_ckpts.py new file mode 100644 index 0000000..12c555a --- /dev/null +++ b/compare_ckpts.py @@ -0,0 +1,69 @@ +import torch +from pathlib import Path +from collections import defaultdict +import matplotlib.pyplot as plt +import torch.nn.functional as F + +def compute_differences(m1, m2): + groups = {"mlp.c_fc.weight": defaultdict(lambda: 0.0), "attn.c_attn.weight": defaultdict(lambda: 0.0)} + for k, v1 in m1["model"].items(): + for cat in groups.keys(): + if cat in k: + diff = torch.flatten(v1 - m2["model"][k]) + #groups[cat]["l1"] += torch.linalg.norm(diff, dim=None, ord=1).item() + groups[cat]["l2"] += torch.linalg.norm(diff, dim=None, ord=2).item() + #groups[cat]["cosine"] += F.cosine_similarity(v1.flatten(), m2["model"][k].flatten(), dim=-1).item() + return groups + +def gradnorm(m): + x = 0 + for key, state in m["optimizer"]["state"].items(): + #x += torch.linalg.norm(state["exp_avg"], dim=None, ord=2).item() + x += torch.mean(state["exp_avg_sq"]).item() + return x + +def flatten(xs, out=None, prefix=""): + if out is None: out = {} + for k, v in xs.items(): + longk = (prefix + " " + k).strip() + if isinstance(v, dict): flatten(v, out, longk) + else: + out[longk] = v + return out + +xs = [] +ys = defaultdict(list) +for step in range(500, 3500, 500): + file = f"ckpt{step}.pt" + m_baseline = torch.load(Path("fixed-seed1") / file) + m_sameseed = torch.load(Path("fixed-seed1-1") / file) + m_sameseed2 = torch.load(Path("fixed-seed1-2") / file) + m_other = torch.load(Path("fixed-seed2") / file) + m_baseline_resumed = torch.load(Path("fixed-seed1-res1500") / file) + xs.append(step) + + """ + comparisons = { + "same seed": compute_differences(m_baseline, m_sameseed), + "same seed 2": compute_differences(m_baseline, m_sameseed2), + "same seed resume at 1500": compute_differences(m_baseline, m_baseline_resumed), + "other seed": compute_differences(m_baseline, m_other), + } + """ + comparisons = { + "baseline": gradnorm(m_baseline), + "same seed": gradnorm(m_sameseed), + "same seed 2": gradnorm(m_sameseed2), + "other seed": gradnorm(m_other) + } + + for k, v in flatten(comparisons).items(): + ys[k].append(v) + +plt.figure(figsize=(12, 10)) +plt.xlabel("step") +plt.ylabel("gradnorm") +for k, v in ys.items(): + plt.plot(xs, v, label=k) +plt.legend() +plt.savefig("x.png") \ No newline at end of file diff --git a/config/train_gpt2.py b/config/train_gpt2.py index 8f19273..6f57112 100644 --- a/config/train_gpt2.py +++ b/config/train_gpt2.py @@ -2,7 +2,7 @@ # launch as the following (e.g. in a screen session) and wait ~5 days: # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py -wandb_log = True +wandb_log = False wandb_project = 'owt' wandb_run_name='gpt2-124M' diff --git a/config/train_shakespeare_char.py b/config/train_shakespeare_char.py index 41c81df..003792c 100644 --- a/config/train_shakespeare_char.py +++ b/config/train_shakespeare_char.py @@ -35,3 +35,4 @@ warmup_iters = 100 # not super necessary potentially # on macbook also add # device = 'cpu' # run on cpu only # compile = False # do not torch compile the model +compile = False diff --git a/fixed-cosine-plot.png b/fixed-cosine-plot.png new file mode 100644 index 0000000..87ad195 Binary files /dev/null and b/fixed-cosine-plot.png differ diff --git a/fixed-l1norm-plot.png b/fixed-l1norm-plot.png new file mode 100644 index 0000000..1442c86 Binary files /dev/null and b/fixed-l1norm-plot.png differ diff --git a/fixed-l2norm-plot.png b/fixed-l2norm-plot.png new file mode 100644 index 0000000..d0b6a4d Binary files /dev/null and b/fixed-l2norm-plot.png differ diff --git a/grad-norms-l1-what.png b/grad-norms-l1-what.png new file mode 100644 index 0000000..17b5f78 Binary files /dev/null and b/grad-norms-l1-what.png differ diff --git a/grad-norms-mean-avgsq.png b/grad-norms-mean-avgsq.png new file mode 100644 index 0000000..4ddffe5 Binary files /dev/null and b/grad-norms-mean-avgsq.png differ diff --git a/grad-norms.png b/grad-norms.png new file mode 100644 index 0000000..791b902 Binary files /dev/null and b/grad-norms.png differ diff --git a/l1-new.png b/l1-new.png new file mode 100644 index 0000000..53f14d1 Binary files /dev/null and b/l1-new.png differ diff --git a/l1norm.png b/l1norm.png new file mode 100644 index 0000000..53da383 Binary files /dev/null and b/l1norm.png differ diff --git a/l2-new.png b/l2-new.png new file mode 100644 index 0000000..29f14b9 Binary files /dev/null and b/l2-new.png differ diff --git a/l2norm.png b/l2norm.png new file mode 100644 index 0000000..7fd7bbb Binary files /dev/null and b/l2norm.png differ diff --git a/model.py b/model.py index c698f8b..2e50172 100644 --- a/model.py +++ b/model.py @@ -298,7 +298,7 @@ class GPT(nn.Module): flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter # express our flops throughput as ratio of A100 bfloat16 peak flops flops_achieved = flops_per_iter * (1.0/dt) # per second - flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS + flops_promised = 70e12 # RTX 3090 BF16 FP32 accumulate mfu = flops_achieved / flops_promised return mfu diff --git a/train.py b/train.py index 951bda9..2dc22eb 100644 --- a/train.py +++ b/train.py @@ -28,6 +28,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group from model import GPTConfig, GPT +import random + +seed = 1 +torch.manual_seed(seed) +random.seed(seed) +np.random.seed(seed) # ----------------------------------------------------------------------------- # default config values designed to train a gpt2 (124M) on OpenWebText @@ -38,34 +44,49 @@ log_interval = 1 eval_iters = 200 eval_only = False # if True, script exits right after the first eval always_save_checkpoint = True # if True, always save a checkpoint after each eval -init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' +init_from = 'resume' # 'scratch' or 'resume' or 'gpt2*' # wandb logging -wandb_log = False # disabled by default -wandb_project = 'owt' -wandb_run_name = 'gpt2' # 'run' + str(time.time()) # data dataset = 'openwebtext' -gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes -batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size + +wandb_log = False +wandb_project = 'owt' +wandb_run_name='gpt2' + +# these make the total batch size be ~0.5M +# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 +batch_size = 8 +block_size = 1024 +gradient_accumulation_steps = 8 * 8 + +# this makes total number of tokens be 300B +max_iters = 3000 +lr_decay_iters = 3000 + +# eval stuff +eval_interval = 500 +eval_iters = 200 +log_interval = 10 + +# weight decay +weight_decay = 1e-1 + block_size = 1024 # model -n_layer = 12 -n_head = 12 -n_embd = 768 +n_layer = 6 +n_head = 8 +n_embd = 512 dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ bias = False # do we use bias inside LayerNorm and Linear layers? # adamw optimizer learning_rate = 6e-4 # max learning rate -max_iters = 600000 # total number of training iterations -weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 # learning rate decay settings decay_lr = True # whether to decay the learning rate -warmup_iters = 2000 # how many steps to warm up for -lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla -min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla +warmup_iters = 500 # how many steps to warm up for +min_lr = 6e-4 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla # DDP settings backend = 'nccl' # 'nccl', 'gloo', etc. # system @@ -74,7 +95,6 @@ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported compile = True # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] -exec(open('configurator.py').read()) # overrides from command line or config file config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- @@ -103,7 +123,6 @@ print(f"tokens per iteration will be: {tokens_per_iter:,}") if master_process: os.makedirs(out_dir, exist_ok=True) -torch.manual_seed(1337 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast @@ -112,15 +131,16 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # poor man's data loader -data_dir = os.path.join('data', dataset) -def get_batch(split): +data_dir = "." +def get_batch(split, step): # We recreate np.memmap every batch to avoid a memory leak, as per # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 if split == 'train': data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') else: data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') - ix = torch.randint(len(data) - block_size, (batch_size,)) + d_rng = random.Random(f"{split}-{step}-{seed}") + ix = [ d_rng.randint(0, len(data) - block_size) for _ in range(batch_size) ] x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) if device_type == 'cuda': @@ -158,7 +178,7 @@ if init_from == 'scratch': elif init_from == 'resume': print(f"Resuming training from {out_dir}") # resume training from a checkpoint. - ckpt_path = os.path.join(out_dir, 'ckpt.pt') + ckpt_path = os.path.join(out_dir, 'ckpt1500.pt') checkpoint = torch.load(ckpt_path, map_location=device) checkpoint_model_args = checkpoint['model_args'] # force these config attributes to be equal otherwise we can't even resume training @@ -213,13 +233,13 @@ if ddp: # helps estimate an arbitrarily accurate loss over either split using many batches @torch.no_grad() -def estimate_loss(): +def estimate_loss(step): out = {} model.eval() for split in ['train', 'val']: losses = torch.zeros(eval_iters) for k in range(eval_iters): - X, Y = get_batch(split) + X, Y = get_batch(split, step) with ctx: logits, loss = model(X, Y) losses[k] = loss.item() @@ -247,9 +267,9 @@ if wandb_log and master_process: wandb.init(project=wandb_project, name=wandb_run_name, config=config) # training loop -X, Y = get_batch('train') # fetch the very first batch -t0 = time.time() +X, Y = get_batch('train', f"{iter_num}-{0}") # fetch the very first batch local_iter_num = 0 # number of iterations in the lifetime of this process +t0 = time.time() raw_model = model.module if ddp else model # unwrap DDP container if needed running_mfu = -1.0 while True: @@ -261,7 +281,7 @@ while True: # evaluate the loss on train/val sets and write checkpoints if iter_num % eval_interval == 0 and master_process: - losses = estimate_loss() + losses = estimate_loss(iter_num) print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") if wandb_log: wandb.log({ @@ -273,17 +293,16 @@ while True: }) if losses['val'] < best_val_loss or always_save_checkpoint: best_val_loss = losses['val'] - if iter_num > 0: - checkpoint = { - 'model': raw_model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'model_args': model_args, - 'iter_num': iter_num, - 'best_val_loss': best_val_loss, - 'config': config, - } - print(f"saving checkpoint to {out_dir}") - torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) + checkpoint = { + 'model': raw_model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'model_args': model_args, + 'iter_num': iter_num, + 'best_val_loss': best_val_loss, + 'config': config, + } + print(f"saving checkpoint to {out_dir}") + torch.save(checkpoint, os.path.join(out_dir, f'ckpt{iter_num}.pt')) if iter_num == 0 and eval_only: break @@ -300,7 +319,7 @@ while True: logits, loss = model(X, Y) loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation # immediately async prefetch next batch while model is doing the forward pass on the GPU - X, Y = get_batch('train') + X, Y = get_batch('train', f"{iter_num}-{micro_step + 1}") # backward pass, with gradient scaling if training in fp16 scaler.scale(loss).backward() # clip the gradient diff --git a/x.png b/x.png new file mode 100644 index 0000000..4ddffe5 Binary files /dev/null and b/x.png differ