From 8f85b8334749d1b6c3af3e11f4874bc1d0c0cad2 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 12 Jan 2023 06:02:50 +0000 Subject: [PATCH] inference time mini-optimization low-hanging fruit ty @jxtps for raising: when we are running inference we can apply lm_head on only the very last token --- model.py | 9 ++++++--- train.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/model.py b/model.py index 236117c..18cf8b2 100644 --- a/model.py +++ b/model.py @@ -133,12 +133,15 @@ class GPT(nn.Module): for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) - logits = self.lm_head(x) - # if we are given some desired targets also calculate the loss - loss = None if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + loss = None return logits, loss diff --git a/train.py b/train.py index fb98981..b888310 100644 --- a/train.py +++ b/train.py @@ -102,7 +102,7 @@ def get_batch(split): iter_num = 0 best_val_loss = 1e9 -# model init +# model init. TODO: fix bug we should also propagate the correct vocab_size to the model_args model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout) if init_from == 'scratch': # init a new model from scratch