mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
if dropout > 0.0 disable Flash until pytorch fix. don't assert fail sigh
This commit is contained in:
parent
d8b1a94519
commit
1e87509e47
5
model.py
5
model.py
@ -49,9 +49,9 @@ class CausalSelfAttention(nn.Module):
|
|||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
||||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
|
||||||
if not self.flash:
|
if not self.flash:
|
||||||
print("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention")
|
print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
|
||||||
# causal mask to ensure that attention is only applied to the left in the input sequence
|
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||||
.view(1, 1, config.block_size, config.block_size))
|
.view(1, 1, config.block_size, config.block_size))
|
||||||
@ -68,7 +68,6 @@ class CausalSelfAttention(nn.Module):
|
|||||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||||
if self.flash:
|
if self.flash:
|
||||||
# efficient attention using Flash Attention CUDA kernels
|
# efficient attention using Flash Attention CUDA kernels
|
||||||
assert self.dropout == 0.0, "need dropout=0.0 for now, PyTorch team is working on fix in #92917"
|
|
||||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
||||||
else:
|
else:
|
||||||
# manual implementation of attention
|
# manual implementation of attention
|
||||||
|
Loading…
Reference in New Issue
Block a user