mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 06:00:29 +00:00
Merge pull request #195 from drisspg/enable_sdpa_with_nonzero_dropout
Enable sdpa for nonzero dropout
This commit is contained in:
commit
0d8fbd11ae
6
model.py
6
model.py
@ -49,10 +49,10 @@ class CausalSelfAttention(nn.Module):
|
|||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
||||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||||
if not self.flash:
|
if not self.flash:
|
||||||
print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
|
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||||
# causal mask to ensure that attention is only applied to the left in the input sequence
|
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||||
.view(1, 1, config.block_size, config.block_size))
|
.view(1, 1, config.block_size, config.block_size))
|
||||||
|
Loading…
Reference in New Issue
Block a user