1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

Merge pull request #240 from YassineYousfi/master

don't dropout in eval mode
This commit is contained in:
Andrej 2023-04-12 22:43:59 -07:00 committed by GitHub
commit 01e48ec1ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -69,7 +69,7 @@ class CausalSelfAttention(nn.Module):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash: if self.flash:
# efficient attention using Flash Attention CUDA kernels # efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
else: else:
# manual implementation of attention # manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))