mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2025-10-21 02:27:40 +00:00
add support for DDP training. the scaling timings right now do not look good by default, have to dig more into
This commit is contained in:
79
train.py
79
train.py
@@ -1,6 +1,12 @@
|
||||
"""
|
||||
Train a GPT model on a dataset of text. One GPU version.
|
||||
The text is assumed to pre-tokenized and inside files train.pt and val.pt
|
||||
This training script can be run both on a single gpu in debug mode,
|
||||
and also in a larger training run with distributed data parallel (ddp).
|
||||
|
||||
To run in debug mode example:
|
||||
$ python train.py --batch_size=32 --other=args
|
||||
|
||||
To run DDP on 4 gpus on one node, example:
|
||||
$ torchrun --standalone --nproc_per_node=4 train.py
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -9,9 +15,11 @@ import time
|
||||
import math
|
||||
from ast import literal_eval
|
||||
|
||||
import wandb
|
||||
import numpy as np
|
||||
import torch
|
||||
import wandb
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed import init_process_group, destroy_process_group
|
||||
|
||||
from model import GPTConfig, GPT
|
||||
|
||||
@@ -49,9 +57,11 @@ decay_lr = True # whether to decay the learning rate
|
||||
warmup_iters = 2000 # how many steps to warm up for
|
||||
lr_decay_iters = 320000 # how many steps to decay the learning rate for
|
||||
min_lr = 1e-5 # minimum learning rate
|
||||
# DDP settings
|
||||
backend = 'nccl' # 'nccl', 'gloo', etc.
|
||||
# -----------------------------------------------------------------------------
|
||||
# poor man's Configurator. Potentially a bad idea. Example usage:
|
||||
# python train.py override_file --batch_size=32
|
||||
# $ python train.py override_file --batch_size=32
|
||||
# this will first run config/override_file.py, then override batch_size to 32
|
||||
for arg in sys.argv[1:]:
|
||||
if '=' not in arg:
|
||||
@@ -71,7 +81,7 @@ for arg in sys.argv[1:]:
|
||||
try:
|
||||
# attempt to eval it it (e.g. if bool, number, or etc)
|
||||
attempt = literal_eval(val)
|
||||
except SyntaxError:
|
||||
except (SyntaxError, ValueError):
|
||||
# if that goes wrong, just use the string
|
||||
attempt = val
|
||||
# ensure the types match ok
|
||||
@@ -82,13 +92,21 @@ for arg in sys.argv[1:]:
|
||||
else:
|
||||
raise ValueError(f"Unknown config key: {key}")
|
||||
# -----------------------------------------------------------------------------
|
||||
ddp = int(os.environ.get('LOCAL_RANK', -1)) != -1 # is this a ddp run?
|
||||
if ddp:
|
||||
init_process_group(backend=backend)
|
||||
gpu_id = int(os.environ["LOCAL_RANK"])
|
||||
device = f"cuda:{gpu_id}"
|
||||
else:
|
||||
gpu_id = 0 # gpu_id 0 means this is the (single) master process, basically
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
torch.manual_seed(1337)
|
||||
if gpu_id == 0:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
||||
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
||||
|
||||
# poor man's data loader, TODO use real DataLoader...
|
||||
# poor man's data loader, TODO evaluate need for actual DataLoader
|
||||
data_dir = os.path.join('data', dataset)
|
||||
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
||||
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
||||
@@ -101,16 +119,16 @@ def get_batch(split):
|
||||
return x, y
|
||||
|
||||
# model init
|
||||
# TODO I don't love this whole part/API yet
|
||||
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
|
||||
if init_from == 'scratch':
|
||||
# init a new model from scratch
|
||||
gptconf = GPTConfig(**model_args)
|
||||
model = GPT(gptconf)
|
||||
elif init_from == 'resume':
|
||||
# resume training from a checkpoint. TODO: do we resume iter_num etc too? (yes...)
|
||||
# resume training from a checkpoint.
|
||||
# TODO: should we also resume iter_num and best_val_loss?
|
||||
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
||||
checkpoint = torch.load(ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
checkpoint_model_args = checkpoint['model_args']
|
||||
for k, v in model_args.items():
|
||||
assert checkpoint_model_args[k] == v, "for now"
|
||||
@@ -120,10 +138,20 @@ elif init_from == 'resume':
|
||||
elif init_from.startswith('gpt2'):
|
||||
# initialize from OpenAI GPT-2 weights
|
||||
model = GPT.from_pretrained(init_from)
|
||||
if block_size < model.block_size:
|
||||
model.crop_block_size(block_size)
|
||||
# crop down the model block size if desired
|
||||
if block_size < model.block_size:
|
||||
model.crop_block_size(block_size)
|
||||
model.to(device)
|
||||
|
||||
# optimizer
|
||||
optimizer = model.configure_optimizers(weight_decay, learning_rate, betas)
|
||||
if init_from == 'resume':
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
# wrap model into DDP container
|
||||
if ddp:
|
||||
model = DDP(model, device_ids=[gpu_id])
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_loss():
|
||||
out = {}
|
||||
@@ -139,11 +167,6 @@ def estimate_loss():
|
||||
model.train()
|
||||
return out
|
||||
|
||||
# optimizer
|
||||
optimizer = model.configure_optimizers(weight_decay, learning_rate, betas)
|
||||
if init_from == 'resume':
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
# learning rate decay scheduler (cosine with warmup)
|
||||
def get_lr(iter):
|
||||
# 1) linear warmup for warmup_iters steps
|
||||
@@ -155,11 +178,11 @@ def get_lr(iter):
|
||||
# 3) in between, use cosine decay down to min learning rate
|
||||
decay_ratio = (iter - warmup_iters) / (lr_decay_iters - warmup_iters)
|
||||
assert 0 <= decay_ratio <= 1
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # ranges 0..1
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
|
||||
return min_lr + coeff * (learning_rate - min_lr)
|
||||
|
||||
# logging
|
||||
if wandb_log:
|
||||
if wandb_log and gpu_id == 0:
|
||||
wandb.init(project=wandb_project, entity=wandb_entity, name=wandb_run_name)
|
||||
wandb.config = {
|
||||
"batch_size": batch_size,
|
||||
@@ -169,7 +192,6 @@ if wandb_log:
|
||||
|
||||
# training loop
|
||||
iter_num = 0
|
||||
num_tokens = 0
|
||||
best_val_loss = 1e9
|
||||
t0 = time.time()
|
||||
while True:
|
||||
@@ -182,25 +204,26 @@ while True:
|
||||
else:
|
||||
lr = learning_rate
|
||||
|
||||
if iter_num % eval_interval == 0:
|
||||
if iter_num % eval_interval == 0 and gpu_id == 0:
|
||||
losses = estimate_loss()
|
||||
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
||||
if wandb_log:
|
||||
wandb.log({
|
||||
"iter": iter_num,
|
||||
"num_tokens": num_tokens,
|
||||
"train/loss": losses['train'],
|
||||
"val/loss": losses['val'],
|
||||
"lr": lr,
|
||||
})
|
||||
if losses['val'] < best_val_loss:
|
||||
best_val_loss = losses['val']
|
||||
if iter_num > 0: # don't save checkpoints on very first iteration...
|
||||
raw_model = model.module if ddp else model
|
||||
if iter_num > 0:
|
||||
checkpoint = {
|
||||
'model': model.state_dict(),
|
||||
'model': raw_model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'model_args': model_args,
|
||||
'iter_num': iter_num,
|
||||
'best_val_loss': best_val_loss,
|
||||
}
|
||||
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
|
||||
if iter_num == 0 and eval_only:
|
||||
@@ -212,19 +235,19 @@ while True:
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
# TODO: gradient clipping
|
||||
# TODO: gradient clipping evaluate need for
|
||||
optimizer.step()
|
||||
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
t0 = t1
|
||||
if iter_num % log_interval == 0:
|
||||
if iter_num % log_interval == 0 and gpu_id == 0:
|
||||
lossf = loss.item() # loss as float. TODO CPU-GPU sync: profile, make sure not slow af
|
||||
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
||||
iter_num += 1
|
||||
num_tokens += X.numel()
|
||||
|
||||
# termination conditions
|
||||
if iter_num >= max_iters:
|
||||
break
|
||||
|
||||
destroy_process_group()
|
||||
|
Reference in New Issue
Block a user