mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
simplify configure_optimizers by a lot
This commit is contained in:
parent
196160b849
commit
7fe4a099ad
66
model.py
66
model.py
@ -268,60 +268,28 @@ class GPT(nn.Module):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
||||||
"""
|
# start with all of the candidate parameters
|
||||||
This long function is unfortunately doing something very simple and is being very defensive:
|
|
||||||
We are separating out all parameters of the model into two buckets: those that will experience
|
|
||||||
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
|
||||||
We are then returning the PyTorch optimizer object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
||||||
decay = set()
|
|
||||||
no_decay = set()
|
|
||||||
whitelist_weight_modules = (torch.nn.Linear, )
|
|
||||||
blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
|
|
||||||
for mn, m in self.named_modules():
|
|
||||||
for pn, p in m.named_parameters():
|
|
||||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
|
||||||
# random note: because named_modules and named_parameters are recursive
|
|
||||||
# we will see the same tensors p many many times. but doing it this way
|
|
||||||
# allows us to know which parent module any tensor p belongs to...
|
|
||||||
if pn.endswith('bias'):
|
|
||||||
# all biases will not be decayed
|
|
||||||
no_decay.add(fpn)
|
|
||||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
|
||||||
# weights of whitelist modules will be weight decayed
|
|
||||||
decay.add(fpn)
|
|
||||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
|
||||||
# weights of blacklist modules will NOT be weight decayed
|
|
||||||
no_decay.add(fpn)
|
|
||||||
|
|
||||||
# subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
|
|
||||||
# will appear in the no_decay and decay sets respectively after the above.
|
|
||||||
# In addition, because named_parameters() doesn't return duplicates, it
|
|
||||||
# will only return the first occurence, key'd by 'transformer.wte.weight', below.
|
|
||||||
# so let's manually remove 'lm_head.weight' from decay set. This will include
|
|
||||||
# this tensor into optimization via transformer.wte.weight only, and not decayed.
|
|
||||||
decay.remove('lm_head.weight')
|
|
||||||
|
|
||||||
# validate that we considered every parameter
|
|
||||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
inter_params = decay & no_decay
|
# filter out those that do not require grad
|
||||||
union_params = decay | no_decay
|
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
||||||
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
||||||
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
||||||
% (str(param_dict.keys() - union_params), )
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
||||||
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
||||||
# create the pytorch optimizer object
|
|
||||||
optim_groups = [
|
optim_groups = [
|
||||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
|
{'params': decay_params, 'weight_decay': weight_decay},
|
||||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
{'params': nodecay_params, 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
|
num_decay_params = sum(p.numel() for p in decay_params)
|
||||||
use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
|
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
||||||
print(f"using fused AdamW: {use_fused}")
|
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
||||||
|
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
||||||
|
# Create AdamW optimizer and use the fused version if it is available
|
||||||
|
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
||||||
|
use_fused = fused_available and device_type == 'cuda'
|
||||||
extra_args = dict(fused=True) if use_fused else dict()
|
extra_args = dict(fused=True) if use_fused else dict()
|
||||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
||||||
|
print(f"using fused AdamW: {use_fused}")
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user