mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
Merge pull request #428 from kjslag/memmap-memory-leak
fix np.memmap memory leak
This commit is contained in:
commit
f68ac2200d
9
train.py
9
train.py
@ -113,10 +113,13 @@ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=
|
||||
|
||||
# poor man's data loader
|
||||
data_dir = os.path.join('data', dataset)
|
||||
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
||||
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
||||
def get_batch(split):
|
||||
data = train_data if split == 'train' else val_data
|
||||
# We recreate np.memmap every batch to avoid a memory leak, as per
|
||||
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
|
||||
if split == 'train':
|
||||
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
||||
else:
|
||||
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
||||
ix = torch.randint(len(data) - block_size, (batch_size,))
|
||||
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
|
||||
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
|
||||
|
Loading…
Reference in New Issue
Block a user