1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00

meme rater model code (documentation "later")

This commit is contained in:
osmarks 2024-04-21 23:50:48 +01:00
parent 0b0261f625
commit 58ce70bb5e
10 changed files with 572 additions and 5 deletions

View File

@ -0,0 +1,55 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
from model import Config, BradleyTerry
import shared
batch_size = 128
num_pairs = batch_size * 1024
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.bfloat16,
dropout=0.5
)
model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(2250)
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
variance = {}
pairs = []
for _ in range(num_pairs):
pairs.append(tuple(random.sample(files, 2)))
model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(pairs), batch_size)):
batch = pairs[bstart:bstart + batch_size]
filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
win_probs = model(inputs)
#print(win_probs.shape)
batchvar = torch.var(win_probs, dim=0)
for filename, var in zip(filenames, batchvar):
variance[filename] = float(var)
top = sorted(variance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
json.dump(top[:256], f)

76
meme-rater/al2.py Normal file
View File

@ -0,0 +1,76 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
from torch.func import functional_call, vmap, grad
from model import Config, BradleyTerry
import shared
steps = 855
batch_size = 128
num_pairs = batch_size * 1024
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=1,
device=device,
dtype=torch.bfloat16
)
model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(855)
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
importance = {}
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = functional_call(model, (params, buffers), (batch,))
loss = F.binary_cross_entropy(predictions, targets)
return loss
ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1))
pairs = []
for _ in range(num_pairs):
pairs.append(tuple(random.sample(files, 2)))
for bstart in tqdm(range(0, len(pairs), batch_size)):
batch = pairs[bstart:bstart + batch_size]
filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
#win_probs = model(inputs)
# TODO gradients
# don't take variance: do backwards pass and compute gradient norm
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device))
total_grad_norms = torch.zeros(batch_size).to(device)
for k, v in grads.items():
param_dims = tuple(range(1, len(v.shape)))
total_grad_norms += torch.linalg.vector_norm(v, dim=param_dims)
tgn = total_grad_norms.cpu().numpy()
for filename, tg in zip(filenames, tgn):
importance[filename] = float(tg)
top = sorted(importance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
json.dump(top[:256], f)

View File

@ -0,0 +1,14 @@
import sqlite3
import json
import sys
iteration = sys.argv[1]
db = sqlite3.connect("data.sqlite3")
db.row_factory = sqlite3.Row
with open("top.json", "r") as f:
listing = json.load(f)
db.executemany("INSERT INTO queue VALUES (?, ?, ?)", [ (x[0], x[1], iteration) for x, v in listing ])
db.commit()

86
meme-rater/eval.py Normal file
View File

@ -0,0 +1,86 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
import torch
from model import Config, BradleyTerry
import shared
batch_size = 128
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
dropout=0.1
)
model = BradleyTerry(config).float()
modelc, _ = shared.checkpoint_for(1500)
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
ratings = {}
model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size]
filenames = [ filename for filename, embedding in batch ]
embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
scores = model.ensemble(inputs).float()
mscores = torch.median(scores, dim=0).values
for filename, mscore in zip(filenames, mscores):
ratings[filename] = float(mscore)
ratings = sorted(ratings.items(), key=lambda x: x[1])
def percentile(p, n):
base = round(p * len(ratings))
return ratings[base:base + n]
N = 25
def render_memeset(p):
filenames = percentile(p, N)
return f"""
<div>
<details><summary>Reveal Memeset</summary>{p}</details>
{''.join(f'<div><img src="{"images/" + f}" width="30%"><br><input type=checkbox id="{"col-" + str(p) + "-" + str(i)}"></div>' for i, (f, s) in enumerate(filenames))}
</div>
"""
buf = """<!DOCTYPE html>"""
probs = [0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 0.95, 0.98, 0.99]
random.shuffle(probs)
for p in probs:
#for p in [0.3]:
buf += render_memeset(p)
buf += """
<script>
const computeCounts = () => {
const counts = {}
for (const x of document.querySelectorAll("input[type=checkbox]")) {
const [_, percentile, index] = x.getAttribute("id").split("-")
counts[percentile] ??= 0
if (x.checked) counts[percentile] += 1
}
console.log(counts)
}
</script>
"""
with open("eval.html", "w") as f:
f.write(buf)

View File

@ -0,0 +1,33 @@
import matplotlib.pyplot as plt
import json
# Data as a JSON string
data_json = '{"0.95":22,"0.75":21,"0.5":15,"0.98":23,"0.25":3,"0.05":0,"0.99":24,"0.1":2,"0.01":0,"0.02":0}'
# Parse the JSON string into a dictionary
data = json.loads(data_json)
# Extract the keys and values from the dictionary
keys = list(data.keys())
values = list(data.values())
# Convert the keys to floats
keys = [float(key) for key in keys]
# Sort the keys and values based on the keys
sorted_data = sorted(zip(keys, values))
keys, values = zip(*sorted_data)
plt.plot(keys, values)
# Set the x-axis tick labels
plt.xticks(keys, rotation=45)
# Add labels and title
plt.xlabel('Percentile')
plt.ylabel('Memes Kept')
plt.title('Final Model Evaluation')
# Display the plot
plt.tight_layout()
plt.show()

