1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-28 13:33:11 +00:00

repurpose meme rater

This commit is contained in:
osmarks 2025-01-18 07:19:08 +00:00
parent 163dceca4b
commit 0a542ef579
9 changed files with 183 additions and 50 deletions

View File

@ -7,6 +7,7 @@ import numpy
import json
import time
from tqdm import tqdm
import sys
from model import Config, BradleyTerry
import shared
@ -20,11 +21,12 @@ config = Config(
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.bfloat16,
dropout=0.5
dtype=torch.float32,
output_channels=3,
dropout=0.1
)
model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(2250)
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
@ -45,11 +47,13 @@ with torch.inference_mode():
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
win_probs = model(inputs)
#print(win_probs, win_probs.shape)
#print(win_probs.shape)
batchvar = torch.var(win_probs, dim=0)
batchvar = torch.var(win_probs, dim=0).max(-1).values
#print(batchvar, batchvar.shape)
for filename, var in zip(filenames, batchvar):
variance[filename] = float(var)
top = sorted(variance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
json.dump(top[:256], f)
json.dump(top[:100], f)

View File

@ -0,0 +1,64 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
import sys
from collections import defaultdict
from model import Config, BradleyTerry
import shared
batch_size = 128
num_pairs = batch_size * 1024
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
output_channels=3,
dropout=0.1
)
model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
results = {}
model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size]
filenames = [ f1 for f1, e1 in batch ]
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy()
#print(batchvar, batchvar.shape)
for filename, score in zip(filenames, scores):
results[filename] = score
channel = int(sys.argv[2])
percentile = float(sys.argv[3])
output_pairs = int(sys.argv[4])
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
select_from = top[:int(len(top) * percentile)]
out = []
for _ in range(output_pairs):
# dummy score for compatibility with existing code
out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0))
with open("top.json", "w") as f:
json.dump(out, f)

View File

@ -8,11 +8,11 @@ import json
import time
from tqdm import tqdm
from torch.func import functional_call, vmap, grad
import sys
from model import Config, BradleyTerry
import shared
steps = 855
batch_size = 128
num_pairs = batch_size * 1024
device = "cuda"
@ -22,10 +22,12 @@ config = Config(
n_hidden=1,
n_ensemble=1,
device=device,
dtype=torch.bfloat16
dtype=torch.float32,
output_channels=3,
dropout=0.1
)
model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(855)
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
@ -61,7 +63,7 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
#win_probs = model(inputs)
# TODO gradients
# don't take variance: do backwards pass and compute gradient norm
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device))
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size, config.output_channels), 0.95).to(device))
total_grad_norms = torch.zeros(batch_size).to(device)
for k, v in grads.items():
param_dims = tuple(range(1, len(v.shape)))
@ -73,4 +75,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
top = sorted(importance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
json.dump(top[:256], f)
json.dump(top[:256], f)

View File

@ -0,0 +1,20 @@
import jsonlines
import sqlite3
import numpy as np
import shared
shared.db.executescript("""
CREATE TABLE IF NOT EXISTS files (
filename TEXT NOT NULL,
title TEXT NOT NULL,
link TEXT NOT NULL,
embedding BLOB NOT NULL,
UNIQUE (filename)
);
""")
with jsonlines.open("sample.jsonl") as reader:
for obj in reader:
shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes())))
shared.db.commit()

View File

@ -13,13 +13,14 @@ class Config:
device: str
dtype: torch.dtype
dropout: float
output_channels: int
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device)
self.output = nn.Linear(config.d_emb, config.output_channels, dtype=config.dtype, device=config.device)
def forward(self, embs):
x = embs
@ -34,8 +35,7 @@ class Ensemble(nn.Module):
# model batch
def forward(self, embs):
xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
return xs.squeeze(-1)
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
class BradleyTerry(nn.Module):
def __init__(self, config):
@ -49,4 +49,4 @@ class BradleyTerry(nn.Module):
#print(scores1, scores2)
probs = torch.sigmoid(scores1 - scores2) # model batch
#print(probs)
return probs
return probs

View File

