diff --git a/data/openwebtext/prepare.py b/data/openwebtext/prepare.py index 7710bdf..8dc30e1 100644 --- a/data/openwebtext/prepare.py +++ b/data/openwebtext/prepare.py @@ -54,12 +54,16 @@ for split, dset in tokenized.items(): filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) + total_batches = 1024 - print(f"writing {filename}...") idx = 0 - for example in tqdm(dset): - arr[idx : idx + example['len']] = example['ids'] - idx += example['len'] + for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): + # Batch together samples for faster write + batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') + arr_batch = np.concatenate(batch['ids']) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) arr.flush() # train.bin is ~17GB, val.bin ~8.5MB