diff --git a/train.py b/train.py index a482ab7..951bda9 100644 --- a/train.py +++ b/train.py @@ -113,10 +113,13 @@ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type= # poor man's data loader data_dir = os.path.join('data', dataset) -train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') -val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') def get_batch(split): - data = train_data if split == 'train' else val_data + # We recreate np.memmap every batch to avoid a memory leak, as per + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + if split == 'train': + data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') + else: + data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])