diff --git a/model.py b/model.py index 4bee2c5..86f47e4 100644 --- a/model.py +++ b/model.py @@ -152,8 +152,19 @@ class GPT(nn.Module): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) # report number of parameters + print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ n_params = sum(p.numel() for p in self.parameters()) - print("number of parameters: %.2fM" % (n_params/1e6,)) + if non_embedding: + n_params -= self.transformer.wpe.weight.numel() + return n_params def _init_weights(self, module): if isinstance(module, nn.Linear):