mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
fix bug... if topk > vocab_size, torch.topk will throw error
This commit is contained in:
parent
57735f532d
commit
91d02510ce
2
model.py
2
model.py
@ -277,7 +277,7 @@ class GPT(nn.Module):
|
|||||||
logits = logits[:, -1, :] / temperature
|
logits = logits[:, -1, :] / temperature
|
||||||
# optionally crop the logits to only the top k options
|
# optionally crop the logits to only the top k options
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(logits, top_k)
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||||
# apply softmax to convert logits to (normalized) probabilities
|
# apply softmax to convert logits to (normalized) probabilities
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user