1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-02-01 03:39:13 +00:00
meme-search-engine/sae/shared.py

12 lines
292 B
Python

import torch
import pyarrow as pa
torch.set_float32_matmul_precision("high")
with pa.memory_map("../../sample_1m.arrow", "r") as source:
loaded_arrays = pa.ipc.open_file(source).read_all()
train_split = 0.8
def ckpt_path(steps):
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"