From ea4de192e01fbcb8e6ef84babd09d819e6533cc1 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 2 Jan 2023 02:11:39 +0000 Subject: [PATCH] reshuffle args inside sample.py --- sample.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/sample.py b/sample.py index b68296a..e245efc 100644 --- a/sample.py +++ b/sample.py @@ -6,33 +6,45 @@ import torch import tiktoken from model import GPTConfig, GPT +# ----------------------------------------------------------------------------- +# todo make these overridable like in train.py +out_dir = 'out' device = 'cuda:2' -torch.manual_seed(1337) +compile = False +start = "\n" # or "<|endoftext|>" or whatever you like +num_samples = 10 # number of samples to draw +max_new_tokens = 500 # number of tokens generated in each sample +temperature = 0.8 # higher temperature (up to 1) is more random, lower (down to 0) means more greedy +top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability +seed = 1337 +# ----------------------------------------------------------------------------- + +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn -out_dir = 'out' +# model ckpt_path = os.path.join(out_dir, 'ckpt.pt') checkpoint = torch.load(ckpt_path, map_location=device) - -# model gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) model.load_state_dict(checkpoint['model']) model.eval() model.to(device) -#model = torch.compile(model) # requires PyTorch 2.0 (optional) +if compile: + model = torch.compile(model) # requires PyTorch 2.0 (optional) +# encode the beginning of the prompt enc = tiktoken.get_encoding("gpt2") -start = enc.encode("\n") # user choice on what token to start with -#start = [enc.eot_token] -x = (torch.tensor(start, dtype=torch.long, device=device)[None, ...]) +start_ids = enc.encode(start, allowed_special={"<|endoftext|>"}) +x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) -for k in range(10): +for k in range(num_samples): with torch.no_grad(): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - y = model.generate(x, 500, temperature=0.8, top_k=200) + y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) print(enc.decode(y[0].tolist())) print('---------------')