mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-11-04 01:03:02 +00:00 
			
		
		
		
	fix bug... if topk > vocab_size, torch.topk will throw error
This commit is contained in:
		
							
								
								
									
										2
									
								
								model.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								model.py
									
									
									
									
									
								
							@@ -277,7 +277,7 @@ class GPT(nn.Module):
 | 
			
		||||
            logits = logits[:, -1, :] / temperature
 | 
			
		||||
            # optionally crop the logits to only the top k options
 | 
			
		||||
            if top_k is not None:
 | 
			
		||||
                v, _ = torch.topk(logits, top_k)
 | 
			
		||||
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
 | 
			
		||||
                logits[logits < v[:, [-1]]] = -float('Inf')
 | 
			
		||||
            # apply softmax to convert logits to (normalized) probabilities
 | 
			
		||||
            probs = F.softmax(logits, dim=-1)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user