mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-31 15:23:01 +00:00 
			
		
		
		
	Merge pull request #195 from drisspg/enable_sdpa_with_nonzero_dropout
Enable sdpa for nonzero dropout
This commit is contained in:
		
							
								
								
									
										6
									
								
								model.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								model.py
									
									
									
									
									
								
							| @@ -49,10 +49,10 @@ class CausalSelfAttention(nn.Module): | |||||||
|         self.n_head = config.n_head |         self.n_head = config.n_head | ||||||
|         self.n_embd = config.n_embd |         self.n_embd = config.n_embd | ||||||
|         self.dropout = config.dropout |         self.dropout = config.dropout | ||||||
|         # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary |         # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 | ||||||
|         self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0 |         self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') | ||||||
|         if not self.flash: |         if not self.flash: | ||||||
|             print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0") |             print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") | ||||||
|             # causal mask to ensure that attention is only applied to the left in the input sequence |             # causal mask to ensure that attention is only applied to the left in the input sequence | ||||||
|             self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |             self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) | ||||||
|                                         .view(1, 1, config.block_size, config.block_size)) |                                         .view(1, 1, config.block_size, config.block_size)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Andrej
					Andrej