mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
small usability tweaks to bench
This commit is contained in:
parent
d995c22128
commit
d01863ef01
16
bench.py
16
bench.py
@ -9,13 +9,15 @@ import torch
|
|||||||
from model import GPTConfig, GPT
|
from model import GPTConfig, GPT
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
batch_size = 8
|
batch_size = 12
|
||||||
block_size = 1024
|
block_size = 1024
|
||||||
bias = True
|
bias = False
|
||||||
|
real_data = True
|
||||||
seed = 1337
|
seed = 1337
|
||||||
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
|
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
|
||||||
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
|
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
|
||||||
compile = True # use PyTorch 2.0 to compile the model to be faster
|
compile = True # use PyTorch 2.0 to compile the model to be faster
|
||||||
|
profile = False # use pytorch profiler, or just simple benchmarking?
|
||||||
exec(open('configurator.py').read()) # overrides from command line or config file
|
exec(open('configurator.py').read()) # overrides from command line or config file
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
@ -28,7 +30,6 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
|
|||||||
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||||
|
|
||||||
# data loading init
|
# data loading init
|
||||||
real_data = True
|
|
||||||
if real_data:
|
if real_data:
|
||||||
dataset = 'openwebtext'
|
dataset = 'openwebtext'
|
||||||
data_dir = os.path.join('data', dataset)
|
data_dir = os.path.join('data', dataset)
|
||||||
@ -62,7 +63,6 @@ if compile:
|
|||||||
print("Compiling model...")
|
print("Compiling model...")
|
||||||
model = torch.compile(model) # pytorch 2.0
|
model = torch.compile(model) # pytorch 2.0
|
||||||
|
|
||||||
profile = False # use pytorch profiler, or just simple benchmarking?
|
|
||||||
if profile:
|
if profile:
|
||||||
# useful docs on pytorch profiler:
|
# useful docs on pytorch profiler:
|
||||||
# - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
|
# - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
|
||||||
@ -73,10 +73,10 @@ if profile:
|
|||||||
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||||
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
|
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
|
on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
|
||||||
record_shapes=True,
|
record_shapes=False,
|
||||||
profile_memory=True,
|
profile_memory=False,
|
||||||
with_stack=True, # incurs an additional overhead, disable if not needed
|
with_stack=False, # incurs an additional overhead, disable if not needed
|
||||||
with_flops=True,
|
with_flops=False,
|
||||||
with_modules=False, # only for torchscript models atm
|
with_modules=False, # only for torchscript models atm
|
||||||
) as prof:
|
) as prof:
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user