1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-20 23:20:30 +00:00

Merge pull request from drisspg/enable_sdpa_with_nonzero_dropout

Enable sdpa for nonzero dropout
This commit is contained in:
Andrej 2023-03-06 21:47:20 -08:00 committed by GitHub
commit 0d8fbd11ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -49,10 +49,10 @@ class CausalSelfAttention(nn.Module):
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))