diff --git a/model.py b/model.py index 1b32cdf..0858f80 100644 --- a/model.py +++ b/model.py @@ -49,10 +49,10 @@ class CausalSelfAttention(nn.Module): self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout - # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary - self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0 + # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') if not self.flash: - print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0") + print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size))