mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-31 15:23:01 +00:00 
			
		
		
		
	try bring back mingpt init
This commit is contained in:
		
							
								
								
									
										18
									
								
								model.py
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								model.py
									
									
									
									
									
								
							| @@ -121,10 +121,28 @@ class GPT(nn.Module): | |||||||
|         # not 100% sure what this is, so far seems to be harmless. TODO investigate |         # not 100% sure what this is, so far seems to be harmless. TODO investigate | ||||||
|         self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying |         self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying | ||||||
|  |  | ||||||
|  |         # init all weights | ||||||
|  |         self.apply(self._init_weights) | ||||||
|  |         # apply special scaled init to the residual projections, per GPT-2 paper | ||||||
|  |         for pn, p in self.named_parameters(): | ||||||
|  |             if pn.endswith('c_proj.weight'): | ||||||
|  |                 torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) | ||||||
|  |  | ||||||
|         # report number of parameters |         # report number of parameters | ||||||
|         n_params = sum(p.numel() for p in self.parameters()) |         n_params = sum(p.numel() for p in self.parameters()) | ||||||
|         print("number of parameters: %.2fM" % (n_params/1e6,)) |         print("number of parameters: %.2fM" % (n_params/1e6,)) | ||||||
|  |  | ||||||
|  |     def _init_weights(self, module): | ||||||
|  |         if isinstance(module, nn.Linear): | ||||||
|  |             torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | ||||||
|  |             if module.bias is not None: | ||||||
|  |                 torch.nn.init.zeros_(module.bias) | ||||||
|  |         elif isinstance(module, nn.Embedding): | ||||||
|  |             torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | ||||||
|  |         elif isinstance(module, nn.LayerNorm): | ||||||
|  |             torch.nn.init.zeros_(module.bias) | ||||||
|  |             torch.nn.init.ones_(module.weight) | ||||||
|  |  | ||||||
|     def forward(self, idx, targets=None): |     def forward(self, idx, targets=None): | ||||||
|         device = idx.device |         device = idx.device | ||||||
|         b, t = idx.size() |         b, t = idx.size() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Andrej Karpathy
					Andrej Karpathy