mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
very slight refactor, bit cleaner
This commit is contained in:
parent
dc149891b6
commit
e108ffb973
6
train.py
6
train.py
@ -112,12 +112,10 @@ def get_batch(split):
|
|||||||
ix = torch.randint(len(data) - block_size, (batch_size,))
|
ix = torch.randint(len(data) - block_size, (batch_size,))
|
||||||
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
||||||
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
||||||
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
|
if device_type == 'cuda':
|
||||||
if "cuda" in device:
|
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
|
||||||
# GPU training
|
|
||||||
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
|
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
# CPU or MPS training
|
|
||||||
x, y = x.to(device), y.to(device)
|
x, y = x.to(device), y.to(device)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user