1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

try bring back mingpt init

This commit is contained in:
Andrej Karpathy 2023-01-27 16:52:18 +00:00
parent 3cb3fc059c
commit 23a0bfac20

View File

@ -121,10 +121,28 @@ class GPT(nn.Module):
# not 100% sure what this is, so far seems to be harmless. TODO investigate # not 100% sure what this is, so far seems to be harmless. TODO investigate
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# report number of parameters # report number of parameters
n_params = sum(p.numel() for p in self.parameters()) n_params = sum(p.numel() for p in self.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,)) print("number of parameters: %.2fM" % (n_params/1e6,))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, idx, targets=None): def forward(self, idx, targets=None):
device = idx.device device = idx.device
b, t = idx.size() b, t = idx.size()