1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

use the new fused AdamW from pytorch nightly, if available

This commit is contained in:
Andrej Karpathy 2023-02-03 17:56:51 +00:00
parent 7d44bdf6b5
commit e170e40872

View File

@ -8,6 +8,7 @@ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gp
""" """
import math import math
import inspect
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
@ -307,7 +308,10 @@ class GPT(nn.Module):
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
] ]
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
extra_args = dict(fused=True) if 'fused' in inspect.signature(torch.optim.AdamW).parameters else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
return optimizer return optimizer
@torch.no_grad() @torch.no_grad()