diff --git a/model.py b/model.py index f18996c..044c668 100644 --- a/model.py +++ b/model.py @@ -115,6 +115,10 @@ class GPT(nn.Module): ln_f = nn.LayerNorm(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # with weight tying when using torch.compile() some warnings get generated: + # "UserWarning: functional_call was passed multiple values for tied weights. + # This behavior is deprecated and will be an error in future versions" + # not 100% sure what this is, so far seems to be harmless. TODO investigate self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # report number of parameters