1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-01-06 15:30:30 +00:00
meme-search-engine/meme-rater/train.py

129 lines
4.1 KiB
Python

import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
import math
from dataclasses import dataclass, asdict
from model import Config as ModelConfig, BradleyTerry
import shared
trains, validations = shared.fetch_ratings()
for train, validation in zip(trains, validations):
print(len(train), len(validation))
device = "cuda"
@dataclass
class TrainConfig:
model: ModelConfig
lr: float
weight_decay: float
batch_size: int
epochs: int
compile: bool
data_grouped_by_iter: bool
config = TrainConfig(
model=ModelConfig(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
dropout=0.1
),
lr=3e-4,
weight_decay=0.2,
batch_size=1,
epochs=5,
compile=False,
data_grouped_by_iter=False
)
def exprange(min, max, n):
lmin, lmax = math.log(min), math.log(max)
step = (lmax - lmin) / (n - 1)
return (math.exp(lmin + step * i) for i in range(n))
model = BradleyTerry(config.model)
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
def train_step(model, batch, real):
optimizer.zero_grad()
# model batch
win_probabilities = model(batch).float()
loss = F.binary_cross_entropy(win_probabilities, real)
loss.backward()
optimizer.step()
loss = loss.detach().cpu().item()
return loss
if config.compile:
print("compiling...")
train_step = torch.compile(train_step)
def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]):
batch_input = torch.stack([
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
for input in inputs
]).to(device)
target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device)
return batch_input, target
def evaluate(steps):
print("evaluating...")
model.eval()
results = {"step": steps, "time": time.time(), "val_loss": {}}
for vset, validation in enumerate(validations):
with torch.no_grad():
batch_input, target = batch_from_inputs([ validation[:128] for _ in range(config.model.n_ensemble) ])
result = model(batch_input).float()
val_loss = F.binary_cross_entropy(result, target).detach().cpu().item()
model.train()
results["val_loss"][vset] = val_loss
log.write(json.dumps(results) + "\n")
def save_ckpt(log, steps):
print("saving...")
modelc, optimc = shared.checkpoint_for(steps)
torch.save(optimizer.state_dict(), optimc)
torch.save(model.state_dict(), modelc)
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, torch.dtype):
return str(o)
else: return super().default(o)
logfile = f"logs/log-{time.time()}.jsonl"
with open(logfile, "w") as log:
steps = 0
log.write(JSONEncoder().encode(asdict(config)) + "\n")
for epoch in range(config.epochs):
for train in (trains if config.data_grouped_by_iter else [[ sample for trainss in trains for sample in trainss ]]):
data_orders = shared.generate_random_permutations(train, config.model.n_ensemble)
for bstart in range(0, len(train), config.batch_size):
batch_input, target = batch_from_inputs([ order[bstart:bstart + config.batch_size] for order in data_orders ])
loss = train_step(model, batch_input, target)
print(steps, loss)
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 10 == 0:
if steps % 250 == 0: save_ckpt(log, steps)
loss = evaluate(steps)
#print(loss)
#best = min(loss, best)
steps += 1
save_ckpt(log, steps)
print(logfile)