1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-30 22:34:06 +00:00

repurpose meme rater

This commit is contained in:
osmarks 2025-01-18 07:19:08 +00:00
parent 163dceca4b
commit 0a542ef579
9 changed files with 183 additions and 50 deletions

View File

@ -7,6 +7,7 @@ import numpy
import json import json
import time import time
from tqdm import tqdm from tqdm import tqdm
import sys
from model import Config, BradleyTerry from model import Config, BradleyTerry
import shared import shared
@ -20,11 +21,12 @@ config = Config(
n_hidden=1, n_hidden=1,
n_ensemble=16, n_ensemble=16,
device=device, device=device,
dtype=torch.bfloat16, dtype=torch.float32,
dropout=0.5 output_channels=3,
dropout=0.1
) )
model = BradleyTerry(config) model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(2250) modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc)) model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters()) params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters") print(f"{params/1e6:.1f}M parameters")
@ -45,11 +47,13 @@ with torch.inference_mode():
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ]) embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device) inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
win_probs = model(inputs) win_probs = model(inputs)
#print(win_probs, win_probs.shape)
#print(win_probs.shape) #print(win_probs.shape)
batchvar = torch.var(win_probs, dim=0) batchvar = torch.var(win_probs, dim=0).max(-1).values
#print(batchvar, batchvar.shape)
for filename, var in zip(filenames, batchvar): for filename, var in zip(filenames, batchvar):
variance[filename] = float(var) variance[filename] = float(var)
top = sorted(variance.items(), key=lambda x: -x[1]) top = sorted(variance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f: with open("top.json", "w") as f:
json.dump(top[:256], f) json.dump(top[:100], f)

View File

@ -0,0 +1,64 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
import sys
from collections import defaultdict
from model import Config, BradleyTerry
import shared
batch_size = 128
num_pairs = batch_size * 1024
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
output_channels=3,
dropout=0.1
)
model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
results = {}
model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size]
filenames = [ f1 for f1, e1 in batch ]
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy()
#print(batchvar, batchvar.shape)
for filename, score in zip(filenames, scores):
results[filename] = score
channel = int(sys.argv[2])
percentile = float(sys.argv[3])
output_pairs = int(sys.argv[4])
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
select_from = top[:int(len(top) * percentile)]
out = []
for _ in range(output_pairs):
# dummy score for compatibility with existing code
out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0))
with open("top.json", "w") as f:
json.dump(out, f)

View File

@ -8,11 +8,11 @@ import json
import time import time
from tqdm import tqdm from tqdm import tqdm
from torch.func import functional_call, vmap, grad from torch.func import functional_call, vmap, grad
import sys
from model import Config, BradleyTerry from model import Config, BradleyTerry
import shared import shared
steps = 855
batch_size = 128 batch_size = 128
num_pairs = batch_size * 1024 num_pairs = batch_size * 1024
device = "cuda" device = "cuda"
@ -22,10 +22,12 @@ config = Config(
n_hidden=1, n_hidden=1,
n_ensemble=1, n_ensemble=1,
device=device, device=device,
dtype=torch.bfloat16 dtype=torch.float32,
output_channels=3,
dropout=0.1
) )
model = BradleyTerry(config) model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(855) modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc)) model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters()) params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters") print(f"{params/1e6:.1f}M parameters")
@ -61,7 +63,7 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
#win_probs = model(inputs) #win_probs = model(inputs)
# TODO gradients # TODO gradients
# don't take variance: do backwards pass and compute gradient norm # don't take variance: do backwards pass and compute gradient norm
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device)) grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size, config.output_channels), 0.95).to(device))
total_grad_norms = torch.zeros(batch_size).to(device) total_grad_norms = torch.zeros(batch_size).to(device)
for k, v in grads.items(): for k, v in grads.items():
param_dims = tuple(range(1, len(v.shape))) param_dims = tuple(range(1, len(v.shape)))

View File

@ -0,0 +1,20 @@
import jsonlines
import sqlite3
import numpy as np
import shared
shared.db.executescript("""
CREATE TABLE IF NOT EXISTS files (
filename TEXT NOT NULL,
title TEXT NOT NULL,
link TEXT NOT NULL,
embedding BLOB NOT NULL,
UNIQUE (filename)
);
""")
with jsonlines.open("sample.jsonl") as reader:
for obj in reader:
shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes())))
shared.db.commit()

