mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
inference time mini-optimization low-hanging fruit ty @jxtps for raising: when we are running inference we can apply lm_head on only the very last token
This commit is contained in:
parent
e21cbf887f
commit
8f85b83347
9
model.py
9
model.py
@ -133,12 +133,15 @@ class GPT(nn.Module):
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
x = self.transformer.ln_f(x)
|
||||
logits = self.lm_head(x)
|
||||
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
# if we are given some desired targets also calculate the loss
|
||||
logits = self.lm_head(x)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
else:
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
loss = None
|
||||
|
||||
return logits, loss
|
||||
|
||||
|
2
train.py
2
train.py
@ -102,7 +102,7 @@ def get_batch(split):
|
||||
iter_num = 0
|
||||
best_val_loss = 1e9
|
||||
|
||||
# model init
|
||||
# model init. TODO: fix bug we should also propagate the correct vocab_size to the model_args
|
||||
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
|
||||
if init_from == 'scratch':
|
||||
# init a new model from scratch
|
||||
|
Loading…
Reference in New Issue
Block a user