52
meme-rater/model.py Normal file
View File

@ -0,0 +1,52 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from functools import partial
import math
@dataclass
class Config:
d_emb: int
n_hidden: int
n_ensemble: int
device: str
dtype: torch.dtype
dropout: float
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device)
def forward(self, embs):
x = embs
for (layer, dropout) in zip(self.hidden, self.dropout):
x = F.silu(layer(dropout(x)))
return self.output(x)
class Ensemble(nn.Module):
def __init__(self, config):
super().__init__()
self.models = nn.ModuleList([ Model(config) for i in range(config.n_ensemble) ])
# model batch
def forward(self, embs):
xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
return xs.squeeze(-1)
class BradleyTerry(nn.Module):
def __init__(self, config):
super().__init__()
self.ensemble = Ensemble(config)
def forward(self, embs): # model batch input=2 d_emb
scores1 = self.ensemble(embs[:, :, 0]).float() # model batch
scores2 = self.ensemble(embs[:, :, 1]).float()
# win probabilities
#print(scores1, scores2)
probs = torch.sigmoid(scores1 - scores2) # model batch
#print(probs)
return probs

View File

@ -11,15 +11,22 @@ routes = web.RouteTableDef()
async def get_pair(db):
while True:
filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ]
m1, m2 = tuple(sorted(random.sample(filenames, 2)))
csr = await db.execute("SELECT * FROM queue")
row = await csr.fetchone()
await csr.close()
iteration = None
if row:
m1, m2, iteration = row
else:
filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ]
m1, m2 = tuple(sorted(random.sample(filenames, 2)))
csr = await db.execute("SELECT 1 FROM ratings WHERE meme1 = ? AND meme2 = ?", (m1, m2))
if not await csr.fetchone():
return m1, m2
return m1, m2, iteration
@routes.get("/")
async def index(request):
meme1, meme2 = await get_pair(request.app["db"])
meme1, meme2, iteration = await get_pair(request.app["db"])
return web.Response(text=f"""
<!DOCTYPE html>
<html>
@ -46,6 +53,7 @@ async def index(request):
<input type="hidden" name="meme1" value="{meme1}">
<input type="hidden" name="meme2" value="{meme2}">
<input type="hidden" name="iteration" value="{str(iteration or 0)}">
<input type="submit" value="Submit">
<div class="memes">
<img src="/memes/{meme1}" id="meme1">
@ -81,8 +89,10 @@ async def rate(request):
post = await request.post()
meme1 = post["meme1"]
meme2 = post["meme2"]
iteration = post["iteration"]
rating = post["rating"]
await db.execute("INSERT INTO ratings (meme1, meme2, rating) VALUES (?, ?, ?)", (meme1, meme2, rating))
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration))
await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
await db.commit()
return web.HTTPFound("/")
@ -93,8 +103,15 @@ CREATE TABLE IF NOT EXISTS ratings (
meme1 TEXT NOT NULL,
meme2 TEXT NOT NULL,
rating TEXT NOT NULL,
iteration TEXT,
UNIQUE (meme1, meme2)
);
CREATE TABLE IF NOT EXISTS queue (
meme1 TEXT NOT NULL,
meme2 TEXT NOT NULL,
iteration TEXT NOT NULL,
UNIQUE (meme1, meme2, iteration)
);
""")
app.router.add_routes(routes)
app.router.add_static("/memes/", "./images")

52
meme-rater/run_graph.py Normal file
View File

@ -0,0 +1,52 @@
# claude-3
import json
import matplotlib.pyplot as plt
# Read data from log.jsonl
data = []
with open('log.jsonl', 'r') as file:
for line in file:
data.append(json.loads(line))
# Extract steps, loss, and val_loss
steps = [entry['step'] for entry in data if "loss" in entry]
loss = [entry['loss'] for entry in data if "loss" in entry]
val_loss_data = [entry['val_loss'] for entry in data if 'val_loss' in entry]
val_steps = [entry['step'] for entry in data if 'val_loss' in entry]
# Extract individual validation loss series
val_loss_series = {}
for val_loss in val_loss_data:
for key, value in val_loss.items():
if key not in val_loss_series:
val_loss_series[key] = []
val_loss_series[key].append(value)
# Calculate rolling average for loss
window_size = 50
rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)]
rolling_steps = steps[window_size-1:]
# Calculate rolling averages for validation loss series
val_rolling_avgs = {}
for key, series in val_loss_series.items():
val_rolling_avgs[key] = [sum(series[i:i+window_size])/window_size for i in range(len(series)-window_size+1)]
print([(name, min(series)) for name, series in val_loss_series.items()])
# Create the plot
plt.figure(figsize=(10, 6))
#plt.plot(steps, loss, label='Loss')
plt.plot(rolling_steps, rolling_avg, label='Rolling Average (Loss)')
for key, series in val_loss_series.items():
#plt.plot(val_steps, series, marker='o', linestyle='', label=f'Validation Loss ({key})')
plt.plot(val_steps[window_size-1:], val_rolling_avgs[key], label=f'Rolling Average (Validation Loss {key})')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Loss and Validation Loss vs. Steps')
plt.legend()
plt.grid(True)
plt.show()

