1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

if dropout > 0.0 disable Flash until pytorch fix. don't assert fail sigh

This commit is contained in:
Andrej Karpathy 2023-02-02 23:22:56 +00:00
parent d8b1a94519
commit 1e87509e47

View File

@ -49,9 +49,9 @@ class CausalSelfAttention(nn.Module):
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.dropout = config.dropout self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
if not self.flash: if not self.flash:
print("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention") print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
# causal mask to ensure that attention is only applied to the left in the input sequence # causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size))
@ -68,7 +68,6 @@ class CausalSelfAttention(nn.Module):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash: if self.flash:
# efficient attention using Flash Attention CUDA kernels # efficient attention using Flash Attention CUDA kernels
assert self.dropout == 0.0, "need dropout=0.0 for now, PyTorch team is working on fix in #92917"
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
else: else:
# manual implementation of attention # manual implementation of attention