From 40f4d6ff70aa59022b365906527752a0e56fb890 Mon Sep 17 00:00:00 2001 From: Yassine Yousfi Date: Tue, 31 Jan 2023 21:12:49 -0800 Subject: [PATCH] use the enabled arg in GradScaler --- train.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/train.py b/train.py index 77e713b..f8848e9 100644 --- a/train.py +++ b/train.py @@ -173,11 +173,8 @@ if block_size < model.config.block_size: model.crop_block_size(block_size) model.to(device) -# initialize a GradScaler if data type is float16 -scaler = None -if dtype == 'float16': - print(f"Initializing Gradient Scaler to account for dtype: {dtype}") - scaler = torch.cuda.amp.GradScaler() +# initialize a GradScaler. If enabled=False scaler is a no-op +scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) # optimizer optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2)) @@ -283,17 +280,14 @@ while True: with ctx: logits, loss = model(X, Y) # backward pass, with gradient scaling if training in fp16 - scaler.scale(loss).backward() if scaler else loss.backward() + scaler.scale(loss).backward() # clip the gradient if grad_clip != 0.0: - scaler.unscale_(optimizer) if scaler else None + scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) - # step the optimizer - if scaler: - scaler.step(optimizer) - scaler.update() - else: - optimizer.step() + # step the optimizer and scaler if training in fp16 + scaler.step(optimizer) + scaler.update() # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True)