1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-14 15:54:48 +00:00
meme-search-engine/sae/train.py

101 lines
2.9 KiB
Python

import torch.nn
import torch.nn.functional as F
import torch
import numpy
import json
import time
from tqdm import tqdm
from dataclasses import dataclass, asdict
from model import SAEConfig, SAE
from shared import train_split, loaded_arrays, ckpt_path
device = "cuda"
@dataclass
class TrainConfig:
model: SAEConfig
lr: float
weight_decay: float
batch_size: int
epochs: int
compile: bool
config = TrainConfig(
model=SAEConfig(
d_emb=1152,
d_hidden=65536,
top_k=128,
device=device,
dtype=torch.float32,
up_proj_bias=False
),
lr=3e-4,
weight_decay=0.0,
batch_size=64,
epochs=5,
compile=True,
)
def exprange(min, max, n):
lmin, lmax = math.log(min), math.log(max)
step = (lmax - lmin) / (n - 1)
return (math.exp(lmin + step * i) for i in range(n))
model = SAE(config.model)
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
def train_step(model, batch):
optimizer.zero_grad()
reconstructions = model(batch).float()
loss = F.mse_loss(reconstructions, batch)
loss.backward()
optimizer.step()
return loss
if config.compile:
print("compiling...")
train_step = torch.compile(train_step)
def save_ckpt(log, steps):
#print("saving...")
modelc, optimc = ckpt_path(steps)
torch.save(optimizer.state_dict(), optimc)
torch.save({"model": model.state_dict(), "config": config}, modelc)
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, torch.dtype):
return str(o)
else: return super().default(o)
logfile = f"logs/log-{time.time()}.jsonl"
with open(logfile, "w") as log:
steps = 0
log.write(JSONEncoder().encode(asdict(config)) + "\n")
for epoch in range(config.epochs):
batch = []
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
for batch_start in t:
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
if len(batch) == config.batch_size:
batch = torch.Tensor(batch).to(device)
loss = train_step(model, batch)
loss = loss.detach().cpu().item()
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 5000 == 0: save_ckpt(log, steps)
steps += 1
save_ckpt(log, steps)
print(model.feature_activation_counter.cpu())
ctr = model.reset_counters()
print(ctr)
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
print(logfile)