mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 06:00:29 +00:00
add gradient accumulation support to simulate larger batch sizes. ty @VHellendoorn for original PR
This commit is contained in:
parent
89da79eee1
commit
cf99914886
25
train.py
25
train.py
@ -38,7 +38,8 @@ wandb_project = 'owt'
|
||||
wandb_run_name = 'gpt2' # 'run' + str(time.time())
|
||||
# data
|
||||
dataset = 'openwebtext'
|
||||
batch_size = 12
|
||||
gradient_accumulation_steps = 1 # used to simulate larger batch sizes
|
||||
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
||||
block_size = 1024
|
||||
# model
|
||||
n_layer = 12
|
||||
@ -217,6 +218,7 @@ while True:
|
||||
else:
|
||||
lr = learning_rate
|
||||
|
||||
# evaluate the loss on train/val sets and write checkpoints
|
||||
if iter_num % eval_interval == 0 and gpu_id == 0:
|
||||
losses = estimate_loss()
|
||||
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
||||
@ -244,20 +246,27 @@ while True:
|
||||
if iter_num == 0 and eval_only:
|
||||
break
|
||||
|
||||
X, Y = get_batch('train')
|
||||
with ctx:
|
||||
logits, loss = model(X, Y)
|
||||
|
||||
# forward backward update, with optional gradient accumulation to simulate larger batch size
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
# TODO: gradient clipping evaluate need for
|
||||
for micro_step in range(gradient_accumulation_steps):
|
||||
X, Y = get_batch('train')
|
||||
if ddp:
|
||||
# in DDP training we only need to sync gradients at the last micro step.
|
||||
# the official way to do this is with model.no_sync() context manager, but
|
||||
# I really dislike that this bloats the code and forces us to repeat code
|
||||
# looking at the source of that context manager, it just toggles this variable
|
||||
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
|
||||
with ctx:
|
||||
logits, loss = model(X, Y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# timing and logging
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
t0 = t1
|
||||
if iter_num % log_interval == 0 and gpu_id == 0:
|
||||
lossf = loss.item() # loss as float. TODO CPU-GPU sync: profile, make sure not slow af
|
||||
lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow
|
||||
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
||||
iter_num += 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user