diff --git a/config/train_shakespeare_char.py b/config/train_shakespeare_char.py new file mode 100644 index 0000000..975832b --- /dev/null +++ b/config/train_shakespeare_char.py @@ -0,0 +1,36 @@ +# train a miniature character-level shakespeare model +# good for debugging and playing on macbooks and such + +out_dir = 'out-shakespeare-char' +eval_interval = 250 # keep frequent because we'll overfit +eval_iters = 200 +log_interval = 10 # don't print too too often + +# we expect to overfit on this small dataset, so only save when val improves +always_save_checkpoint = True + +wandb_log = False # override via command line if you like +wandb_project = 'shakespeare-char' +wandb_run_name = 'mini-gpt' + +dataset = 'shakespeare_char' +batch_size = 64 +block_size = 128 # context of up to 128 previous characters + +# baby GPT model :) +n_layer = 4 +n_head = 4 +n_embd = 128 +dropout = 0.0 + +learning_rate = 1e-3 # with baby networks can afford to go a bit higher +max_iters = 5000 +lr_decay_iters = 5000 # make equal to max_iters usually +min_lr = 1e-4 # learning_rate / 10 usually +beta2 = 0.99 # make a bit bigger because number of tokens per iter is small + +warmup_iters = 100 # not super necessary potentially + +# on macbook also add +# device = 'cpu' # run on cpu only +# compile = False # do not torch compile the model diff --git a/data/shakespeare_char/prepare.py b/data/shakespeare_char/prepare.py new file mode 100644 index 0000000..6759b2f --- /dev/null +++ b/data/shakespeare_char/prepare.py @@ -0,0 +1,67 @@ +""" +Prepare the Shakespeare dataset for character-level language modeling. +So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. +Will save train.bin, val.bin containing the ids, and meta.pkl containing the +encoder and decoder and some other related info. +""" +import os +import pickle +import requests +import numpy as np + +# download the tiny shakespeare dataset +if not os.path.exists('input.txt'): + data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' + with open('input.txt', 'w') as f: + f.write(requests.get(data_url).text) + +with open('input.txt', 'r') as f: + data = f.read() +print("length of dataset in characters: ", len(data)) + +# get all the unique characters that occur in this text +chars = sorted(list(set(data))) +vocab_size = len(chars) +print("all the unique characters:", ''.join(chars)) +print("vocab size:", vocab_size) + +# create a mapping from characters to integers +stoi = { ch:i for i,ch in enumerate(chars) } +itos = { i:ch for i,ch in enumerate(chars) } +def encode(s): + return [stoi[c] for c in s] # encoder: take a string, output a list of integers +def decode(l): + ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string + +# create the train and test splits +n = len(data) +train_data = data[:int(n*0.9)] +val_data = data[int(n*0.9):] + +# encode both to integers +train_ids = encode(train_data) +val_ids = encode(val_data) +print(f"train has {len(train_ids)} tokens") +print(f"val has {len(val_ids)} tokens") + +# export to bin files +train_ids = np.array(train_ids, dtype=np.uint16) +val_ids = np.array(val_ids, dtype=np.uint16) +train_ids.tofile('train.bin') +val_ids.tofile('val.bin') + +# save the meta information as well, to help us encode/decode later +meta = { + 'vocab_size': vocab_size, + 'itos': itos, + 'stoi': stoi, +} +with open('meta.pkl', 'wb') as f: + pickle.dump(meta, f) + +# length of dataset in characters: 1115394 +# all the unique characters: +# !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz +# vocab size: 65 +# train has 1003854 tokens +# val has 111540 tokens diff --git a/data/shakespeare_char/readme.md b/data/shakespeare_char/readme.md new file mode 100644 index 0000000..d597b79 --- /dev/null +++ b/data/shakespeare_char/readme.md @@ -0,0 +1,9 @@ + +# tiny shakespeare, character-level + +Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. + +After running `prepare.py`: + +- train.bin has 1,003,854 tokens +- val.bin has 111,540 tokens diff --git a/sample.py b/sample.py index 0fcf19a..6ff0ea2 100644 --- a/sample.py +++ b/sample.py @@ -2,6 +2,7 @@ Sample from a trained model """ import os +import pickle from contextlib import nullcontext import torch import tiktoken @@ -45,9 +46,28 @@ model.to(device) if compile: model = torch.compile(model) # requires PyTorch 2.0 (optional) +# look for the meta pickle in case it is available in the dataset folder +load_meta = False +if 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these... + meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') + load_meta = os.path.exists(meta_path) +if load_meta: + print(f"Loading meta from {meta_path}...") + with open(meta_path, 'rb') as f: + meta = pickle.load(f) + # TODO want to make this more general to arbitrary encoder/decoder schemes + stoi, itos = meta['stoi'], meta['itos'] + encode = lambda s: [stoi[c] for c in s] + decode = lambda l: ''.join([itos[i] for i in l]) +else: + # ok let's assume gpt-2 encodings by default + print("No meta.pkl found, assuming GPT-2 encodings...") + enc = tiktoken.get_encoding("gpt2") + encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) + decode = lambda l: enc.decode(l) + # encode the beginning of the prompt -enc = tiktoken.get_encoding("gpt2") -start_ids = enc.encode(start, allowed_special={"<|endoftext|>"}) +start_ids = encode(start) x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) # run generation @@ -55,5 +75,5 @@ with torch.no_grad(): with ctx: for k in range(num_samples): y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) - print(enc.decode(y[0].tolist())) - print('---------------') + print(decode(y[0].tolist())) + print('---------------') diff --git a/train.py b/train.py index 37bd3fb..fb98981 100644 --- a/train.py +++ b/train.py @@ -225,6 +225,7 @@ while True: 'model_args': model_args, 'iter_num': iter_num, 'best_val_loss': best_val_loss, + 'config': config, } print(f"saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))