diff --git a/train.py b/train.py index c4aaaa4..c8f3594 100644 --- a/train.py +++ b/train.py @@ -70,7 +70,7 @@ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchi backend = 'nccl' # 'nccl', 'gloo', etc. # system device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks -dtype = 'bfloat16' # 'float32' or 'bfloat16' +dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler compile = True # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] @@ -98,8 +98,8 @@ torch.manual_seed(1337 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast -# note: float16 would require us to change the code to use a GradScaler -ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype] +# note: float16 data type will automatically use a GradScaler +ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # poor man's data loader, TODO evaluate need for actual DataLoader @@ -173,6 +173,12 @@ if block_size < model.config.block_size: model.crop_block_size(block_size) model.to(device) +# initialize a GradScaler if data type is float16 +scaler = None +if dtype == 'float16': + print(f"Initializing Gradient Scaler to account for dtype: {dtype}") + scaler = torch.cuda.amp.GradScaler() + # optimizer optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2)) if init_from == 'resume': @@ -188,6 +194,7 @@ if compile: if ddp: model = DDP(model, device_ids=[ddp_local_rank]) +# helps estimate an arbitrarily accurate loss over either split using many batches @torch.no_grad() def estimate_loss(): out = {} @@ -263,7 +270,9 @@ while True: break # forward backward update, with optional gradient accumulation to simulate larger batch size + # and using the GradScaler if data type is float16 for micro_step in range(gradient_accumulation_steps): + # fetch a batch X, Y = get_batch('train') if ddp: # in DDP training we only need to sync gradients at the last micro step. @@ -273,10 +282,19 @@ while True: model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) with ctx: logits, loss = model(X, Y) - loss.backward() - if grad_clip != 0: + # backward pass, with gradient scaling if training in fp16 + scaler.scale(loss).backward() if scaler else loss.backward() + # clip the gradient + if grad_clip != 0.0: + scaler.unscale_(optimizer) if scaler else None torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) - optimizer.step() + # step the optimizer + if scaler: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True) # timing and logging