From e170e40872cce7ce9426eda80cd8653d807cb48a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 3 Feb 2023 17:56:51 +0000 Subject: [PATCH] use the new fused AdamW from pytorch nightly, if available --- model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/model.py b/model.py index a10c850..f934ef1 100644 --- a/model.py +++ b/model.py @@ -8,6 +8,7 @@ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gp """ import math +import inspect from dataclasses import dataclass import torch @@ -307,7 +308,10 @@ class GPT(nn.Module): {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] - optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + # new PyTorch nightly has a new 'fused' option for AdamW that is much faster + extra_args = dict(fused=True) if 'fused' in inspect.signature(torch.optim.AdamW).parameters else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + return optimizer @torch.no_grad()