From ae06d0b15a9111cbe2ce66b0f1be9ae29c1ecbbe Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 30 Jan 2023 23:18:26 +0000 Subject: [PATCH] add flash attention support, resolving last few issues but for now seems to work ok --- model.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/model.py b/model.py index 62d0cfb..e6683a0 100644 --- a/model.py +++ b/model.py @@ -45,11 +45,16 @@ class CausalSelfAttention(nn.Module): # regularization self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) self.n_head = config.n_head self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + if not self.flash: + print("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) @@ -61,11 +66,17 @@ class CausalSelfAttention(nn.Module): v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + assert self.dropout == 0.0, "need dropout=0.0 for now, PyTorch team is working on fix in #92917" + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection