diff --git a/train.py b/train.py index 2dc22eb..3bf1803 100644 --- a/train.py +++ b/train.py @@ -140,7 +140,7 @@ def get_batch(split, step): else: data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') d_rng = random.Random(f"{split}-{step}-{seed}") - ix = [ d_rng.randint(0, len(data) - block_size) for _ in range(batch_size) ] + ix = [ d_rng.randint(0, len(data) - block_size) for _ in range(batch_size) ] # TODO: I think this needs to be len(data) - block_size - 1 but changing it breaks determinism badly x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) if device_type == 'cuda':