mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
add flash attention support, resolving last few issues but for now seems to work ok
This commit is contained in:
parent
0e90ee9d48
commit
ae06d0b15a
15
model.py
15
model.py
@ -45,11 +45,16 @@ class CausalSelfAttention(nn.Module):
|
|||||||
# regularization
|
# regularization
|
||||||
self.attn_dropout = nn.Dropout(config.dropout)
|
self.attn_dropout = nn.Dropout(config.dropout)
|
||||||
self.resid_dropout = nn.Dropout(config.dropout)
|
self.resid_dropout = nn.Dropout(config.dropout)
|
||||||
|
self.n_head = config.n_head
|
||||||
|
self.n_embd = config.n_embd
|
||||||
|
self.dropout = config.dropout
|
||||||
|
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
||||||
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||||
|
if not self.flash:
|
||||||
|
print("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention")
|
||||||
# causal mask to ensure that attention is only applied to the left in the input sequence
|
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||||
.view(1, 1, config.block_size, config.block_size))
|
.view(1, 1, config.block_size, config.block_size))
|
||||||
self.n_head = config.n_head
|
|
||||||
self.n_embd = config.n_embd
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||||
@ -61,6 +66,12 @@ class CausalSelfAttention(nn.Module):
|
|||||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
|
||||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||||
|
if self.flash:
|
||||||
|
# efficient attention using Flash Attention CUDA kernels
|
||||||
|
assert self.dropout == 0.0, "need dropout=0.0 for now, PyTorch team is working on fix in #92917"
|
||||||
|
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
||||||
|
else:
|
||||||
|
# manual implementation of attention
|
||||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||||
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
||||||
att = F.softmax(att, dim=-1)
|
att = F.softmax(att, dim=-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user