1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

who needs a dataloader? overlap the prefetching of the next batch with GPU compute, ehiding the data loading latency entirely. this saves about 1ms lol

This commit is contained in:
Andrej Karpathy 2023-02-04 02:52:48 +00:00
parent 46428d3142
commit 3fd4c0c5ef
2 changed files with 11 additions and 7 deletions

View File

@ -39,7 +39,7 @@ if real_data:
ix = torch.randint(len(data) - block_size, (batch_size,)) ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
x, y = x.to(device), y.to(device) x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
return x, y return x, y
else: else:
# alternatively, if fixed data is desired to not care about data loading # alternatively, if fixed data is desired to not care about data loading
@ -76,14 +76,15 @@ if profile:
record_shapes=False, record_shapes=False,
profile_memory=False, profile_memory=False,
with_stack=False, # incurs an additional overhead, disable if not needed with_stack=False, # incurs an additional overhead, disable if not needed
with_flops=False, with_flops=True,
with_modules=False, # only for torchscript models atm with_modules=False, # only for torchscript models atm
) as prof: ) as prof:
for k in range(num_steps):
X, Y = get_batch('train') X, Y = get_batch('train')
for k in range(num_steps):
with ctx: with ctx:
logits, loss = model(X, Y) logits, loss = model(X, Y)
X, Y = get_batch('train')
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -98,10 +99,11 @@ else:
torch.cuda.synchronize() torch.cuda.synchronize()
for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
t0 = time.time() t0 = time.time()
for k in range(num_steps):
X, Y = get_batch('train') X, Y = get_batch('train')
for k in range(num_steps):
with ctx: with ctx:
logits, loss = model(X, Y) logits, loss = model(X, Y)
X, Y = get_batch('train')
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -111,7 +111,8 @@ def get_batch(split):
ix = torch.randint(len(data) - block_size, (batch_size,)) ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
x, y = x.to(device), y.to(device) # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
return x, y return x, y
# init these up here, can override if init_from='resume' (i.e. from a checkpoint) # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
@ -227,6 +228,7 @@ if wandb_log and master_process:
wandb.init(project=wandb_project, name=wandb_run_name, config=config) wandb.init(project=wandb_project, name=wandb_run_name, config=config)
# training loop # training loop
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time() t0 = time.time()
while True: while True:
@ -269,8 +271,6 @@ while True:
# forward backward update, with optional gradient accumulation to simulate larger batch size # forward backward update, with optional gradient accumulation to simulate larger batch size
# and using the GradScaler if data type is float16 # and using the GradScaler if data type is float16
for micro_step in range(gradient_accumulation_steps): for micro_step in range(gradient_accumulation_steps):
# fetch a batch
X, Y = get_batch('train')
if ddp: if ddp:
# in DDP training we only need to sync gradients at the last micro step. # in DDP training we only need to sync gradients at the last micro step.
# the official way to do this is with model.no_sync() context manager, but # the official way to do this is with model.no_sync() context manager, but
@ -279,6 +279,8 @@ while True:
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
with ctx: with ctx:
logits, loss = model(X, Y) logits, loss = model(X, Y)
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = get_batch('train')
# backward pass, with gradient scaling if training in fp16 # backward pass, with gradient scaling if training in fp16
scaler.scale(loss).backward() scaler.scale(loss).backward()
# clip the gradient # clip the gradient