mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
use the new fused AdamW from pytorch nightly, if available
This commit is contained in:
parent
7d44bdf6b5
commit
e170e40872
6
model.py
6
model.py
@ -8,6 +8,7 @@ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gp
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -307,7 +308,10 @@ class GPT(nn.Module):
|
|||||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
|
||||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
|
||||||
|
extra_args = dict(fused=True) if 'fused' in inspect.signature(torch.optim.AdamW).parameters else dict()
|
||||||
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
Loading…
Reference in New Issue
Block a user