1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

add support for character-level language models, a new character-level shakespeare dataset, a new config file that shows how to train a character-level baby GPT on it, and adjust the sample function to figure out if it should decode with characters or GPT2 bpe tokens. The current implementation is a bit hacky and basically assumes just these two possibilities. In the future we may want to support more general encoders or decoders.

This commit is contained in:
Andrej Karpathy 2023-01-11 05:27:19 +00:00
parent c2a402f7f7
commit d17350a31d
5 changed files with 137 additions and 4 deletions

View File

@ -0,0 +1,36 @@
# train a miniature character-level shakespeare model
# good for debugging and playing on macbooks and such
out_dir = 'out-shakespeare-char'
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often
# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = True
wandb_log = False # override via command line if you like
wandb_project = 'shakespeare-char'
wandb_run_name = 'mini-gpt'
dataset = 'shakespeare_char'
batch_size = 64
block_size = 128 # context of up to 128 previous characters
# baby GPT model :)
n_layer = 4
n_head = 4
n_embd = 128
dropout = 0.0
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
lr_decay_iters = 5000 # make equal to max_iters usually
min_lr = 1e-4 # learning_rate / 10 usually
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
warmup_iters = 100 # not super necessary potentially
# on macbook also add
# device = 'cpu' # run on cpu only
# compile = False # do not torch compile the model

View File

@ -0,0 +1,67 @@
"""
Prepare the Shakespeare dataset for character-level language modeling.
So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
Will save train.bin, val.bin containing the ids, and meta.pkl containing the
encoder and decoder and some other related info.
"""
import os
import pickle
import requests
import numpy as np
# download the tiny shakespeare dataset
if not os.path.exists('input.txt'):
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open('input.txt', 'w') as f:
f.write(requests.get(data_url).text)
with open('input.txt', 'r') as f:
data = f.read()
print("length of dataset in characters: ", len(data))
# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print("vocab size:", vocab_size)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]
# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids)} tokens")
print(f"val has {len(val_ids)} tokens")
# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile('train.bin')
val_ids.tofile('val.bin')
# save the meta information as well, to help us encode/decode later
meta = {
'vocab_size': vocab_size,
'itos': itos,
'stoi': stoi,
}
with open('meta.pkl', 'wb') as f:
pickle.dump(meta, f)
# length of dataset in characters: 1115394
# all the unique characters:
# !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
# vocab size: 65
# train has 1003854 tokens
# val has 111540 tokens

View File

@ -0,0 +1,9 @@
# tiny shakespeare, character-level
Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level.
After running `prepare.py`:
- train.bin has 1,003,854 tokens
- val.bin has 111,540 tokens

View File

@ -2,6 +2,7 @@
Sample from a trained model Sample from a trained model
""" """
import os import os
import pickle
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import tiktoken import tiktoken
@ -45,9 +46,28 @@ model.to(device)
if compile: if compile:
model = torch.compile(model) # requires PyTorch 2.0 (optional) model = torch.compile(model) # requires PyTorch 2.0 (optional)
# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
load_meta = os.path.exists(meta_path)
if load_meta:
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
else:
# ok let's assume gpt-2 encodings by default
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
# encode the beginning of the prompt # encode the beginning of the prompt
enc = tiktoken.get_encoding("gpt2") start_ids = encode(start)
start_ids = enc.encode(start, allowed_special={"<|endoftext|>"})
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
# run generation # run generation
@ -55,5 +75,5 @@ with torch.no_grad():
with ctx: with ctx:
for k in range(num_samples): for k in range(num_samples):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(enc.decode(y[0].tolist())) print(decode(y[0].tolist()))
print('---------------') print('---------------')

View File

@ -225,6 +225,7 @@ while True:
'model_args': model_args, 'model_args': model_args,
'iter_num': iter_num, 'iter_num': iter_num,
'best_val_loss': best_val_loss, 'best_val_loss': best_val_loss,
'config': config,
} }
print(f"saving checkpoint to {out_dir}") print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))