1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2025-10-21 18:47:39 +00:00

add the estimation of model flops utilization (MFU), a very commonly looked at metric that estimates the token throughput in units of A100 bfloat16 peak flops (312 TFLOPS). this gives us a sense of the hardware utilization we're achieving

This commit is contained in:
Andrej Karpathy
2023-02-05 00:48:58 +00:00
parent 580902617c
commit ab0718a7dd
3 changed files with 29 additions and 2 deletions

View File

@@ -84,12 +84,14 @@ if ddp:
init_process_group(backend=backend)
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE']) # total number of training processes
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
seed_offset = ddp_rank # each process gets a different seed
else:
# if not ddp, we are running on a single gpu, and one process
world_size = 1
master_process = True
seed_offset = 0
@@ -237,6 +239,8 @@ if wandb_log and master_process:
# training loop
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
running_mfu = -1.0
while True:
# determine and set the learning rate for this iteration
@@ -254,6 +258,7 @@ while True:
"train/loss": losses['train'],
"val/loss": losses['val'],
"lr": lr,
"mfu": running_mfu*100, # convert to percentage
})
if losses['val'] < best_val_loss or always_save_checkpoint:
best_val_loss = losses['val']
@@ -303,8 +308,12 @@ while True:
t0 = t1
if iter_num % log_interval == 0 and master_process:
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
if local_iter_num >= 5: # let the training loop settle a bit
mfu = model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
iter_num += 1
local_iter_num += 1
# termination conditions
if iter_num > max_iters: