mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
Merge pull request #106 from YassineYousfi/master
use the ``enabled`` arg in GradScaler
This commit is contained in:
commit
7d44bdf6b5
16
train.py
16
train.py
@ -173,11 +173,8 @@ if block_size < model.config.block_size:
|
|||||||
model.crop_block_size(block_size)
|
model.crop_block_size(block_size)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# initialize a GradScaler if data type is float16
|
# initialize a GradScaler. If enabled=False scaler is a no-op
|
||||||
scaler = None
|
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
|
||||||
if dtype == 'float16':
|
|
||||||
print(f"Initializing Gradient Scaler to account for dtype: {dtype}")
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2))
|
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2))
|
||||||
@ -283,17 +280,14 @@ while True:
|
|||||||
with ctx:
|
with ctx:
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
# backward pass, with gradient scaling if training in fp16
|
# backward pass, with gradient scaling if training in fp16
|
||||||
scaler.scale(loss).backward() if scaler else loss.backward()
|
scaler.scale(loss).backward()
|
||||||
# clip the gradient
|
# clip the gradient
|
||||||
if grad_clip != 0.0:
|
if grad_clip != 0.0:
|
||||||
scaler.unscale_(optimizer) if scaler else None
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||||
# step the optimizer
|
# step the optimizer and scaler if training in fp16
|
||||||
if scaler:
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
else:
|
|
||||||
optimizer.step()
|
|
||||||
# flush the gradients as soon as we can, no need for this memory anymore
|
# flush the gradients as soon as we can, no need for this memory anymore
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user