mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
clean up TODOs a bit, they are stale
This commit is contained in:
parent
25d95dbd65
commit
a74e8363a2
4
train.py
4
train.py
@ -103,7 +103,7 @@ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.aut
|
||||
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
||||
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
|
||||
# poor man's data loader, TODO evaluate need for actual DataLoader
|
||||
# poor man's data loader
|
||||
data_dir = os.path.join('data', dataset)
|
||||
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
||||
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
||||
@ -302,7 +302,7 @@ while True:
|
||||
dt = t1 - t0
|
||||
t0 = t1
|
||||
if iter_num % log_interval == 0 and master_process:
|
||||
lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow
|
||||
lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
|
||||
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
||||
iter_num += 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user