mirror of
				https://github.com/osmarks/nanogpt-experiments.git
				synced 2025-10-26 12:57:41 +00:00 
			
		
		
		
	bugfix we have to call the raw_model's estimate_mfu ty @jprobichaud for original PR
This commit is contained in:
		
							
								
								
									
										4
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								train.py
									
									
									
									
									
								
							| @@ -240,6 +240,7 @@ if wandb_log and master_process: | |||||||
| X, Y = get_batch('train') # fetch the very first batch | X, Y = get_batch('train') # fetch the very first batch | ||||||
| t0 = time.time() | t0 = time.time() | ||||||
| local_iter_num = 0 # number of iterations in the lifetime of this process | local_iter_num = 0 # number of iterations in the lifetime of this process | ||||||
|  | raw_model = model.module if ddp else model # unwrap DDP container if needed | ||||||
| running_mfu = -1.0 | running_mfu = -1.0 | ||||||
| while True: | while True: | ||||||
|  |  | ||||||
| @@ -262,7 +263,6 @@ while True: | |||||||
|             }) |             }) | ||||||
|         if losses['val'] < best_val_loss or always_save_checkpoint: |         if losses['val'] < best_val_loss or always_save_checkpoint: | ||||||
|             best_val_loss = losses['val'] |             best_val_loss = losses['val'] | ||||||
|             raw_model = model.module if ddp else model |  | ||||||
|             if iter_num > 0: |             if iter_num > 0: | ||||||
|                 checkpoint = { |                 checkpoint = { | ||||||
|                     'model': raw_model.state_dict(), |                     'model': raw_model.state_dict(), | ||||||
| @@ -309,7 +309,7 @@ while True: | |||||||
|     if iter_num % log_interval == 0 and master_process: |     if iter_num % log_interval == 0 and master_process: | ||||||
|         lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point |         lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point | ||||||
|         if local_iter_num >= 5: # let the training loop settle a bit |         if local_iter_num >= 5: # let the training loop settle a bit | ||||||
|             mfu = model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt) |             mfu = raw_model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt) | ||||||
|             running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu |             running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu | ||||||
|         print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") |         print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") | ||||||
|     iter_num += 1 |     iter_num += 1 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Andrej Karpathy
					Andrej Karpathy