1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

candidate changes to apis, have to think through more

This commit is contained in:
Andrej Karpathy 2023-01-01 01:29:48 +00:00
parent 7c6ea8409e
commit 2febf4463c
8 changed files with 111 additions and 19 deletions

View File

@ -42,6 +42,16 @@ $ python sample.py
Training on 1 A100 40GB GPU overnight currently gets loss ~3.74, training on 4 gets ~3.60. Random chance at init is -ln(1/50257) = 10.82. Which brings us to baselines: Training on 1 A100 40GB GPU overnight currently gets loss ~3.74, training on 4 gets ~3.60. Random chance at init is -ln(1/50257) = 10.82. Which brings us to baselines:
## finetuning
For an example of how to finetune a GPT on new text go to `data/shakespeare` and look at `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`. Unlike OpenWebText this will run in seconds. Finetuning takes very little time, e.g. on a single GPT just a few minutes. Run an example finetuning like:
```
$ python train.py finetune_shakespeare
```
This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py` to generate infinite Shakespeare. Note that you'll have to edit it to point to the correct `out_dir`.
## baselines ## baselines
OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows: OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:

View File

@ -0,0 +1,22 @@
import time
out_dir = 'out-shakespeare'
eval_interval = 200
wandb_log = False # feel free to turn on
wandb_project = 'shakespeare'
wandb_run_name = 'ft-' + str(time.time())
compile_model = False # takes too little time to finetune, not worth it
# save a nice and overfit checkpoint that
# will only speak Shakespeare and forgets
# everything else about the world #dark
always_save_checkpoint = True
dataset = 'shakespeare'
init_from = 'gpt2-xl'
batch_size = 1
block_size = 512
learning_rate = 1e-5
max_iters = 1000
decay_lr = False

View File

@ -35,6 +35,7 @@ enc = tiktoken.get_encoding("gpt2")
def process(example): def process(example):
ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
# note: I think eot should be prepended not appended... hmm. it's called "eot" though...
out = {'ids': ids, 'len': len(ids)} out = {'ids': ids, 'len': len(ids)}
return out return out

View File

@ -0,0 +1,32 @@
import os
import requests
import tiktoken
import numpy as np
# download the tiny shakespeare dataset
if not os.path.exists('input.txt'):
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open('input.txt', 'w') as f:
f.write(requests.get(data_url).text)
with open('input.txt', 'r') as f:
data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]
# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids)} tokens")
print(f"val has {len(val_ids)} tokens")
# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile('train.bin')
val_ids.tofile('val.bin')
# train.bin has 301,966 tokens
# val.bin has 36,059 tokens

View File

@ -0,0 +1,9 @@
# tiny shakespeare
Tiny shakespeare, of the good old char-rnn fame :)
After running `prepare.py`:
- train.bin has 301,966 tokens
- val.bin has 36,059 tokens

View File

@ -90,7 +90,7 @@ class Block(nn.Module):
x = x + self.mlp(self.ln_2(x)) x = x + self.mlp(self.ln_2(x))
return x return x
@dataclass(frozen=True) @dataclass
class GPTConfig: class GPTConfig:
block_size: int = 1024 block_size: int = 1024
vocab_size: int = 50257 vocab_size: int = 50257
@ -105,7 +105,7 @@ class GPT(nn.Module):
super().__init__() super().__init__()
assert config.vocab_size is not None assert config.vocab_size is not None
assert config.block_size is not None assert config.block_size is not None
self.block_size = config.block_size self.config = config
self.transformer = nn.ModuleDict(dict( self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd), wte = nn.Embedding(config.vocab_size, config.n_embd),
@ -123,7 +123,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None): def forward(self, idx, targets=None):
device = idx.device device = idx.device
b, t = idx.size() b, t = idx.size()
assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# forward the GPT model itself # forward the GPT model itself
@ -146,27 +146,36 @@ class GPT(nn.Module):
# model surgery to decrease the block size if necessary # model surgery to decrease the block size if necessary
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
# but want to use a smaller block size for some smaller, simpler model # but want to use a smaller block size for some smaller, simpler model
assert block_size <= self.block_size assert block_size <= self.config.block_size
self.block_size = block_size self.config.block_size = block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
for block in self.transformer.h: for block in self.transformer.h:
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
@classmethod @classmethod
def from_pretrained(cls, model_type): def from_pretrained(cls, model_type, override_args):
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
# only dropout can be overridden see more notes below
assert all(k == 'dropout' for k in override_args)
from transformers import GPT2LMHeadModel from transformers import GPT2LMHeadModel
print("loading weights from pretrained gpt: %s" % model_type) print("loading weights from pretrained gpt: %s" % model_type)
layer_config = { # n_layer, n_head and n_embd are determined from model_type
config_args = {
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}[model_type] }[model_type]
# we can override the dropout rate
if 'dropout' in override_args:
config_args['dropout'] = override_args['dropout']
# block_size is always 1024 for GPT model checkpoints
# if one wants a lower block_size it has to be done through model surgery
# later, by calling crop_block_size()
# create a from-scratch initialized minGPT model # create a from-scratch initialized minGPT model
config = GPTConfig(block_size=1024, **layer_config) config = GPTConfig(block_size=1024, **config_args)
model = GPT(config) model = GPT(config)
sd = model.state_dict() sd = model.state_dict()
@ -248,7 +257,7 @@ class GPT(nn.Module):
""" """
for _ in range(max_new_tokens): for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size # if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence # forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond) logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature # pluck the logits at the final step and scale by desired temperature

View File

@ -21,18 +21,18 @@ model = GPT(gptconf)
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
model.eval() model.eval()
model.to(device) model.to(device)
model = torch.compile(model) # requires PyTorch 2.0 #model = torch.compile(model) # requires PyTorch 2.0 (optional)
enc = tiktoken.get_encoding("gpt2") enc = tiktoken.get_encoding("gpt2")
#start = enc.encode("\n") start = enc.encode("\n") # user choice on what token to start with
start = [enc.eot_token] #start = [enc.eot_token]
x = (torch.tensor(start, dtype=torch.long, device=device)[None, ...]) x = (torch.tensor(start, dtype=torch.long, device=device)[None, ...])
for k in range(1): for k in range(10):
with torch.no_grad(): with torch.no_grad():
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
y = model.generate(x, 300, temperature=0.8, top_k=200) y = model.generate(x, 500, temperature=0.8, top_k=200)
print(enc.decode(y[0].tolist())) print(enc.decode(y[0].tolist()))
print('---------------') print('---------------')

View File

@ -31,6 +31,7 @@ eval_interval = 500
log_interval = 1 log_interval = 1
eval_iters = 50 eval_iters = 50
eval_only = False # if True, script exits right after the first eval eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
# wandb logging # wandb logging
wandb_log = False # disabled by default wandb_log = False # disabled by default
wandb_entity = 'karpathy' wandb_entity = 'karpathy'
@ -138,6 +139,7 @@ elif init_from == 'resume':
checkpoint_model_args = checkpoint['model_args'] checkpoint_model_args = checkpoint['model_args']
for k, v in model_args.items(): for k, v in model_args.items():
assert checkpoint_model_args[k] == v, "for now" assert checkpoint_model_args[k] == v, "for now"
# TODO: think through how passed in params should interact with checkpoint params
gptconf = GPTConfig(**model_args) gptconf = GPTConfig(**model_args)
model = GPT(gptconf) model = GPT(gptconf)
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
@ -146,9 +148,14 @@ elif init_from == 'resume':
elif init_from.startswith('gpt2'): elif init_from.startswith('gpt2'):
print(f"Initializing from OpenAI GPT-2 weights: {init_from}") print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
# initialize from OpenAI GPT-2 weights # initialize from OpenAI GPT-2 weights
model = GPT.from_pretrained(init_from) override_args = dict(dropout=dropout)
model = GPT.from_pretrained(init_from, override_args)
# read off and override the GPT sizing model args from the model config
model_args['n_layer'] = model.config.n_layer
model_args['n_head'] = model.config.n_head
model_args['n_embd'] = model.config.n_embd
# crop down the model block size if desired # crop down the model block size if desired
if block_size < model.block_size: if block_size < model.config.block_size:
model.crop_block_size(block_size) model.crop_block_size(block_size)
model.to(device) model.to(device)
@ -227,7 +234,7 @@ while True:
"val/loss": losses['val'], "val/loss": losses['val'],
"lr": lr, "lr": lr,
}) })
if losses['val'] < best_val_loss: if losses['val'] < best_val_loss or always_save_checkpoint:
best_val_loss = losses['val'] best_val_loss = losses['val']
raw_model = model.module if ddp else model raw_model = model.module if ddp else model
if iter_num > 0: if iter_num > 0:
@ -238,6 +245,7 @@ while True:
'iter_num': iter_num, 'iter_num': iter_num,
'best_val_loss': best_val_loss, 'best_val_loss': best_val_loss,
} }
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
if iter_num == 0 and eval_only: if iter_num == 0 and eval_only:
break break
@ -260,7 +268,8 @@ while True:
iter_num += 1 iter_num += 1
# termination conditions # termination conditions
if iter_num >= max_iters: if iter_num > max_iters:
break break
destroy_process_group() if ddp:
destroy_process_group()