mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-31 07:13:01 +00:00 
			
		
		
		
	fix minor bug where we have to scale the loss to account for gradient accumulation, which sums before backprop. note that this is not a major bug because AdamW is scale invariant. however, this did affect gradient clipping
This commit is contained in:
		
							
								
								
									
										6
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								train.py
									
									
									
									
									
								
							| @@ -93,6 +93,7 @@ else: | ||||
|     master_process = True | ||||
|     seed_offset = 0 | ||||
|     gradient_accumulation_steps *= 8 # simulate 8 gpus | ||||
| print("total number of tokens per iteration:", batch_size * block_size * gradient_accumulation_steps) | ||||
|  | ||||
| if master_process: | ||||
|     os.makedirs(out_dir, exist_ok=True) | ||||
| @@ -287,6 +288,7 @@ while True: | ||||
|             model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) | ||||
|         with ctx: | ||||
|             logits, loss = model(X, Y) | ||||
|             loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation | ||||
|         # immediately async prefetch next batch while model is doing the forward pass on the GPU | ||||
|         X, Y = get_batch('train') | ||||
|         # backward pass, with gradient scaling if training in fp16 | ||||
| @@ -306,7 +308,9 @@ while True: | ||||
|     dt = t1 - t0 | ||||
|     t0 = t1 | ||||
|     if iter_num % log_interval == 0 and master_process: | ||||
|         lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point | ||||
|         # get loss as float. note: this is a CPU-GPU sync point | ||||
|         # scale up to undo the division above, approximating the true total loss (exact would have been a sum) | ||||
|         lossf = loss.item() * gradient_accumulation_steps | ||||
|         if local_iter_num >= 5: # let the training loop settle a bit | ||||
|             mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) | ||||
|             running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Andrej Karpathy
					Andrej Karpathy