53
meme-rater/shared.py Normal file
View File

@ -0,0 +1,53 @@
import sqlite3
import hashlib
from collections import defaultdict
import numpy
import random
db = sqlite3.connect("data.sqlite3")
db.row_factory = sqlite3.Row
val_fraction = 0.2
def is_val_set(meme1, meme2):
def is_one_val(meme):
return hashlib.sha256(meme.encode("utf-8")).digest()[0] / 255 < (val_fraction / 2) # not strictly correct but good enough
return is_one_val(meme1) or is_one_val(meme2)
def fetch_embedding(filename):
csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,))
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
csr.close()
return x.copy() # PyTorch complains otherwise due to bad
def map_rating(rating, uncertainty=0.05):
match rating:
case "1": # meme 1 is better
return 1 - uncertainty
case "2":
return uncertainty
case _: raise ValueError("invalid rating, please fix")
def fetch_ratings():
trains = defaultdict(list)
validations = defaultdict(list)
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
for meme1, meme2, rating, iteration in csr.fetchall():
(validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
csr.close()
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
def generate_random_permutations(x, n):
out = []
for _ in range(n):
random.shuffle(x)
out.append(x.copy())
return out
def fetch_all_files():
csr = db.execute("SELECT filename, embedding_vector FROM files")
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
csr.close()
return x
def checkpoint_for(steps):
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"

129
meme-rater/train.py Normal file
View File

@ -0,0 +1,129 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
import math
from dataclasses import dataclass, asdict
from model import Config as ModelConfig, BradleyTerry
import shared
trains, validations = shared.fetch_ratings()
for train, validation in zip(trains, validations):
print(len(train), len(validation))
device = "cuda"
@dataclass
class TrainConfig:
model: ModelConfig
lr: float
weight_decay: float
batch_size: int
epochs: int
compile: bool
data_grouped_by_iter: bool
config = TrainConfig(
model=ModelConfig(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
dropout=0.1
),
lr=3e-4,
weight_decay=0.2,
batch_size=1,
epochs=5,
compile=False,
data_grouped_by_iter=False
)
def exprange(min, max, n):
lmin, lmax = math.log(min), math.log(max)
step = (lmax - lmin) / (n - 1)
return (math.exp(lmin + step * i) for i in range(n))
model = BradleyTerry(config.model)
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
def train_step(model, batch, real):
optimizer.zero_grad()
# model batch
win_probabilities = model(batch).float()
loss = F.binary_cross_entropy(win_probabilities, real)
loss.backward()
optimizer.step()
loss = loss.detach().cpu().item()
return loss
if config.compile:
print("compiling...")
train_step = torch.compile(train_step)
def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]):
batch_input = torch.stack([
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
for input in inputs
]).to(device)
target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device)
return batch_input, target
def evaluate(steps):
print("evaluating...")
model.eval()
results = {"step": steps, "time": time.time(), "val_loss": {}}
for vset, validation in enumerate(validations):
with torch.no_grad():
batch_input, target = batch_from_inputs([ validation[:128] for _ in range(config.model.n_ensemble) ])
result = model(batch_input).float()
val_loss = F.binary_cross_entropy(result, target).detach().cpu().item()
model.train()
results["val_loss"][vset] = val_loss
log.write(json.dumps(results) + "\n")
def save_ckpt(log, steps):
print("saving...")
modelc, optimc = shared.checkpoint_for(steps)
torch.save(optimizer.state_dict(), optimc)
torch.save(model.state_dict(), modelc)
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, torch.dtype):
return str(o)
else: return super().default(o)
logfile = f"logs/log-{time.time()}.jsonl"
with open(logfile, "w") as log:
steps = 0
log.write(JSONEncoder().encode(asdict(config)) + "\n")
for epoch in range(config.epochs):
for train in (trains if config.data_grouped_by_iter else [[ sample for trainss in trains for sample in trainss ]]):
data_orders = shared.generate_random_permutations(train, config.model.n_ensemble)
for bstart in range(0, len(train), config.batch_size):
batch_input, target = batch_from_inputs([ order[bstart:bstart + config.batch_size] for order in data_orders ])
loss = train_step(model, batch_input, target)
print(steps, loss)
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 10 == 0:
if steps % 250 == 0: save_ckpt(log, steps)
loss = evaluate(steps)
#print(loss)
#best = min(loss, best)
steps += 1
save_ckpt(log, steps)
print(logfile)