diff --git a/bench.py b/bench.py index 589c7ae..294c824 100644 --- a/bench.py +++ b/bench.py @@ -39,7 +39,7 @@ if real_data: ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) - x, y = x.to(device), y.to(device) + x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) return x, y else: # alternatively, if fixed data is desired to not care about data loading @@ -76,14 +76,15 @@ if profile: record_shapes=False, profile_memory=False, with_stack=False, # incurs an additional overhead, disable if not needed - with_flops=False, + with_flops=True, with_modules=False, # only for torchscript models atm ) as prof: + X, Y = get_batch('train') for k in range(num_steps): - X, Y = get_batch('train') with ctx: logits, loss = model(X, Y) + X, Y = get_batch('train') optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() @@ -98,10 +99,11 @@ else: torch.cuda.synchronize() for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark t0 = time.time() + X, Y = get_batch('train') for k in range(num_steps): - X, Y = get_batch('train') with ctx: logits, loss = model(X, Y) + X, Y = get_batch('train') optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() diff --git a/train.py b/train.py index 7881be0..a66fa91 100644 --- a/train.py +++ b/train.py @@ -111,7 +111,8 @@ def get_batch(split): ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) - x, y = x.to(device), y.to(device) + # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) + x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) return x, y # init these up here, can override if init_from='resume' (i.e. from a checkpoint) @@ -227,6 +228,7 @@ if wandb_log and master_process: wandb.init(project=wandb_project, name=wandb_run_name, config=config) # training loop +X, Y = get_batch('train') # fetch the very first batch t0 = time.time() while True: @@ -269,8 +271,6 @@ while True: # forward backward update, with optional gradient accumulation to simulate larger batch size # and using the GradScaler if data type is float16 for micro_step in range(gradient_accumulation_steps): - # fetch a batch - X, Y = get_batch('train') if ddp: # in DDP training we only need to sync gradients at the last micro step. # the official way to do this is with model.no_sync() context manager, but @@ -279,6 +279,8 @@ while True: model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) with ctx: logits, loss = model(X, Y) + # immediately async prefetch next batch while model is doing the forward pass on the GPU + X, Y = get_batch('train') # backward pass, with gradient scaling if training in fp16 scaler.scale(loss).backward() # clip the gradient