diff --git a/bench.py b/bench.py index 66a083b..25da9f2 100644 --- a/bench.py +++ b/bench.py @@ -111,5 +111,7 @@ else: print(f"{k}/{num_steps} loss: {lossf:.4f}") torch.cuda.synchronize() t1 = time.time() + dt = t1-t0 + mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt) if stage == 1: - print(f"time per iteration: {(t1-t0)/num_steps*1000:.4f}ms") + print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%") diff --git a/model.py b/model.py index 86f47e4..5a7f4dd 100644 --- a/model.py +++ b/model.py @@ -328,6 +328,22 @@ class GPT(nn.Module): return optimizer + def estimate_mfu(self, fwdbwd_per_iter, dt): + """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ + # first estimate the number of flops we do per iteration. + # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 + N = self.get_num_params() + cfg = self.config + L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size + flops_per_token = 6*N + 12*L*H*Q*T + flops_per_fwdbwd = flops_per_token * T + flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter + # express our flops throughput as ratio of A100 bfloat16 peak flops + flops_achieved = flops_per_iter * (1.0/dt) # per second + flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS + mfu = flops_achieved / flops_promised + return mfu + @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ diff --git a/train.py b/train.py index 1e5cb77..0867991 100644 --- a/train.py +++ b/train.py @@ -84,12 +84,14 @@ if ddp: init_process_group(backend=backend) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) # total number of training processes device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. seed_offset = ddp_rank # each process gets a different seed else: # if not ddp, we are running on a single gpu, and one process + world_size = 1 master_process = True seed_offset = 0 @@ -237,6 +239,8 @@ if wandb_log and master_process: # training loop X, Y = get_batch('train') # fetch the very first batch t0 = time.time() +local_iter_num = 0 # number of iterations in the lifetime of this process +running_mfu = -1.0 while True: # determine and set the learning rate for this iteration @@ -254,6 +258,7 @@ while True: "train/loss": losses['train'], "val/loss": losses['val'], "lr": lr, + "mfu": running_mfu*100, # convert to percentage }) if losses['val'] < best_val_loss or always_save_checkpoint: best_val_loss = losses['val'] @@ -303,8 +308,12 @@ while True: t0 = t1 if iter_num % log_interval == 0 and master_process: lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point - print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms") + if local_iter_num >= 5: # let the training loop settle a bit + mfu = model.estimate_mfu(batch_size * world_size * gradient_accumulation_steps, dt) + running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu + print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") iter_num += 1 + local_iter_num += 1 # termination conditions if iter_num > max_iters: