mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-23 00:20:29 +00:00
pull out dtype up top
This commit is contained in:
parent
e7bac659f5
commit
fa57d464d7
5
bench.py
5
bench.py
@ -14,6 +14,7 @@ torch.manual_seed(1337)
|
|||||||
|
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
block_size = 1024
|
block_size = 1024
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
# data loading init
|
# data loading init
|
||||||
real_data = True
|
real_data = True
|
||||||
@ -65,7 +66,7 @@ if profile:
|
|||||||
|
|
||||||
for k in range(num_steps):
|
for k in range(num_steps):
|
||||||
X, Y = get_batch('train')
|
X, Y = get_batch('train')
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
with torch.autocast(device_type='cuda', dtype=dtype):
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -83,7 +84,7 @@ else:
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for k in range(num_steps):
|
for k in range(num_steps):
|
||||||
X, Y = get_batch('train')
|
X, Y = get_batch('train')
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
with torch.autocast(device_type='cuda', dtype=dtype):
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(X, Y)
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
Loading…
Reference in New Issue
Block a user