mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
fix np.memmap memory leak
nn.memmap doesn't free memory that it accesses. Thus, the entire dataset gets stored in RAM as the dataset has been fully accessed. The simplest workaround on stackoverflow is to just recreate the memmap for each batch. The extra overhead is negligible. https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
This commit is contained in:
parent
eba36e8464
commit
5156fef93c
9
train.py
9
train.py
@ -113,10 +113,13 @@ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=
|
|||||||
|
|
||||||
# poor man's data loader
|
# poor man's data loader
|
||||||
data_dir = os.path.join('data', dataset)
|
data_dir = os.path.join('data', dataset)
|
||||||
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
|
||||||
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
|
||||||
def get_batch(split):
|
def get_batch(split):
|
||||||
data = train_data if split == 'train' else val_data
|
# We recreate np.memmap every batch to avoid a memory leak, as per
|
||||||
|
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
|
||||||
|
if split == 'train':
|
||||||
|
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
||||||
|
else:
|
||||||
|
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
||||||
ix = torch.randint(len(data) - block_size, (batch_size,))
|
ix = torch.randint(len(data) - block_size, (batch_size,))
|
||||||
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
||||||
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
||||||
|
Loading…
Reference in New Issue
Block a user