mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 06:00:29 +00:00
try bring back mingpt init
This commit is contained in:
parent
3cb3fc059c
commit
23a0bfac20
18
model.py
18
model.py
@ -121,10 +121,28 @@ class GPT(nn.Module):
|
||||
# not 100% sure what this is, so far seems to be harmless. TODO investigate
|
||||
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
||||
|
||||
# init all weights
|
||||
self.apply(self._init_weights)
|
||||
# apply special scaled init to the residual projections, per GPT-2 paper
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith('c_proj.weight'):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
||||
|
||||
# report number of parameters
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
print("number of parameters: %.2fM" % (n_params/1e6,))
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
|
Loading…
Reference in New Issue
Block a user