diff --git a/model.py b/model.py index 1e5e1fd..ae90db6 100644 --- a/model.py +++ b/model.py @@ -178,11 +178,11 @@ class GPT(nn.Module): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x)