From 5156fef93c15ef7e0dcdb35b4581a1dcd9c4d72e Mon Sep 17 00:00:00 2001 From: Kevin Slagle Date: Thu, 25 Jan 2024 11:41:01 -0800 Subject: [PATCH] fix np.memmap memory leak nn.memmap doesn't free memory that it accesses. Thus, the entire dataset gets stored in RAM as the dataset has been fully accessed. The simplest workaround on stackoverflow is to just recreate the memmap for each batch. The extra overhead is negligible. https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 --- train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index a482ab7..951bda9 100644 --- a/train.py +++ b/train.py @@ -113,10 +113,13 @@ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type= # poor man's data loader data_dir = os.path.join('data', dataset) -train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') -val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') def get_batch(split): - data = train_data if split == 'train' else val_data + # We recreate np.memmap every batch to avoid a memory leak, as per + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + if split == 'train': + data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') + else: + data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])