diff --git a/model.py b/model.py index 6c80679..1e5e1fd 100644 --- a/model.py +++ b/model.py @@ -268,60 +268,28 @@ class GPT(nn.Module): return model def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): - """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. - """ - - # separate out all parameters to those that will and won't experience regularizing weight decay - decay = set() - no_decay = set() - whitelist_weight_modules = (torch.nn.Linear, ) - blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding) - for mn, m in self.named_modules(): - for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name - # random note: because named_modules and named_parameters are recursive - # we will see the same tensors p many many times. but doing it this way - # allows us to know which parent module any tensor p belongs to... - if pn.endswith('bias'): - # all biases will not be decayed - no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): - # weights of whitelist modules will be weight decayed - decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): - # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) - - # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they - # will appear in the no_decay and decay sets respectively after the above. - # In addition, because named_parameters() doesn't return duplicates, it - # will only return the first occurence, key'd by 'transformer.wte.weight', below. - # so let's manually remove 'lm_head.weight' from decay set. This will include - # this tensor into optimization via transformer.wte.weight only, and not decayed. - decay.remove('lm_head.weight') - - # validate that we considered every parameter + # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()} - inter_params = decay & no_decay - union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) - - # create the pytorch optimizer object + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + {'params': decay_params, 'weight_decay': weight_decay}, + {'params': nodecay_params, 'weight_decay': 0.0} ] - # new PyTorch nightly has a new 'fused' option for AdamW that is much faster - use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) - print(f"using fused AdamW: {use_fused}") + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + print(f"using fused AdamW: {use_fused}") return optimizer