diff --git a/train.py b/train.py index 6994d65..aea9f0e 100644 --- a/train.py +++ b/train.py @@ -113,7 +113,12 @@ def get_batch(split): x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) - x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) + if "cuda" in device: + # GPU training + x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) + else: + # CPU or MPS training + x, y = x.to(device), y.to(device) return x, y # init these up here, can override if init_from='resume' (i.e. from a checkpoint)