From 67166079c9df2fbdc1528ad159b7f54b4a11fc37 Mon Sep 17 00:00:00 2001 From: Clive Chan Date: Thu, 19 Jan 2023 22:10:44 -0800 Subject: [PATCH] Zero-grad more aggressively to save memory --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index f1b0bc1..f6e5fd0 100644 --- a/train.py +++ b/train.py @@ -259,7 +259,6 @@ while True: break # forward backward update, with optional gradient accumulation to simulate larger batch size - optimizer.zero_grad(set_to_none=True) for micro_step in range(gradient_accumulation_steps): X, Y = get_batch('train') if ddp: @@ -272,6 +271,7 @@ while True: logits, loss = model(X, Y) loss.backward() optimizer.step() + optimizer.zero_grad(set_to_none=True) # timing and logging t1 = time.time()