diff --git a/bench.py b/bench.py index 9ebb280..434632c 100644 --- a/bench.py +++ b/bench.py @@ -14,6 +14,7 @@ torch.manual_seed(1337) batch_size = 8 block_size = 1024 +dtype = torch.float16 # data loading init real_data = True @@ -65,7 +66,7 @@ if profile: for k in range(num_steps): X, Y = get_batch('train') - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type='cuda', dtype=dtype): logits, loss = model(X, Y) optimizer.zero_grad(set_to_none=True) loss.backward() @@ -83,7 +84,7 @@ else: t0 = time.time() for k in range(num_steps): X, Y = get_batch('train') - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type='cuda', dtype=dtype): logits, loss = model(X, Y) optimizer.zero_grad(set_to_none=True) loss.backward()