mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-13 05:19:58 +00:00
inference time mini-optimization low-hanging fruit ty @jxtps for raising: when we are running inference we can apply lm_head on only the very last token
This commit is contained in:
parent
e21cbf887f
commit
8f85b83347
9
model.py
9
model.py
@ -133,12 +133,15 @@ class GPT(nn.Module):
|
|||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
x = self.transformer.ln_f(x)
|
x = self.transformer.ln_f(x)
|
||||||
logits = self.lm_head(x)
|
|
||||||
|
|
||||||
# if we are given some desired targets also calculate the loss
|
|
||||||
loss = None
|
|
||||||
if targets is not None:
|
if targets is not None:
|
||||||
|
# if we are given some desired targets also calculate the loss
|
||||||
|
logits = self.lm_head(x)
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||||
|
else:
|
||||||
|
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||||
|
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||||
|
loss = None
|
||||||
|
|
||||||
return logits, loss
|
return logits, loss
|
||||||
|
|
||||||
|
2
train.py
2
train.py
@ -102,7 +102,7 @@ def get_batch(split):
|
|||||||
iter_num = 0
|
iter_num = 0
|
||||||
best_val_loss = 1e9
|
best_val_loss = 1e9
|
||||||
|
|
||||||
# model init
|
# model init. TODO: fix bug we should also propagate the correct vocab_size to the model_args
|
||||||
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
|
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
|
||||||
if init_from == 'scratch':
|
if init_from == 'scratch':
|
||||||
# init a new model from scratch
|
# init a new model from scratch
|
||||||
|
Loading…
Reference in New Issue
Block a user