View File

@ -13,13 +13,14 @@ class Config:
device: str device: str
dtype: torch.dtype dtype: torch.dtype
dropout: float dropout: float
output_channels: int
class Model(nn.Module): class Model(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ]) self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ]) self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device) self.output = nn.Linear(config.d_emb, config.output_channels, dtype=config.dtype, device=config.device)
def forward(self, embs): def forward(self, embs):
x = embs x = embs
@ -34,8 +35,7 @@ class Ensemble(nn.Module):
# model batch # model batch
def forward(self, embs): def forward(self, embs):
xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
return xs.squeeze(-1)
class BradleyTerry(nn.Module): class BradleyTerry(nn.Module):
def __init__(self, config): def __init__(self, config):

View File

@ -30,6 +30,7 @@ async def index(request):
return web.Response(text=f""" return web.Response(text=f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<title>Data Labelling Frontend (Not Evil)</title>
<style> <style>
.memes img {{ .memes img {{
width: 45%; width: 45%;
@ -46,38 +47,71 @@ async def index(request):
}} }}
</style> </style>
<body> <body>
<h1>Meme Rating</h1> <h1>Data Labelling Frontend (Not Evil)</h1>
<form action="/rate" method="POST"> <form action="/rate" method="POST">
<input type="radio" name="rating" value="1" id="rating1"> <label for="rating1">Meme 1 is better</label> <table>
<input type="radio" name="rating" value="2" id="rating2"> <label for="rating2">Meme 2 is better</label> <tr>
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
</tr>
<tr>
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
</tr>
<tr>
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
</td>
</table>
<input type="hidden" name="meme1" value="{meme1}"> <input type="hidden" name="meme1" value="{meme1}">
<input type="hidden" name="meme2" value="{meme2}"> <input type="hidden" name="meme2" value="{meme2}">
<input type="hidden" name="iteration" value="{str(iteration or 0)}"> <input type="hidden" name="iteration" value="{str(iteration or 0)}">
<input type="submit" value="Submit"> <input type="submit" value="Submit">
<div class="memes"> <div class="memes">
<img src="/memes/{meme1}" id="meme1"> <img src="{meme1}" id="meme1">
<img src="/memes/{meme2}" id="meme2"> <img src="{meme2}" id="meme2">
</div> </div>
</form> </form>
<script> <script>
document.addEventListener("keypress", function(event) {{ const commitIfReady = () => {{
if (event.key === "1") {{ if (document.querySelector("input[name='rating-useful']:checked") && document.querySelector("input[name='rating-meme']:checked") && document.querySelector("input[name='rating-aesthetic']:checked")) {{
document.querySelector("input[name='rating'][value='1']").checked = true
document.querySelector("form").submit()
}} else if (event.key === "2") {{
document.querySelector("input[name='rating'][value='2']").checked = true
document.querySelector("form").submit() document.querySelector("form").submit()
}} }}
}}
document.addEventListener("keypress", function(event) {{
if (event.key === "q") {{
document.querySelector("input[name='rating-useful'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "w") {{
document.querySelector("input[name='rating-useful'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "e") {{
document.querySelector("input[name='rating-useful'][value='2']").checked = true
commitIfReady()
}} else if (event.key === "a") {{
document.querySelector("input[name='rating-meme'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "s") {{
document.querySelector("input[name='rating-meme'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "d") {{
document.querySelector("input[name='rating-meme'][value='2']").checked = true
commitIfReady()
}} else if (event.key === "z") {{
document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "x") {{
document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "c") {{
document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true
commitIfReady()
}}
}}); }});
document.querySelector("#meme1").addEventListener("click", function(event) {{
document.querySelector("input[name='rating'][value='1']").checked = true
document.querySelector("form").submit()
}})
document.querySelector("#meme2").addEventListener("click", function(event) {{
document.querySelector("input[name='rating'][value='2']").checked = true
document.querySelector("form").submit()
}})
</script> </script>
</body> </body>
</html> </html>
@ -90,8 +124,8 @@ async def rate(request):
meme1 = post["meme1"] meme1 = post["meme1"]
meme2 = post["meme2"] meme2 = post["meme2"]
iteration = post["iteration"] iteration = post["iteration"]
rating = post["rating"] rating = post["rating-useful"] + "," + post["rating-meme"] + "," + post["rating-aesthetic"]
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration)) await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration, ip) VALUES (?, ?, ?, ?, ?)", (meme1, meme2, rating, iteration, request.remote))
await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2)) await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
await db.commit() await db.commit()
return web.HTTPFound("/") return web.HTTPFound("/")
@ -104,6 +138,7 @@ CREATE TABLE IF NOT EXISTS ratings (
meme2 TEXT NOT NULL, meme2 TEXT NOT NULL,
rating TEXT NOT NULL, rating TEXT NOT NULL,
iteration TEXT, iteration TEXT,
ip TEXT,
UNIQUE (meme1, meme2) UNIQUE (meme1, meme2)
); );
CREATE TABLE IF NOT EXISTS queue ( CREATE TABLE IF NOT EXISTS queue (

View File

@ -2,10 +2,11 @@
import json import json
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import sys
# Read data from log.jsonl # Read data from log.jsonl
data = [] data = []
with open('log.jsonl', 'r') as file: with open(sys.argv[1], 'r') as file:
for line in file: for line in file:
data.append(json.loads(line)) data.append(json.loads(line))

View File

@ -3,6 +3,7 @@ import hashlib
from collections import defaultdict from collections import defaultdict
import numpy import numpy
import random import random
import numpy as np
db = sqlite3.connect("data.sqlite3") db = sqlite3.connect("data.sqlite3")
db.row_factory = sqlite3.Row db.row_factory = sqlite3.Row
@ -20,19 +21,24 @@ def fetch_embedding(filename):
return x.copy() # PyTorch complains otherwise due to bad return x.copy() # PyTorch complains otherwise due to bad
def map_rating(rating, uncertainty=0.05): def map_rating(rating, uncertainty=0.05):
match rating: def map_one(rating):
case "1": # meme 1 is better match rating:
return 1 - uncertainty case "1": # meme 1 is better
case "2": return 1 - uncertainty
return uncertainty case "2":
case _: raise ValueError("invalid rating, please fix") return uncertainty
case "eq":
return 0.5
case _: raise ValueError("invalid rating, please fix")
return np.array([map_one(r) for r in rating.split(",")])
def fetch_ratings(): def fetch_ratings():
trains = defaultdict(list) trains = defaultdict(list)
validations = defaultdict(list) validations = defaultdict(list)
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings") csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
for meme1, meme2, rating, iteration in csr.fetchall(): for meme1, meme2, rating, iteration in csr.fetchall():
(validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating))) (validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
csr.close() csr.close()
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items())) return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))

View File

@ -36,7 +36,8 @@ config = TrainConfig(
n_ensemble=16, n_ensemble=16,
device=device, device=device,
dtype=torch.float32, dtype=torch.float32,
dropout=0.1 dropout=0.1,
output_channels=3
), ),
lr=3e-4, lr=3e-4,
weight_decay=0.2, weight_decay=0.2,
@ -72,12 +73,12 @@ if config.compile:
print("compiling...") print("compiling...")
train_step = torch.compile(train_step) train_step = torch.compile(train_step)
def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]): def batch_from_inputs(inputs: list[list[tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]]]):
batch_input = torch.stack([ batch_input = torch.stack([
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ]) torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
for input in inputs for input in inputs
]).to(device) ]).to(device)
target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device) target = torch.stack([ torch.Tensor(numpy.array([ rating for emb1, emb2, rating in input ])).to(config.model.dtype) for input in inputs ]).to(device)
return batch_input, target return batch_input, target
def evaluate(steps): def evaluate(steps):
@ -118,7 +119,7 @@ with open(logfile, "w") as log:
print(steps, loss) print(steps, loss)
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 10 == 0: if steps % 10 == 0:
if steps % 250 == 0: save_ckpt(log, steps) if steps % 100 == 0: save_ckpt(log, steps)
loss = evaluate(steps) loss = evaluate(steps)
#print(loss) #print(loss)
#best = min(loss, best) #best = min(loss, best)