1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

bugfix we have to call the raw_model's estimate_mfu ty @jprobichaud for original PR

This commit is contained in:
Andrej Karpathy 2023-02-06 19:55:35 +00:00
parent f83dd034e1
commit ab21d6c15d

View File

@ -240,6 +240,7 @@ if wandb_log and master_process:
X, Y = get_batch('train') # fetch the very first batch X, Y = get_batch('train') # fetch the very first batch
t0 = time.time() t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0 running_mfu = -1.0
while True: while True:
@ -262,7 +263,6 @@ while True:
}) })
if losses['val'] < best_val_loss or always_save_checkpoint: if losses['val'] < best_val_loss or always_save_checkpoint:
best_val_loss = losses['val'] best_val_loss = losses['val']
raw_model = model.module if ddp else model
if iter_num > 0: if iter_num > 0:
checkpoint = { checkpoint = {
'model': raw_model.state_dict(), 'model': raw_model.state_dict(),
@ -309,7 +309,7 @@ while True:
if iter_num % log_interval == 0 and master_process: if iter_num % log_interval == 0 and master_process:
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
if local_iter_num >= 5: # let the training loop settle a bit if local_iter_num >= 5: # let the training loop settle a bit
mfu = model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt) mfu = raw_model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
iter_num += 1 iter_num += 1