mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
who needs a dataloader? overlap the prefetching of the next batch with GPU compute, ehiding the data loading latency entirely. this saves about 1ms lol
This commit is contained in:
parent
46428d3142
commit
3fd4c0c5ef
10
bench.py
10
bench.py
@ -39,7 +39,7 @@ if real_data:
|
|||||||
ix = torch.randint(len(data) - block_size, (batch_size,))
|
ix = torch.randint(len(data) - block_size, (batch_size,))
|
||||||
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
||||||
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
||||||
x, y = x.to(device), y.to(device)
|
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
|
||||||
return x, y
|
return x, y
|
||||||
else:
|
else:
|
||||||
# alternatively, if fixed data is desired to not care about data loading
|
# alternatively, if fixed data is desired to not care about data loading
|
||||||
@ -76,14 +76,15 @@ if profile:
|
|||||||
record_shapes=False,
|
record_shapes=False,
|
||||||
profile_memory=False,
|
profile_memory=False,
|
||||||
with_stack=False, # incurs an additional overhead, disable if not needed
|
with_stack=False, # incurs an additional overhead, disable if not needed
|
||||||
with_flops=False,
|
with_flops=True,
|
||||||
with_modules=False, # only for torchscript models atm
|
with_modules=False, # only for torchscript models atm
|
||||||
) as prof:
|
) as prof:
|
||||||
|
|
||||||
|
X, Y = get_batch('train')
|
||||||
for k in range(num_steps):
|
for k in range(num_steps):
|
||||||
X, Y = get_batch('train')
|
|
||||||
with ctx:
|
with ctx:
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
|
X, Y = get_batch('train')
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -98,10 +99,11 @@ else:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
|
for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
X, Y = get_batch('train')
|
||||||
for k in range(num_steps):
|
for k in range(num_steps):
|
||||||
X, Y = get_batch('train')
|
|
||||||
with ctx:
|
with ctx:
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
|
X, Y = get_batch('train')
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
8
train.py
8
train.py
@ -111,7 +111,8 @@ def get_batch(split):
|
|||||||
ix = torch.randint(len(data) - block_size, (batch_size,))
|
ix = torch.randint(len(data) - block_size, (batch_size,))
|
||||||
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
||||||
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
||||||
x, y = x.to(device), y.to(device)
|
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
|
||||||
|
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
|
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
|
||||||
@ -227,6 +228,7 @@ if wandb_log and master_process:
|
|||||||
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
|
X, Y = get_batch('train') # fetch the very first batch
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
@ -269,8 +271,6 @@ while True:
|
|||||||
# forward backward update, with optional gradient accumulation to simulate larger batch size
|
# forward backward update, with optional gradient accumulation to simulate larger batch size
|
||||||
# and using the GradScaler if data type is float16
|
# and using the GradScaler if data type is float16
|
||||||
for micro_step in range(gradient_accumulation_steps):
|
for micro_step in range(gradient_accumulation_steps):
|
||||||
# fetch a batch
|
|
||||||
X, Y = get_batch('train')
|
|
||||||
if ddp:
|
if ddp:
|
||||||
# in DDP training we only need to sync gradients at the last micro step.
|
# in DDP training we only need to sync gradients at the last micro step.
|
||||||
# the official way to do this is with model.no_sync() context manager, but
|
# the official way to do this is with model.no_sync() context manager, but
|
||||||
@ -279,6 +279,8 @@ while True:
|
|||||||
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
|
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
|
||||||
with ctx:
|
with ctx:
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
|
# immediately async prefetch next batch while model is doing the forward pass on the GPU
|
||||||
|
X, Y = get_batch('train')
|
||||||
# backward pass, with gradient scaling if training in fp16
|
# backward pass, with gradient scaling if training in fp16
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
# clip the gradient
|
# clip the gradient
|
||||||
|
Loading…
Reference in New Issue
Block a user