1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2025-01-18 21:22:53 +00:00

fix np.memmap memory leak

nn.memmap doesn't free memory that it accesses. Thus, the entire dataset gets stored in RAM as the dataset has been fully accessed. The simplest workaround on stackoverflow is to just recreate the memmap for each batch. The extra overhead is negligible.

https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
This commit is contained in:
Kevin Slagle 2024-01-25 11:41:01 -08:00 committed by GitHub
parent eba36e8464
commit 5156fef93c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -113,10 +113,13 @@ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=
# poor man's data loader
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
def get_batch(split):
data = train_data if split == 'train' else val_data
# We recreate np.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
if split == 'train':
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
else:
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])