1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

add support for DDP training. the scaling timings right now do not look good by default, have to dig more into

This commit is contained in:
Andrej Karpathy 2022-12-29 05:06:07 +00:00
parent ee6459f1d0
commit dea1507252
2 changed files with 67 additions and 38 deletions

View File

@ -1,34 +1,40 @@
# nanoGPT # nanoGPT
The cleanest, fastest repository for training/finetuning medium-sized GPTs. Still under active development, currently trying to reproduce GPT-2 on OpenWebText dataset. The code itself is tiny, plain and readable. At the moment `train.py` is a ~200-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can also load the GPT-2 weights from OpenAI. The cleanest, smallest, fastest repository for training/finetuning medium-sized GPTs. Still under active development, currently working to reproduce GPT-2 on OpenWebText dataset. The code itself aims by design to be plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it.
## install ## install
We need a few dependencies: Dependencies:
- [pytorch](https://pytorch.org), of course - [pytorch](https://pytorch.org) <3
- numpy - numpy <3
- `pip install datasets` for huggingface datasets - `pip install datasets` for huggingface datasets <3
- `pip install tiktoken` for OpenAI's fast bpe code - `pip install tiktoken` for OpenAI's fast bpe code <3
- `pip install wandb` for optional logging - `pip install wandb` for optional logging <3
## usage ## usage
To render a dataset we first tokenize some documents into one giant array of indices. E.g. for OpenWebText see: To render a dataset we first tokenize some documents into one simple long 1D array of indices. E.g. for OpenWebText see:
``` ```
$ cd data/openwebtext $ cd data/openwebtext
$ python prepare.py $ python prepare.py
``` ```
To download and tokenize the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. The training script currently by default tries to reproduce the smallest GPT-2 released by OpenAI, i.e. the 124M version of GPT-2. We can train as follows, though I encourage you to read the code and see all of the settings and paths up top in the file: To download and tokenize the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. This will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. The training script currently by default tries to reproduce the smallest GPT-2 released by OpenAI, i.e. the 124M version of GPT-2. We can demo train as follows on a single device, though I encourage you to read the code and see all of the settings and paths up top in the file:
``` ```
$ python train.py $ python train.py
``` ```
Once some checkpoints are written to the output directory `out`, we can sample from the model: To train using PyTorch Distributed Data Parallel (DDP) run the script with torchrun. For example to train on a node with 4 GPUs run:
```
$ torchrun --standalone --nproc_per_node=4 train.py
```
Once some checkpoints are written to the output directory (e.g. `./out` by default), we can sample from the model:
``` ```
$ python sample.py $ python sample.py

View File

@ -1,6 +1,12 @@
""" """
Train a GPT model on a dataset of text. One GPU version. This training script can be run both on a single gpu in debug mode,
The text is assumed to pre-tokenized and inside files train.pt and val.pt and also in a larger training run with distributed data parallel (ddp).
To run in debug mode example:
$ python train.py --batch_size=32 --other=args
To run DDP on 4 gpus on one node, example:
$ torchrun --standalone --nproc_per_node=4 train.py
""" """
import os import os
@ -9,9 +15,11 @@ import time
import math import math
from ast import literal_eval from ast import literal_eval
import wandb
import numpy as np import numpy as np
import torch import torch
import wandb from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from model import GPTConfig, GPT from model import GPTConfig, GPT
@ -49,9 +57,11 @@ decay_lr = True # whether to decay the learning rate
warmup_iters = 2000 # how many steps to warm up for warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 320000 # how many steps to decay the learning rate for lr_decay_iters = 320000 # how many steps to decay the learning rate for
min_lr = 1e-5 # minimum learning rate min_lr = 1e-5 # minimum learning rate
# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# poor man's Configurator. Potentially a bad idea. Example usage: # poor man's Configurator. Potentially a bad idea. Example usage:
# python train.py override_file --batch_size=32 # $ python train.py override_file --batch_size=32
# this will first run config/override_file.py, then override batch_size to 32 # this will first run config/override_file.py, then override batch_size to 32
for arg in sys.argv[1:]: for arg in sys.argv[1:]:
if '=' not in arg: if '=' not in arg:
@ -71,7 +81,7 @@ for arg in sys.argv[1:]:
try: try:
# attempt to eval it it (e.g. if bool, number, or etc) # attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val) attempt = literal_eval(val)
except SyntaxError: except (SyntaxError, ValueError):
# if that goes wrong, just use the string # if that goes wrong, just use the string
attempt = val attempt = val
# ensure the types match ok # ensure the types match ok
@ -82,13 +92,21 @@ for arg in sys.argv[1:]:
else: else:
raise ValueError(f"Unknown config key: {key}") raise ValueError(f"Unknown config key: {key}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
ddp = int(os.environ.get('LOCAL_RANK', -1)) != -1 # is this a ddp run?
if ddp:
init_process_group(backend=backend)
gpu_id = int(os.environ["LOCAL_RANK"])
device = f"cuda:{gpu_id}"
else:
gpu_id = 0 # gpu_id 0 means this is the (single) master process, basically
os.makedirs(out_dir, exist_ok=True) if gpu_id == 0:
torch.manual_seed(1337) os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
# poor man's data loader, TODO use real DataLoader... # poor man's data loader, TODO evaluate need for actual DataLoader
data_dir = os.path.join('data', dataset) data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
@ -101,16 +119,16 @@ def get_batch(split):
return x, y return x, y
# model init # model init
# TODO I don't love this whole part/API yet
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout) model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
if init_from == 'scratch': if init_from == 'scratch':
# init a new model from scratch # init a new model from scratch
gptconf = GPTConfig(**model_args) gptconf = GPTConfig(**model_args)
model = GPT(gptconf) model = GPT(gptconf)
elif init_from == 'resume': elif init_from == 'resume':
# resume training from a checkpoint. TODO: do we resume iter_num etc too? (yes...) # resume training from a checkpoint.
# TODO: should we also resume iter_num and best_val_loss?
ckpt_path = os.path.join(out_dir, 'ckpt.pt') ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint_model_args = checkpoint['model_args'] checkpoint_model_args = checkpoint['model_args']
for k, v in model_args.items(): for k, v in model_args.items():
assert checkpoint_model_args[k] == v, "for now" assert checkpoint_model_args[k] == v, "for now"
@ -120,10 +138,20 @@ elif init_from == 'resume':
elif init_from.startswith('gpt2'): elif init_from.startswith('gpt2'):
# initialize from OpenAI GPT-2 weights # initialize from OpenAI GPT-2 weights
model = GPT.from_pretrained(init_from) model = GPT.from_pretrained(init_from)
if block_size < model.block_size: # crop down the model block size if desired
if block_size < model.block_size:
model.crop_block_size(block_size) model.crop_block_size(block_size)
model.to(device) model.to(device)
# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, betas)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])
# wrap model into DDP container
if ddp:
model = DDP(model, device_ids=[gpu_id])
@torch.no_grad() @torch.no_grad()
def estimate_loss(): def estimate_loss():
out = {} out = {}
@ -139,11 +167,6 @@ def estimate_loss():
model.train() model.train()
return out return out
# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, betas)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])
# learning rate decay scheduler (cosine with warmup) # learning rate decay scheduler (cosine with warmup)
def get_lr(iter): def get_lr(iter):
# 1) linear warmup for warmup_iters steps # 1) linear warmup for warmup_iters steps
@ -155,11 +178,11 @@ def get_lr(iter):
# 3) in between, use cosine decay down to min learning rate # 3) in between, use cosine decay down to min learning rate
decay_ratio = (iter - warmup_iters) / (lr_decay_iters - warmup_iters) decay_ratio = (iter - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1 assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # ranges 0..1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr) return min_lr + coeff * (learning_rate - min_lr)
# logging # logging
if wandb_log: if wandb_log and gpu_id == 0:
wandb.init(project=wandb_project, entity=wandb_entity, name=wandb_run_name) wandb.init(project=wandb_project, entity=wandb_entity, name=wandb_run_name)
wandb.config = { wandb.config = {
"batch_size": batch_size, "batch_size": batch_size,
@ -169,7 +192,6 @@ if wandb_log:
# training loop # training loop
iter_num = 0 iter_num = 0
num_tokens = 0
best_val_loss = 1e9 best_val_loss = 1e9
t0 = time.time() t0 = time.time()
while True: while True:
@ -182,25 +204,26 @@ while True:
else: else:
lr = learning_rate lr = learning_rate
if iter_num % eval_interval == 0: if iter_num % eval_interval == 0 and gpu_id == 0:
losses = estimate_loss() losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
if wandb_log: if wandb_log:
wandb.log({ wandb.log({
"iter": iter_num, "iter": iter_num,
"num_tokens": num_tokens,
"train/loss": losses['train'], "train/loss": losses['train'],
"val/loss": losses['val'], "val/loss": losses['val'],
"lr": lr, "lr": lr,
}) })
if losses['val'] < best_val_loss: if losses['val'] < best_val_loss:
best_val_loss = losses['val'] best_val_loss = losses['val']
if iter_num > 0: # don't save checkpoints on very first iteration... raw_model = model.module if ddp else model
if iter_num > 0:
checkpoint = { checkpoint = {
'model': model.state_dict(), 'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'model_args': model_args, 'model_args': model_args,
'iter_num': iter_num, 'iter_num': iter_num,
'best_val_loss': best_val_loss,
} }
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
if iter_num == 0 and eval_only: if iter_num == 0 and eval_only:
@ -212,19 +235,19 @@ while True:
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
loss.backward() loss.backward()
# TODO: gradient clipping # TODO: gradient clipping evaluate need for
optimizer.step() optimizer.step()
t1 = time.time() t1 = time.time()
dt = t1 - t0 dt = t1 - t0
t0 = t1 t0 = t1
if iter_num % log_interval == 0: if iter_num % log_interval == 0 and gpu_id == 0:
lossf = loss.item() # loss as float. TODO CPU-GPU sync: profile, make sure not slow af lossf = loss.item() # loss as float. TODO CPU-GPU sync: profile, make sure not slow af
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms") print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
iter_num += 1 iter_num += 1
num_tokens += X.numel()
# termination conditions # termination conditions
if iter_num >= max_iters: if iter_num >= max_iters:
break break
destroy_process_group()