mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-28 13:33:11 +00:00
repurpose meme rater
This commit is contained in:
parent
163dceca4b
commit
0a542ef579
@ -7,6 +7,7 @@ import numpy
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
from model import Config, BradleyTerry
|
||||
import shared
|
||||
@ -20,11 +21,12 @@ config = Config(
|
||||
n_hidden=1,
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
dropout=0.5
|
||||
dtype=torch.float32,
|
||||
output_channels=3,
|
||||
dropout=0.1
|
||||
)
|
||||
model = BradleyTerry(config)
|
||||
modelc, _ = shared.checkpoint_for(2250)
|
||||
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
||||
model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
@ -45,11 +47,13 @@ with torch.inference_mode():
|
||||
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
|
||||
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
|
||||
win_probs = model(inputs)
|
||||
#print(win_probs, win_probs.shape)
|
||||
#print(win_probs.shape)
|
||||
batchvar = torch.var(win_probs, dim=0)
|
||||
batchvar = torch.var(win_probs, dim=0).max(-1).values
|
||||
#print(batchvar, batchvar.shape)
|
||||
for filename, var in zip(filenames, batchvar):
|
||||
variance[filename] = float(var)
|
||||
|
||||
top = sorted(variance.items(), key=lambda x: -x[1])
|
||||
with open("top.json", "w") as f:
|
||||
json.dump(top[:256], f)
|
||||
json.dump(top[:100], f)
|
||||
|
64
meme-rater/active_learning_find_top.py
Normal file
64
meme-rater/active_learning_find_top.py
Normal file
@ -0,0 +1,64 @@
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import sqlite3
|
||||
import random
|
||||
import numpy
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
from model import Config, BradleyTerry
|
||||
import shared
|
||||
|
||||
batch_size = 128
|
||||
num_pairs = batch_size * 1024
|
||||
device = "cuda"
|
||||
|
||||
config = Config(
|
||||
d_emb=1152,
|
||||
n_hidden=1,
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
output_channels=3,
|
||||
dropout=0.1
|
||||
)
|
||||
model = BradleyTerry(config)
|
||||
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
||||
model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
|
||||
files = shared.fetch_all_files()
|
||||
results = {}
|
||||
|
||||
model.eval()
|
||||
with torch.inference_mode():
|
||||
for bstart in tqdm(range(0, len(files), batch_size)):
|
||||
batch = files[bstart:bstart + batch_size]
|
||||
filenames = [ f1 for f1, e1 in batch ]
|
||||
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ])
|
||||
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
||||
scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy()
|
||||
#print(batchvar, batchvar.shape)
|
||||
for filename, score in zip(filenames, scores):
|
||||
results[filename] = score
|
||||
|
||||
channel = int(sys.argv[2])
|
||||
percentile = float(sys.argv[3])
|
||||
output_pairs = int(sys.argv[4])
|
||||
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
|
||||
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
|
||||
select_from = top[:int(len(top) * percentile)]
|
||||
|
||||
out = []
|
||||
for _ in range(output_pairs):
|
||||
# dummy score for compatibility with existing code
|
||||
out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0))
|
||||
|
||||
with open("top.json", "w") as f:
|
||||
json.dump(out, f)
|
@ -8,11 +8,11 @@ import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from torch.func import functional_call, vmap, grad
|
||||
import sys
|
||||
|
||||
from model import Config, BradleyTerry
|
||||
import shared
|
||||
|
||||
steps = 855
|
||||
batch_size = 128
|
||||
num_pairs = batch_size * 1024
|
||||
device = "cuda"
|
||||
@ -22,10 +22,12 @@ config = Config(
|
||||
n_hidden=1,
|
||||
n_ensemble=1,
|
||||
device=device,
|
||||
dtype=torch.bfloat16
|
||||
dtype=torch.float32,
|
||||
output_channels=3,
|
||||
dropout=0.1
|
||||
)
|
||||
model = BradleyTerry(config)
|
||||
modelc, _ = shared.checkpoint_for(855)
|
||||
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
||||
model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
@ -61,7 +63,7 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
|
||||
#win_probs = model(inputs)
|
||||
# TODO gradients
|
||||
# don't take variance: do backwards pass and compute gradient norm
|
||||
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device))
|
||||
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size, config.output_channels), 0.95).to(device))
|
||||
total_grad_norms = torch.zeros(batch_size).to(device)
|
||||
for k, v in grads.items():
|
||||
param_dims = tuple(range(1, len(v.shape)))
|
||||
@ -73,4 +75,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
|
||||
|
||||
top = sorted(importance.items(), key=lambda x: -x[1])
|
||||
with open("top.json", "w") as f:
|
||||
json.dump(top[:256], f)
|
||||
json.dump(top[:256], f)
|
20
meme-rater/load_from_json.py
Normal file
20
meme-rater/load_from_json.py
Normal file
@ -0,0 +1,20 @@
|
||||
import jsonlines
|
||||
import sqlite3
|
||||
import numpy as np
|
||||
|
||||
import shared
|
||||
|
||||
shared.db.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
link TEXT NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
UNIQUE (filename)
|
||||
);
|
||||
""")
|
||||
|
||||
with jsonlines.open("sample.jsonl") as reader:
|
||||
for obj in reader:
|
||||
shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes())))
|
||||
shared.db.commit()
|
@ -13,13 +13,14 @@ class Config:
|
||||
device: str
|
||||
dtype: torch.dtype
|
||||
dropout: float
|
||||
output_channels: int
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
|
||||
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
|
||||
self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device)
|
||||
self.output = nn.Linear(config.d_emb, config.output_channels, dtype=config.dtype, device=config.device)
|
||||
|
||||
def forward(self, embs):
|
||||
x = embs
|
||||
@ -34,8 +35,7 @@ class Ensemble(nn.Module):
|
||||
|
||||
# model batch
|
||||
def forward(self, embs):
|
||||
xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
|
||||
return xs.squeeze(-1)
|
||||
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
|
||||
|
||||
class BradleyTerry(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -49,4 +49,4 @@ class BradleyTerry(nn.Module):
|
||||
#print(scores1, scores2)
|
||||
probs = torch.sigmoid(scores1 - scores2) # model batch
|
||||
#print(probs)
|
||||
return probs
|
||||
return probs
|
||||
|
@ -30,9 +30,10 @@ async def index(request):
|
||||
return web.Response(text=f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<title>Data Labelling Frontend (Not Evil)</title>
|
||||
<style>
|
||||
.memes img {{
|
||||
width: 45%;
|
||||
width: 45%;
|
||||
}}
|
||||
|
||||
@media (max-width: 768px) {{
|
||||
@ -46,38 +47,71 @@ async def index(request):
|
||||
}}
|
||||
</style>
|
||||
<body>
|
||||
<h1>Meme Rating</h1>
|
||||
<h1>Data Labelling Frontend (Not Evil)</h1>
|
||||
<form action="/rate" method="POST">
|
||||
<input type="radio" name="rating" value="1" id="rating1"> <label for="rating1">Meme 1 is better</label>
|
||||
<input type="radio" name="rating" value="2" id="rating2"> <label for="rating2">Meme 2 is better</label>
|
||||
<table>
|
||||
<tr>
|
||||
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
|
||||
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
|
||||
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
|
||||
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
|
||||
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
|
||||
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
|
||||
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
|
||||
</td>
|
||||
</table>
|
||||
|
||||
<input type="hidden" name="meme1" value="{meme1}">
|
||||
<input type="hidden" name="meme2" value="{meme2}">
|
||||
<input type="hidden" name="iteration" value="{str(iteration or 0)}">
|
||||
<input type="submit" value="Submit">
|
||||
<div class="memes">
|
||||
<img src="/memes/{meme1}" id="meme1">
|
||||
<img src="/memes/{meme2}" id="meme2">
|
||||
<img src="{meme1}" id="meme1">
|
||||
<img src="{meme2}" id="meme2">
|
||||
</div>
|
||||
</form>
|
||||
<script>
|
||||
document.addEventListener("keypress", function(event) {{
|
||||
if (event.key === "1") {{
|
||||
document.querySelector("input[name='rating'][value='1']").checked = true
|
||||
document.querySelector("form").submit()
|
||||
}} else if (event.key === "2") {{
|
||||
document.querySelector("input[name='rating'][value='2']").checked = true
|
||||
const commitIfReady = () => {{
|
||||
if (document.querySelector("input[name='rating-useful']:checked") && document.querySelector("input[name='rating-meme']:checked") && document.querySelector("input[name='rating-aesthetic']:checked")) {{
|
||||
document.querySelector("form").submit()
|
||||
}}
|
||||
}}
|
||||
document.addEventListener("keypress", function(event) {{
|
||||
if (event.key === "q") {{
|
||||
document.querySelector("input[name='rating-useful'][value='1']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "w") {{
|
||||
document.querySelector("input[name='rating-useful'][value='eq']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "e") {{
|
||||
document.querySelector("input[name='rating-useful'][value='2']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "a") {{
|
||||
document.querySelector("input[name='rating-meme'][value='1']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "s") {{
|
||||
document.querySelector("input[name='rating-meme'][value='eq']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "d") {{
|
||||
document.querySelector("input[name='rating-meme'][value='2']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "z") {{
|
||||
document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "x") {{
|
||||
document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "c") {{
|
||||
document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true
|
||||
commitIfReady()
|
||||
}}
|
||||
}});
|
||||
document.querySelector("#meme1").addEventListener("click", function(event) {{
|
||||
document.querySelector("input[name='rating'][value='1']").checked = true
|
||||
document.querySelector("form").submit()
|
||||
}})
|
||||
document.querySelector("#meme2").addEventListener("click", function(event) {{
|
||||
document.querySelector("input[name='rating'][value='2']").checked = true
|
||||
document.querySelector("form").submit()
|
||||
}})
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@ -90,8 +124,8 @@ async def rate(request):
|
||||
meme1 = post["meme1"]
|
||||
meme2 = post["meme2"]
|
||||
iteration = post["iteration"]
|
||||
rating = post["rating"]
|
||||
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration))
|
||||
rating = post["rating-useful"] + "," + post["rating-meme"] + "," + post["rating-aesthetic"]
|
||||
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration, ip) VALUES (?, ?, ?, ?, ?)", (meme1, meme2, rating, iteration, request.remote))
|
||||
await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
|
||||
await db.commit()
|
||||
return web.HTTPFound("/")
|
||||
@ -104,6 +138,7 @@ CREATE TABLE IF NOT EXISTS ratings (
|
||||
meme2 TEXT NOT NULL,
|
||||
rating TEXT NOT NULL,
|
||||
iteration TEXT,
|
||||
ip TEXT,
|
||||
UNIQUE (meme1, meme2)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS queue (
|
||||
|
@ -2,10 +2,11 @@
|
||||
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
|
||||
# Read data from log.jsonl
|
||||
data = []
|
||||
with open('log.jsonl', 'r') as file:
|
||||
with open(sys.argv[1], 'r') as file:
|
||||
for line in file:
|
||||
data.append(json.loads(line))
|
||||
|
||||
|
@ -3,6 +3,7 @@ import hashlib
|
||||
from collections import defaultdict
|
||||
import numpy
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
db = sqlite3.connect("data.sqlite3")
|
||||
db.row_factory = sqlite3.Row
|
||||
@ -20,19 +21,24 @@ def fetch_embedding(filename):
|
||||
return x.copy() # PyTorch complains otherwise due to bad
|
||||
|
||||
def map_rating(rating, uncertainty=0.05):
|
||||
match rating:
|
||||
case "1": # meme 1 is better
|
||||
return 1 - uncertainty
|
||||
case "2":
|
||||
return uncertainty
|
||||
case _: raise ValueError("invalid rating, please fix")
|
||||
def map_one(rating):
|
||||
match rating:
|
||||
case "1": # meme 1 is better
|
||||
return 1 - uncertainty
|
||||
case "2":
|
||||
return uncertainty
|
||||
case "eq":
|
||||
return 0.5
|
||||
case _: raise ValueError("invalid rating, please fix")
|
||||
|
||||
return np.array([map_one(r) for r in rating.split(",")])
|
||||
|
||||
def fetch_ratings():
|
||||
trains = defaultdict(list)
|
||||
validations = defaultdict(list)
|
||||
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
||||
for meme1, meme2, rating, iteration in csr.fetchall():
|
||||
(validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
||||
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
||||
csr.close()
|
||||
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
|
||||
|
||||
@ -50,4 +56,4 @@ def fetch_all_files():
|
||||
return x
|
||||
|
||||
def checkpoint_for(steps):
|
||||
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"
|
||||
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"
|
||||
|
@ -36,7 +36,8 @@ config = TrainConfig(
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
dropout=0.1
|
||||
dropout=0.1,
|
||||
output_channels=3
|
||||
),
|
||||
lr=3e-4,
|
||||
weight_decay=0.2,
|
||||
@ -72,12 +73,12 @@ if config.compile:
|
||||
print("compiling...")
|
||||
train_step = torch.compile(train_step)
|
||||
|
||||
def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]):
|
||||
def batch_from_inputs(inputs: list[list[tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]]]):
|
||||
batch_input = torch.stack([
|
||||
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
|
||||
for input in inputs
|
||||
]).to(device)
|
||||
target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device)
|
||||
target = torch.stack([ torch.Tensor(numpy.array([ rating for emb1, emb2, rating in input ])).to(config.model.dtype) for input in inputs ]).to(device)
|
||||
return batch_input, target
|
||||
|
||||
def evaluate(steps):
|
||||
@ -118,7 +119,7 @@ with open(logfile, "w") as log:
|
||||
print(steps, loss)
|
||||
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
||||
if steps % 10 == 0:
|
||||
if steps % 250 == 0: save_ckpt(log, steps)
|
||||
if steps % 100 == 0: save_ckpt(log, steps)
|
||||
loss = evaluate(steps)
|
||||
#print(loss)
|
||||
#best = min(loss, best)
|
||||
@ -126,4 +127,4 @@ with open(logfile, "w") as log:
|
||||
|
||||
save_ckpt(log, steps)
|
||||
|
||||
print(logfile)
|
||||
print(logfile)
|
||||
|
Loading…
x
Reference in New Issue
Block a user