mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
try bring back mingpt init
This commit is contained in:
parent
3cb3fc059c
commit
23a0bfac20
18
model.py
18
model.py
@ -121,10 +121,28 @@ class GPT(nn.Module):
|
|||||||
# not 100% sure what this is, so far seems to be harmless. TODO investigate
|
# not 100% sure what this is, so far seems to be harmless. TODO investigate
|
||||||
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
||||||
|
|
||||||
|
# init all weights
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
# apply special scaled init to the residual projections, per GPT-2 paper
|
||||||
|
for pn, p in self.named_parameters():
|
||||||
|
if pn.endswith('c_proj.weight'):
|
||||||
|
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
||||||
|
|
||||||
# report number of parameters
|
# report number of parameters
|
||||||
n_params = sum(p.numel() for p in self.parameters())
|
n_params = sum(p.numel() for p in self.parameters())
|
||||||
print("number of parameters: %.2fM" % (n_params/1e6,))
|
print("number of parameters: %.2fM" % (n_params/1e6,))
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
torch.nn.init.ones_(module.weight)
|
||||||
|
|
||||||
def forward(self, idx, targets=None):
|
def forward(self, idx, targets=None):
|
||||||
device = idx.device
|
device = idx.device
|
||||||
b, t = idx.size()
|
b, t = idx.size()
|
||||||
|
Loading…
Reference in New Issue
Block a user