1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-09-21 11:49:46 +00:00

fix bug... if topk > vocab_size, torch.topk will throw error

This commit is contained in:
Andrej Karpathy 2023-01-14 03:57:00 +00:00
parent 57735f532d
commit 91d02510ce

View File

@ -277,7 +277,7 @@ class GPT(nn.Module):
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)