From 39ae397a9344e202c366200f6605b310449c1b2e Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 17 May 2023 02:30:18 +0000 Subject: [PATCH] Remove pos unsqueeze(0) --- model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index 1e5e1fd..ae90db6 100644 --- a/model.py +++ b/model.py @@ -178,11 +178,11 @@ class GPT(nn.Module): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x)