From 7399dfe39d24579cc1fb391475864ad28f19f6f7 Mon Sep 17 00:00:00 2001 From: Yassine Yousfi Date: Mon, 10 Apr 2023 22:56:22 -0700 Subject: [PATCH] dont always dropout! --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 0858f80..287aadd 100644 --- a/model.py +++ b/model.py @@ -69,7 +69,7 @@ class CausalSelfAttention(nn.Module): # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))