mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
fix bug... if topk > vocab_size, torch.topk will throw error
This commit is contained in:
parent
57735f532d
commit
91d02510ce
2
model.py
2
model.py
@ -277,7 +277,7 @@ class GPT(nn.Module):
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop the logits to only the top k options
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
Loading…
Reference in New Issue
Block a user