mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
bugfix we have to call the raw_model's estimate_mfu ty @jprobichaud for original PR
This commit is contained in:
parent
f83dd034e1
commit
ab21d6c15d
4
train.py
4
train.py
@ -240,6 +240,7 @@ if wandb_log and master_process:
|
|||||||
X, Y = get_batch('train') # fetch the very first batch
|
X, Y = get_batch('train') # fetch the very first batch
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
local_iter_num = 0 # number of iterations in the lifetime of this process
|
local_iter_num = 0 # number of iterations in the lifetime of this process
|
||||||
|
raw_model = model.module if ddp else model # unwrap DDP container if needed
|
||||||
running_mfu = -1.0
|
running_mfu = -1.0
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
@ -262,7 +263,6 @@ while True:
|
|||||||
})
|
})
|
||||||
if losses['val'] < best_val_loss or always_save_checkpoint:
|
if losses['val'] < best_val_loss or always_save_checkpoint:
|
||||||
best_val_loss = losses['val']
|
best_val_loss = losses['val']
|
||||||
raw_model = model.module if ddp else model
|
|
||||||
if iter_num > 0:
|
if iter_num > 0:
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
'model': raw_model.state_dict(),
|
'model': raw_model.state_dict(),
|
||||||
@ -309,7 +309,7 @@ while True:
|
|||||||
if iter_num % log_interval == 0 and master_process:
|
if iter_num % log_interval == 0 and master_process:
|
||||||
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
|
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
|
||||||
if local_iter_num >= 5: # let the training loop settle a bit
|
if local_iter_num >= 5: # let the training loop settle a bit
|
||||||
mfu = model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt)
|
mfu = raw_model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt)
|
||||||
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
|
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
|
||||||
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
|
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
|
||||||
iter_num += 1
|
iter_num += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user