From cf9991488629b1b072c49bf261d04b0c8a3207a3 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 15 Jan 2023 17:49:55 +0000 Subject: [PATCH] add gradient accumulation support to simulate larger batch sizes. ty @VHellendoorn for original PR --- train.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index 90482d0..58619ed 100644 --- a/train.py +++ b/train.py @@ -38,7 +38,8 @@ wandb_project = 'owt' wandb_run_name = 'gpt2' # 'run' + str(time.time()) # data dataset = 'openwebtext' -batch_size = 12 +gradient_accumulation_steps = 1 # used to simulate larger batch sizes +batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size block_size = 1024 # model n_layer = 12 @@ -217,6 +218,7 @@ while True: else: lr = learning_rate + # evaluate the loss on train/val sets and write checkpoints if iter_num % eval_interval == 0 and gpu_id == 0: losses = estimate_loss() print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") @@ -244,20 +246,27 @@ while True: if iter_num == 0 and eval_only: break - X, Y = get_batch('train') - with ctx: - logits, loss = model(X, Y) - + # forward backward update, with optional gradient accumulation to simulate larger batch size optimizer.zero_grad(set_to_none=True) - loss.backward() - # TODO: gradient clipping evaluate need for + for micro_step in range(gradient_accumulation_steps): + X, Y = get_batch('train') + if ddp: + # in DDP training we only need to sync gradients at the last micro step. + # the official way to do this is with model.no_sync() context manager, but + # I really dislike that this bloats the code and forces us to repeat code + # looking at the source of that context manager, it just toggles this variable + model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) + with ctx: + logits, loss = model(X, Y) + loss.backward() optimizer.step() + # timing and logging t1 = time.time() dt = t1 - t0 t0 = t1 if iter_num % log_interval == 0 and gpu_id == 0: - lossf = loss.item() # loss as float. TODO CPU-GPU sync: profile, make sure not slow af + lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms") iter_num += 1