1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

make more accurate the way in which we count parameters. previous count incorrectly included the positional encoding params, when typically only the number of weight parameters is reported for these models

This commit is contained in:
Andrej Karpathy 2023-02-04 23:51:18 +00:00
parent 3341b4cecc
commit 34720df284

View File

@ -152,8 +152,19 @@ class GPT(nn.Module):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
def get_num_params(self, non_embedding=True):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
The token embeddings would too, except due to the parameter sharing these
params are actually used as weights in the final layer, so we include them.
"""
n_params = sum(p.numel() for p in self.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
return n_params
def _init_weights(self, module):
if isinstance(module, nn.Linear):