mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 06:00:29 +00:00
Merge pull request #275 from apivovarov/rm_unsqueeze
Remove pos unsqueeze(0)
This commit is contained in:
commit
18ee6b62b6
4
model.py
4
model.py
@ -178,11 +178,11 @@ class GPT(nn.Module):
|
|||||||
device = idx.device
|
device = idx.device
|
||||||
b, t = idx.size()
|
b, t = idx.size()
|
||||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
||||||
|
|
||||||
# forward the GPT model itself
|
# forward the GPT model itself
|
||||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
||||||
x = self.transformer.drop(tok_emb + pos_emb)
|
x = self.transformer.drop(tok_emb + pos_emb)
|
||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user