mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
pull out dtype up top
This commit is contained in:
parent
e7bac659f5
commit
fa57d464d7
5
bench.py
5
bench.py
@ -14,6 +14,7 @@ torch.manual_seed(1337)
|
||||
|
||||
batch_size = 8
|
||||
block_size = 1024
|
||||
dtype = torch.float16
|
||||
|
||||
# data loading init
|
||||
real_data = True
|
||||
@ -65,7 +66,7 @@ if profile:
|
||||
|
||||
for k in range(num_steps):
|
||||
X, Y = get_batch('train')
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
with torch.autocast(device_type='cuda', dtype=dtype):
|
||||
logits, loss = model(X, Y)
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
@ -83,7 +84,7 @@ else:
|
||||
t0 = time.time()
|
||||
for k in range(num_steps):
|
||||
X, Y = get_batch('train')
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
with torch.autocast(device_type='cuda', dtype=dtype):
|
||||
logits, loss = model(X, Y)
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
|
Loading…
Reference in New Issue
Block a user