mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-01-20 06:02:56 +00:00
101 lines
2.9 KiB
Python
101 lines
2.9 KiB
Python
import torch.nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
import numpy
|
|
import json
|
|
import time
|
|
from tqdm import tqdm
|
|
from dataclasses import dataclass, asdict
|
|
|
|
from model import SAEConfig, SAE
|
|
from shared import train_split, loaded_arrays, ckpt_path
|
|
|
|
device = "cuda"
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
model: SAEConfig
|
|
lr: float
|
|
weight_decay: float
|
|
batch_size: int
|
|
epochs: int
|
|
compile: bool
|
|
|
|
config = TrainConfig(
|
|
model=SAEConfig(
|
|
d_emb=1152,
|
|
d_hidden=65536,
|
|
top_k=128,
|
|
device=device,
|
|
dtype=torch.float32,
|
|
up_proj_bias=False
|
|
),
|
|
lr=3e-4,
|
|
weight_decay=0.0,
|
|
batch_size=64,
|
|
epochs=5,
|
|
compile=True,
|
|
)
|
|
|
|
def exprange(min, max, n):
|
|
lmin, lmax = math.log(min), math.log(max)
|
|
step = (lmax - lmin) / (n - 1)
|
|
return (math.exp(lmin + step * i) for i in range(n))
|
|
|
|
model = SAE(config.model)
|
|
params = sum(p.numel() for p in model.parameters())
|
|
print(f"{params/1e6:.1f}M parameters")
|
|
print(model)
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
|
|
|
|
def train_step(model, batch):
|
|
optimizer.zero_grad()
|
|
reconstructions = model(batch).float()
|
|
loss = F.mse_loss(reconstructions, batch)
|
|
loss.backward()
|
|
optimizer.step()
|
|
return loss
|
|
|
|
if config.compile:
|
|
print("compiling...")
|
|
train_step = torch.compile(train_step)
|
|
|
|
def save_ckpt(log, steps):
|
|
#print("saving...")
|
|
modelc, optimc = ckpt_path(steps)
|
|
torch.save(optimizer.state_dict(), optimc)
|
|
torch.save({"model": model.state_dict(), "config": config}, modelc)
|
|
|
|
class JSONEncoder(json.JSONEncoder):
|
|
def default(self, o):
|
|
if isinstance(o, torch.dtype):
|
|
return str(o)
|
|
else: return super().default(o)
|
|
|
|
logfile = f"logs/log-{time.time()}.jsonl"
|
|
with open(logfile, "w") as log:
|
|
steps = 0
|
|
log.write(JSONEncoder().encode(asdict(config)) + "\n")
|
|
for epoch in range(config.epochs):
|
|
batch = []
|
|
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
|
|
for batch_start in t:
|
|
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
|
|
|
|
if len(batch) == config.batch_size:
|
|
batch = torch.Tensor(batch).to(device)
|
|
loss = train_step(model, batch)
|
|
loss = loss.detach().cpu().item()
|
|
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
|
|
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
|
if steps % 5000 == 0: save_ckpt(log, steps)
|
|
steps += 1
|
|
|
|
save_ckpt(log, steps)
|
|
print(model.feature_activation_counter.cpu())
|
|
ctr = model.reset_counters()
|
|
print(ctr)
|
|
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
|
|
|
|
print(logfile) |