mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
add the estimation of model flops utilization (MFU), a very commonly looked at metric that estimates the token throughput in units of A100 bfloat16 peak flops (312 TFLOPS). this gives us a sense of the hardware utilization we're achieving
This commit is contained in:
parent
580902617c
commit
ab0718a7dd
4
bench.py
4
bench.py
@ -111,5 +111,7 @@ else:
|
|||||||
print(f"{k}/{num_steps} loss: {lossf:.4f}")
|
print(f"{k}/{num_steps} loss: {lossf:.4f}")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
dt = t1-t0
|
||||||
|
mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt)
|
||||||
if stage == 1:
|
if stage == 1:
|
||||||
print(f"time per iteration: {(t1-t0)/num_steps*1000:.4f}ms")
|
print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%")
|
||||||
|
16
model.py
16
model.py
@ -328,6 +328,22 @@ class GPT(nn.Module):
|
|||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
||||||
|
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
|
||||||
|
# first estimate the number of flops we do per iteration.
|
||||||
|
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
|
||||||
|
N = self.get_num_params()
|
||||||
|
cfg = self.config
|
||||||
|
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
|
||||||
|
flops_per_token = 6*N + 12*L*H*Q*T
|
||||||
|
flops_per_fwdbwd = flops_per_token * T
|
||||||
|
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
||||||
|
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
||||||
|
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
||||||
|
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
||||||
|
mfu = flops_achieved / flops_promised
|
||||||
|
return mfu
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
||||||
"""
|
"""
|
||||||
|
11
train.py
11
train.py
@ -84,12 +84,14 @@ if ddp:
|
|||||||
init_process_group(backend=backend)
|
init_process_group(backend=backend)
|
||||||
ddp_rank = int(os.environ['RANK'])
|
ddp_rank = int(os.environ['RANK'])
|
||||||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
world_size = int(os.environ['WORLD_SIZE']) # total number of training processes
|
||||||
device = f'cuda:{ddp_local_rank}'
|
device = f'cuda:{ddp_local_rank}'
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
seed_offset = ddp_rank # each process gets a different seed
|
seed_offset = ddp_rank # each process gets a different seed
|
||||||
else:
|
else:
|
||||||
# if not ddp, we are running on a single gpu, and one process
|
# if not ddp, we are running on a single gpu, and one process
|
||||||
|
world_size = 1
|
||||||
master_process = True
|
master_process = True
|
||||||
seed_offset = 0
|
seed_offset = 0
|
||||||
|
|
||||||
@ -237,6 +239,8 @@ if wandb_log and master_process:
|
|||||||
# training loop
|
# training loop
|
||||||
X, Y = get_batch('train') # fetch the very first batch
|
X, Y = get_batch('train') # fetch the very first batch
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
local_iter_num = 0 # number of iterations in the lifetime of this process
|
||||||
|
running_mfu = -1.0
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
# determine and set the learning rate for this iteration
|
# determine and set the learning rate for this iteration
|
||||||
@ -254,6 +258,7 @@ while True:
|
|||||||
"train/loss": losses['train'],
|
"train/loss": losses['train'],
|
||||||
"val/loss": losses['val'],
|
"val/loss": losses['val'],
|
||||||
"lr": lr,
|
"lr": lr,
|
||||||
|
"mfu": running_mfu*100, # convert to percentage
|
||||||
})
|
})
|
||||||
if losses['val'] < best_val_loss or always_save_checkpoint:
|
if losses['val'] < best_val_loss or always_save_checkpoint:
|
||||||
best_val_loss = losses['val']
|
best_val_loss = losses['val']
|
||||||
@ -303,8 +308,12 @@ while True:
|
|||||||
t0 = t1
|
t0 = t1
|
||||||
if iter_num % log_interval == 0 and master_process:
|
if iter_num % log_interval == 0 and master_process:
|
||||||
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
|
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
|
||||||
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
if local_iter_num >= 5: # let the training loop settle a bit
|
||||||
|
mfu = model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt)
|
||||||
|
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
|
||||||
|
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
|
||||||
iter_num += 1
|
iter_num += 1
|
||||||
|
local_iter_num += 1
|
||||||
|
|
||||||
# termination conditions
|
# termination conditions
|
||||||
if iter_num > max_iters:
|
if iter_num > max_iters:
|
||||||
|
Loading…
Reference in New Issue
Block a user