@ -30,9 +30,10 @@ async def index(request):
return web.Response(text=f"""
<!DOCTYPE html>
<html>
<title>Data Labelling Frontend (Not Evil)</title>
<style>
.memes img {{
width: 45%;
width: 45%;
}}
@media (max-width: 768px) {{
@ -46,38 +47,71 @@ async def index(request):
}}
</style>
<body>
<h1>Meme Rating</h1>
<h1>Data Labelling Frontend (Not Evil)</h1>
<form action="/rate" method="POST">
<input type="radio" name="rating" value="1" id="rating1"> <label for="rating1">Meme 1 is better</label>
<input type="radio" name="rating" value="2" id="rating2"> <label for="rating2">Meme 2 is better</label>
<table>
<tr>
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
</tr>
<tr>
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
</tr>
<tr>
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
</td>
</table>
<input type="hidden" name="meme1" value="{meme1}">
<input type="hidden" name="meme2" value="{meme2}">
<input type="hidden" name="iteration" value="{str(iteration or 0)}">
<input type="submit" value="Submit">
<div class="memes">
<img src="/memes/{meme1}" id="meme1">
<img src="/memes/{meme2}" id="meme2">
<img src="{meme1}" id="meme1">
<img src="{meme2}" id="meme2">
</div>
</form>
<script>
document.addEventListener("keypress", function(event) {{
if (event.key === "1") {{
document.querySelector("input[name='rating'][value='1']").checked = true
document.querySelector("form").submit()
}} else if (event.key === "2") {{
document.querySelector("input[name='rating'][value='2']").checked = true
const commitIfReady = () => {{
if (document.querySelector("input[name='rating-useful']:checked") && document.querySelector("input[name='rating-meme']:checked") && document.querySelector("input[name='rating-aesthetic']:checked")) {{
document.querySelector("form").submit()
}}
}}
document.addEventListener("keypress", function(event) {{
if (event.key === "q") {{
document.querySelector("input[name='rating-useful'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "w") {{
document.querySelector("input[name='rating-useful'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "e") {{
document.querySelector("input[name='rating-useful'][value='2']").checked = true
commitIfReady()
}} else if (event.key === "a") {{
document.querySelector("input[name='rating-meme'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "s") {{
document.querySelector("input[name='rating-meme'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "d") {{
document.querySelector("input[name='rating-meme'][value='2']").checked = true
commitIfReady()
}} else if (event.key === "z") {{
document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "x") {{
document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "c") {{
document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true
commitIfReady()
}}
}});
document.querySelector("#meme1").addEventListener("click", function(event) {{
document.querySelector("input[name='rating'][value='1']").checked = true
document.querySelector("form").submit()
}})
document.querySelector("#meme2").addEventListener("click", function(event) {{
document.querySelector("input[name='rating'][value='2']").checked = true
document.querySelector("form").submit()
}})
</script>
</body>
</html>
@ -90,8 +124,8 @@ async def rate(request):
meme1 = post["meme1"]
meme2 = post["meme2"]
iteration = post["iteration"]
rating = post["rating"]
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration))
rating = post["rating-useful"] + "," + post["rating-meme"] + "," + post["rating-aesthetic"]
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration, ip) VALUES (?, ?, ?, ?, ?)", (meme1, meme2, rating, iteration, request.remote))
await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
await db.commit()
return web.HTTPFound("/")
@ -104,6 +138,7 @@ CREATE TABLE IF NOT EXISTS ratings (
meme2 TEXT NOT NULL,
rating TEXT NOT NULL,
iteration TEXT,
ip TEXT,
UNIQUE (meme1, meme2)
);
CREATE TABLE IF NOT EXISTS queue (

View File

@ -2,10 +2,11 @@
import json
import matplotlib.pyplot as plt
import sys
# Read data from log.jsonl
data = []
with open('log.jsonl', 'r') as file:
with open(sys.argv[1], 'r') as file:
for line in file:
data.append(json.loads(line))

View File

@ -3,6 +3,7 @@ import hashlib
from collections import defaultdict
import numpy
import random
import numpy as np
db = sqlite3.connect("data.sqlite3")
db.row_factory = sqlite3.Row
@ -20,19 +21,24 @@ def fetch_embedding(filename):
return x.copy() # PyTorch complains otherwise due to bad
def map_rating(rating, uncertainty=0.05):
match rating:
case "1": # meme 1 is better
return 1 - uncertainty
case "2":
return uncertainty
case _: raise ValueError("invalid rating, please fix")
def map_one(rating):
match rating:
case "1": # meme 1 is better
return 1 - uncertainty
case "2":
return uncertainty
case "eq":
return 0.5
case _: raise ValueError("invalid rating, please fix")
return np.array([map_one(r) for r in rating.split(",")])
def fetch_ratings():
trains = defaultdict(list)
validations = defaultdict(list)
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
for meme1, meme2, rating, iteration in csr.fetchall():
(validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
csr.close()
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
@ -50,4 +56,4 @@ def fetch_all_files():
return x
def checkpoint_for(steps):
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"

View File

@ -36,7 +36,8 @@ config = TrainConfig(
n_ensemble=16,
device=device,
dtype=torch.float32,
dropout=0.1
dropout=0.1,
output_channels=3
),
lr=3e-4,
weight_decay=0.2,
@ -72,12 +73,12 @@ if config.compile:
print("compiling...")
train_step = torch.compile(train_step)
def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]):
def batch_from_inputs(inputs: list[list[tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]]]):
batch_input = torch.stack([
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
for input in inputs
]).to(device)
target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device)
target = torch.stack([ torch.Tensor(numpy.array([ rating for emb1, emb2, rating in input ])).to(config.model.dtype) for input in inputs ]).to(device)
return batch_input, target
def evaluate(steps):
@ -118,7 +119,7 @@ with open(logfile, "w") as log:
print(steps, loss)
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 10 == 0:
if steps % 250 == 0: save_ckpt(log, steps)
if steps % 100 == 0: save_ckpt(log, steps)
loss = evaluate(steps)
#print(loss)
#best = min(loss, best)
@ -126,4 +127,4 @@ with open(logfile, "w") as log:
save_ckpt(log, steps)
print(logfile)
print(logfile)