diff --git a/train.py b/train.py index 5db73a8..581635d 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,7 @@ $ torchrun --standalone --nproc_per_node=4 train.py import os import time import math +from contextlib import nullcontext import numpy as np import torch @@ -56,11 +57,14 @@ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchi # DDP settings backend = 'nccl' # 'nccl', 'gloo', etc. # system -device = 'cuda' +device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. +dtype = 'bfloat16' # 'float32' or 'bfloat16' compile = True # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- exec(open('configurator.py').read()) # overrides from command line or config file # ----------------------------------------------------------------------------- + +# various inits, derived attributes, I/O setup ddp = int(os.environ.get('LOCAL_RANK', -1)) != -1 # is this a ddp run? if ddp: init_process_group(backend=backend) @@ -74,6 +78,10 @@ if gpu_id == 0: torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn +device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast +# note: float16 would require us to change the code to use a GradScaler +ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype] +ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # poor man's data loader, TODO evaluate need for actual DataLoader data_dir = os.path.join('data', dataset) @@ -156,7 +164,7 @@ def estimate_loss(): losses = torch.zeros(eval_iters) for k in range(eval_iters): X, Y = get_batch(split) - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + with ctx: logits, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean() @@ -226,7 +234,7 @@ while True: break X, Y = get_batch('train') - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + with ctx: logits, loss = model(X, Y) optimizer.zero_grad(set_to_none=True)