1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

enable sdpa for nonzero dropout

This commit is contained in:
Driss Guessous 2023-03-05 19:29:29 +00:00
parent ae3a8d5fdd
commit 6170531b8a

View File

@ -49,10 +49,10 @@ class CausalSelfAttention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.dropout = config.dropout self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0 self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash: if not self.flash:
print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0") print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence # causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size))