mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
add pytorch profiler support. not sure how to support both profiler and simple benchmarking, a bit gnarly atm hmm
This commit is contained in:
parent
b760ef1358
commit
3000cf5dda
63
bench.py
63
bench.py
@ -7,7 +7,7 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
from model import GPTConfig, GPT
|
from model import GPTConfig, GPT
|
||||||
|
|
||||||
device = 'cuda:3'
|
device = 'cuda'
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
||||||
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
||||||
torch.manual_seed(1337)
|
torch.manual_seed(1337)
|
||||||
@ -45,23 +45,52 @@ model.to(device)
|
|||||||
|
|
||||||
optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95))
|
optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95))
|
||||||
|
|
||||||
burn_in = 10 # number of burn in steps where we don't measure time
|
profile = False # use pytorch profiler, or just simple benchmarking?
|
||||||
num_steps = 30
|
if profile:
|
||||||
for k in range(num_steps):
|
# useful docs on pytorch profiler:
|
||||||
|
# - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
|
||||||
|
# - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile
|
||||||
|
wait, warmup, active = 5, 5, 5
|
||||||
|
num_steps = wait + warmup + active
|
||||||
|
with torch.profiler.profile(
|
||||||
|
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||||
|
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True,
|
||||||
|
with_stack=True, # incurs an additional overhead, disable if not needed
|
||||||
|
with_flops=True,
|
||||||
|
with_modules=False, # only for torchscript models atm
|
||||||
|
) as prof:
|
||||||
|
|
||||||
if k == burn_in:
|
for k in range(num_steps):
|
||||||
t0 = time.time() # start the timer
|
X, Y = get_batch('train')
|
||||||
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||||
|
logits, loss = model(X, Y)
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
lossf = loss.item()
|
||||||
|
print(f"{k}/{num_steps} loss: {lossf:.4f}")
|
||||||
|
|
||||||
X, Y = get_batch('train')
|
prof.step() # notify the profiler at end of each step
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
|
||||||
logits, loss = model(X, Y)
|
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
else:
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
lossf = loss.item()
|
|
||||||
print(f"{k}/{num_steps} loss: {lossf:.4f}")
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
# simple benchmarking
|
||||||
t1 = time.time()
|
torch.cuda.synchronize()
|
||||||
print("time in ms per iteration: %.2f" % ((t1 - t0) / (num_steps - burn_in) * 1000))
|
for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
|
||||||
|
t0 = time.time()
|
||||||
|
for k in range(num_steps):
|
||||||
|
X, Y = get_batch('train')
|
||||||
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||||
|
logits, loss = model(X, Y)
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
lossf = loss.item()
|
||||||
|
print(f"{k}/{num_steps} loss: {lossf:.4f}")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t1 = time.time()
|
||||||
|
if stage == 1:
|
||||||
|
print(f"time per iteration: {(t1-t0)/num_steps*1000:.4f}ms")
|
||||||
|
Loading…
Reference in New Issue
Block a user