mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-30 06:43:04 +00:00 
			
		
		
		
	add a patch to fix mysterious unwanted prefix in state dict? maybe remove later
This commit is contained in:
		
							
								
								
									
										9
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								train.py
									
									
									
									
									
								
							| @@ -142,7 +142,14 @@ elif init_from == 'resume': | |||||||
|         # TODO: think through how passed in params should interact with checkpoint params |         # TODO: think through how passed in params should interact with checkpoint params | ||||||
|     gptconf = GPTConfig(**model_args) |     gptconf = GPTConfig(**model_args) | ||||||
|     model = GPT(gptconf) |     model = GPT(gptconf) | ||||||
|     model.load_state_dict(checkpoint['model']) |     state_dict = checkpoint['model'] | ||||||
|  |     # fix the keys of the state dictionary :( | ||||||
|  |     # honestly no idea how checkpoints sometimes get this prefix, have to debug more | ||||||
|  |     unwanted_prefix = '_orig_mod.' | ||||||
|  |     for k,v in list(state_dict.items()): | ||||||
|  |         if k.startswith(unwanted_prefix): | ||||||
|  |             state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | ||||||
|  |     model.load_state_dict(state_dict) | ||||||
|     iter_num = checkpoint['iter_num'] |     iter_num = checkpoint['iter_num'] | ||||||
|     best_val_loss = checkpoint['best_val_loss'] |     best_val_loss = checkpoint['best_val_loss'] | ||||||
| elif init_from.startswith('gpt2'): | elif init_from.startswith('gpt2'): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Andrej Karpathy
					Andrej Karpathy