mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
use the new fused AdamW from pytorch nightly, if available
This commit is contained in:
parent
7d44bdf6b5
commit
e170e40872
6
model.py
6
model.py
@ -8,6 +8,7 @@ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gp
|
||||
"""
|
||||
|
||||
import math
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
@ -307,7 +308,10 @@ class GPT(nn.Module):
|
||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
|
||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
|
||||
extra_args = dict(fused=True) if 'fused' in inspect.signature(torch.optim.AdamW).parameters else dict()
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
||||
|
||||
return optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
|
Loading…
Reference in New Issue
Block a user