mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-31 15:23:01 +00:00 
			
		
		
		
	Merge pull request #106 from YassineYousfi/master
use the ``enabled`` arg in GradScaler
This commit is contained in:
		
							
								
								
									
										20
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								train.py
									
									
									
									
									
								
							| @@ -173,11 +173,8 @@ if block_size < model.config.block_size: | ||||
|     model.crop_block_size(block_size) | ||||
| model.to(device) | ||||
|  | ||||
| # initialize a GradScaler if data type is float16 | ||||
| scaler = None | ||||
| if dtype == 'float16': | ||||
|     print(f"Initializing Gradient Scaler to account for dtype: {dtype}") | ||||
|     scaler = torch.cuda.amp.GradScaler() | ||||
| # initialize a GradScaler. If enabled=False scaler is a no-op | ||||
| scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) | ||||
|  | ||||
| # optimizer | ||||
| optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2)) | ||||
| @@ -283,17 +280,14 @@ while True: | ||||
|         with ctx: | ||||
|             logits, loss = model(X, Y) | ||||
|         # backward pass, with gradient scaling if training in fp16 | ||||
|         scaler.scale(loss).backward() if scaler else loss.backward() | ||||
|         scaler.scale(loss).backward() | ||||
|     # clip the gradient | ||||
|     if grad_clip != 0.0: | ||||
|         scaler.unscale_(optimizer) if scaler else None | ||||
|         scaler.unscale_(optimizer) | ||||
|         torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | ||||
|     # step the optimizer | ||||
|     if scaler: | ||||
|         scaler.step(optimizer) | ||||
|         scaler.update() | ||||
|     else: | ||||
|         optimizer.step() | ||||
|     # step the optimizer and scaler if training in fp16 | ||||
|     scaler.step(optimizer) | ||||
|     scaler.update() | ||||
|     # flush the gradients as soon as we can, no need for this memory anymore | ||||
|     optimizer.zero_grad(set_to_none=True) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Andrej
					Andrej