1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

Merge pull request #106 from YassineYousfi/master

use the ``enabled`` arg in GradScaler
This commit is contained in:
Andrej 2023-02-02 17:23:22 -08:00 committed by GitHub
commit 7d44bdf6b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -173,11 +173,8 @@ if block_size < model.config.block_size:
model.crop_block_size(block_size) model.crop_block_size(block_size)
model.to(device) model.to(device)
# initialize a GradScaler if data type is float16 # initialize a GradScaler. If enabled=False scaler is a no-op
scaler = None scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
if dtype == 'float16':
print(f"Initializing Gradient Scaler to account for dtype: {dtype}")
scaler = torch.cuda.amp.GradScaler()
# optimizer # optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2)) optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2))
@ -283,17 +280,14 @@ while True:
with ctx: with ctx:
logits, loss = model(X, Y) logits, loss = model(X, Y)
# backward pass, with gradient scaling if training in fp16 # backward pass, with gradient scaling if training in fp16
scaler.scale(loss).backward() if scaler else loss.backward() scaler.scale(loss).backward()
# clip the gradient # clip the gradient
if grad_clip != 0.0: if grad_clip != 0.0:
scaler.unscale_(optimizer) if scaler else None scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# step the optimizer # step the optimizer and scaler if training in fp16
if scaler:
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
else:
optimizer.step()
# flush the gradients as soon as we can, no need for this memory anymore # flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)