1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-09-21 11:49:46 +00:00

add gradient accumulation support to simulate larger batch sizes. ty @VHellendoorn for original PR

This commit is contained in:
Andrej Karpathy 2023-01-15 17:49:55 +00:00
parent 89da79eee1
commit cf99914886

View File

@ -38,7 +38,8 @@ wandb_project = 'owt'
wandb_run_name = 'gpt2' # 'run' + str(time.time())
# data
dataset = 'openwebtext'
batch_size = 12
gradient_accumulation_steps = 1 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 1024
# model
n_layer = 12
@ -217,6 +218,7 @@ while True:
else:
lr = learning_rate
# evaluate the loss on train/val sets and write checkpoints
if iter_num % eval_interval == 0 and gpu_id == 0:
losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
@ -244,20 +246,27 @@ while True:
if iter_num == 0 and eval_only:
break
X, Y = get_batch('train')
with ctx:
logits, loss = model(X, Y)
# forward backward update, with optional gradient accumulation to simulate larger batch size
optimizer.zero_grad(set_to_none=True)
loss.backward()
# TODO: gradient clipping evaluate need for
for micro_step in range(gradient_accumulation_steps):
X, Y = get_batch('train')
if ddp:
# in DDP training we only need to sync gradients at the last micro step.
# the official way to do this is with model.no_sync() context manager, but
# I really dislike that this bloats the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
with ctx:
logits, loss = model(X, Y)
loss.backward()
optimizer.step()
# timing and logging
t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % log_interval == 0 and gpu_id == 0:
lossf = loss.item() # loss as float. TODO CPU-GPU sync: profile, make sure not slow af
lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
iter_num += 1