1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-09-21 03:39:44 +00:00

Fix GPT.crop_block_size when flash attention is available

This commit is contained in:
Kirill 2023-03-24 14:51:02 +03:00 committed by GitHub
parent a82b33b525
commit c3f254844d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -207,7 +207,8 @@ class GPT(nn.Module):
self.config.block_size = block_size self.config.block_size = block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
for block in self.transformer.h: for block in self.transformer.h:
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] if hasattr(block.attn, 'bias'):
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
@classmethod @classmethod
def from_pretrained(cls, model_type, override_args=None): def from_pretrained(cls, model_type, override_args=None):