diff --git a/sample.py b/sample.py index 8b42c8f..670759b 100644 --- a/sample.py +++ b/sample.py @@ -11,7 +11,7 @@ from model import GPTConfig, GPT # ----------------------------------------------------------------------------- init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') out_dir = 'out' # ignored if init_from is not 'resume' -start = "\n" # or "<|endoftext|>" or whatever you like +start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" num_samples = 10 # number of samples to draw max_new_tokens = 500 # number of tokens generated in each sample temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions @@ -74,6 +74,9 @@ else: decode = lambda l: enc.decode(l) # encode the beginning of the prompt +if start.startswith('FILE:'): + with open(start[5:], 'r', encoding='utf-8') as f: + start = f.read() start_ids = encode(start) x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])