diff --git a/model.py b/model.py index 044c668..ec18243 100644 --- a/model.py +++ b/model.py @@ -121,10 +121,28 @@ class GPT(nn.Module): # not 100% sure what this is, so far seems to be harmless. TODO investigate self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith('c_proj.weight'): + torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) + # report number of parameters n_params = sum(p.numel() for p in self.parameters()) print("number of parameters: %.2fM" % (n_params/1e6,)) + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + def forward(self, idx, targets=None): device = idx.device b, t = idx.size()