1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

pull out dtype up top

This commit is contained in:
Andrej Karpathy 2022-12-29 05:32:55 +00:00
parent e7bac659f5
commit fa57d464d7

View File

@ -14,6 +14,7 @@ torch.manual_seed(1337)
batch_size = 8
block_size = 1024
dtype = torch.float16
# data loading init
real_data = True
@ -65,7 +66,7 @@ if profile:
for k in range(num_steps):
X, Y = get_batch('train')
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
with torch.autocast(device_type='cuda', dtype=dtype):
logits, loss = model(X, Y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
@ -83,7 +84,7 @@ else:
t0 = time.time()
for k in range(num_steps):
X, Y = get_batch('train')
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
with torch.autocast(device_type='cuda', dtype=dtype):
logits, loss = model(X, Y)
optimizer.zero_grad(set_to_none=True)
loss.backward()