mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-31 15:23:01 +00:00 
			
		
		
		
	merge, make cleaner, careful with gradient clipping when using grad scaler fp16 training
This commit is contained in:
		
							
								
								
									
										30
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								train.py
									
									
									
									
									
								
							| @@ -70,7 +70,7 @@ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchi | ||||
| backend = 'nccl' # 'nccl', 'gloo', etc. | ||||
| # system | ||||
| device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks | ||||
| dtype = 'bfloat16' # 'float32' or 'bfloat16' | ||||
| dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler | ||||
| compile = True # use PyTorch 2.0 to compile the model to be faster | ||||
| # ----------------------------------------------------------------------------- | ||||
| config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] | ||||
| @@ -98,8 +98,8 @@ torch.manual_seed(1337 + seed_offset) | ||||
| torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul | ||||
| torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn | ||||
| device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast | ||||
| # note: float16 would require us to change the code to use a GradScaler | ||||
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype] | ||||
| # note: float16 data type will automatically use a GradScaler | ||||
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] | ||||
| ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) | ||||
|  | ||||
| # poor man's data loader, TODO evaluate need for actual DataLoader | ||||
| @@ -173,6 +173,12 @@ if block_size < model.config.block_size: | ||||
|     model.crop_block_size(block_size) | ||||
| model.to(device) | ||||
|  | ||||
| # initialize a GradScaler if data type is float16 | ||||
| scaler = None | ||||
| if dtype == 'float16': | ||||
|     print(f"Initializing Gradient Scaler to account for dtype: {dtype}") | ||||
|     scaler = torch.cuda.amp.GradScaler() | ||||
|  | ||||
| # optimizer | ||||
| optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2)) | ||||
| if init_from == 'resume': | ||||
| @@ -188,6 +194,7 @@ if compile: | ||||
| if ddp: | ||||
|     model = DDP(model, device_ids=[ddp_local_rank]) | ||||
|  | ||||
| # helps estimate an arbitrarily accurate loss over either split using many batches | ||||
| @torch.no_grad() | ||||
| def estimate_loss(): | ||||
|     out = {} | ||||
| @@ -263,7 +270,9 @@ while True: | ||||
|         break | ||||
|  | ||||
|     # forward backward update, with optional gradient accumulation to simulate larger batch size | ||||
|     # and using the GradScaler if data type is float16 | ||||
|     for micro_step in range(gradient_accumulation_steps): | ||||
|         # fetch a batch | ||||
|         X, Y = get_batch('train') | ||||
|         if ddp: | ||||
|             # in DDP training we only need to sync gradients at the last micro step. | ||||
| @@ -273,10 +282,19 @@ while True: | ||||
|             model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) | ||||
|         with ctx: | ||||
|             logits, loss = model(X, Y) | ||||
|         loss.backward() | ||||
|     if grad_clip != 0: | ||||
|         # backward pass, with gradient scaling if training in fp16 | ||||
|         scaler.scale(loss).backward() if scaler else loss.backward() | ||||
|     # clip the gradient | ||||
|     if grad_clip != 0.0: | ||||
|         scaler.unscale_(optimizer) if scaler else None | ||||
|         torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | ||||
|     optimizer.step() | ||||
|     # step the optimizer | ||||
|     if scaler: | ||||
|         scaler.step(optimizer) | ||||
|         scaler.update() | ||||
|     else: | ||||
|         optimizer.step() | ||||
|     # flush the gradients as soon as we can, no need for this memory anymore | ||||
|     optimizer.zero_grad(set_to_none=True) | ||||
|  | ||||
|     # timing and logging | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Andrej Karpathy
					Andrej Karpathy