mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-05-12 20:24:05 +00:00
meme rater model code (documentation "later")
This commit is contained in:
parent
0b0261f625
commit
58ce70bb5e
55
meme-rater/active_learning.py
Normal file
55
meme-rater/active_learning.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import torch.nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
import sqlite3
|
||||||
|
import random
|
||||||
|
import numpy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from model import Config, BradleyTerry
|
||||||
|
import shared
|
||||||
|
|
||||||
|
batch_size = 128
|
||||||
|
num_pairs = batch_size * 1024
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
config = Config(
|
||||||
|
d_emb=1152,
|
||||||
|
n_hidden=1,
|
||||||
|
n_ensemble=16,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
dropout=0.5
|
||||||
|
)
|
||||||
|
model = BradleyTerry(config)
|
||||||
|
modelc, _ = shared.checkpoint_for(2250)
|
||||||
|
model.load_state_dict(torch.load(modelc))
|
||||||
|
params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
files = shared.fetch_all_files()
|
||||||
|
variance = {}
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
for _ in range(num_pairs):
|
||||||
|
pairs.append(tuple(random.sample(files, 2)))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.inference_mode():
|
||||||
|
for bstart in tqdm(range(0, len(pairs), batch_size)):
|
||||||
|
batch = pairs[bstart:bstart + batch_size]
|
||||||
|
filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
|
||||||
|
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
|
||||||
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
|
||||||
|
win_probs = model(inputs)
|
||||||
|
#print(win_probs.shape)
|
||||||
|
batchvar = torch.var(win_probs, dim=0)
|
||||||
|
for filename, var in zip(filenames, batchvar):
|
||||||
|
variance[filename] = float(var)
|
||||||
|
|
||||||
|
top = sorted(variance.items(), key=lambda x: -x[1])
|
||||||
|
with open("top.json", "w") as f:
|
||||||
|
json.dump(top[:256], f)
|
76
meme-rater/al2.py
Normal file
76
meme-rater/al2.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import torch.nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
import sqlite3
|
||||||
|
import random
|
||||||
|
import numpy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.func import functional_call, vmap, grad
|
||||||
|
|
||||||
|
from model import Config, BradleyTerry
|
||||||
|
import shared
|
||||||
|
|
||||||
|
steps = 855
|
||||||
|
batch_size = 128
|
||||||
|
num_pairs = batch_size * 1024
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
config = Config(
|
||||||
|
d_emb=1152,
|
||||||
|
n_hidden=1,
|
||||||
|
n_ensemble=1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
model = BradleyTerry(config)
|
||||||
|
modelc, _ = shared.checkpoint_for(855)
|
||||||
|
model.load_state_dict(torch.load(modelc))
|
||||||
|
params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
files = shared.fetch_all_files()
|
||||||
|
importance = {}
|
||||||
|
|
||||||
|
params = {k: v.detach() for k, v in model.named_parameters()}
|
||||||
|
buffers = {k: v.detach() for k, v in model.named_buffers()}
|
||||||
|
|
||||||
|
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
|
||||||
|
def compute_loss(params, buffers, sample, target):
|
||||||
|
batch = sample.unsqueeze(0)
|
||||||
|
targets = target.unsqueeze(0)
|
||||||
|
|
||||||
|
predictions = functional_call(model, (params, buffers), (batch,))
|
||||||
|
loss = F.binary_cross_entropy(predictions, targets)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
ft_compute_grad = grad(compute_loss)
|
||||||
|
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1))
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
for _ in range(num_pairs):
|
||||||
|
pairs.append(tuple(random.sample(files, 2)))
|
||||||
|
|
||||||
|
for bstart in tqdm(range(0, len(pairs), batch_size)):
|
||||||
|
batch = pairs[bstart:bstart + batch_size]
|
||||||
|
filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
|
||||||
|
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
|
||||||
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
|
||||||
|
#win_probs = model(inputs)
|
||||||
|
# TODO gradients
|
||||||
|
# don't take variance: do backwards pass and compute gradient norm
|
||||||
|
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device))
|
||||||
|
total_grad_norms = torch.zeros(batch_size).to(device)
|
||||||
|
for k, v in grads.items():
|
||||||
|
param_dims = tuple(range(1, len(v.shape)))
|
||||||
|
total_grad_norms += torch.linalg.vector_norm(v, dim=param_dims)
|
||||||
|
tgn = total_grad_norms.cpu().numpy()
|
||||||
|
|
||||||
|
for filename, tg in zip(filenames, tgn):
|
||||||
|
importance[filename] = float(tg)
|
||||||
|
|
||||||
|
top = sorted(importance.items(), key=lambda x: -x[1])
|
||||||
|
with open("top.json", "w") as f:
|
||||||
|
json.dump(top[:256], f)
|
14
meme-rater/copy_into_queue.py
Normal file
14
meme-rater/copy_into_queue.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
iteration = sys.argv[1]
|
||||||
|
|
||||||
|
db = sqlite3.connect("data.sqlite3")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
with open("top.json", "r") as f:
|
||||||
|
listing = json.load(f)
|
||||||
|
|
||||||
|
db.executemany("INSERT INTO queue VALUES (?, ?, ?)", [ (x[0], x[1], iteration) for x, v in listing ])
|
||||||
|
db.commit()
|
86
meme-rater/eval.py
Normal file
86
meme-rater/eval.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import torch.nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
import sqlite3
|
||||||
|
import random
|
||||||
|
import numpy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from model import Config, BradleyTerry
|
||||||
|
import shared
|
||||||
|
|
||||||
|
batch_size = 128
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
config = Config(
|
||||||
|
d_emb=1152,
|
||||||
|
n_hidden=1,
|
||||||
|
n_ensemble=16,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
dropout=0.1
|
||||||
|
)
|
||||||
|
model = BradleyTerry(config).float()
|
||||||
|
modelc, _ = shared.checkpoint_for(1500)
|
||||||
|
model.load_state_dict(torch.load(modelc))
|
||||||
|
params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
files = shared.fetch_all_files()
|
||||||
|
ratings = {}
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.inference_mode():
|
||||||
|
for bstart in tqdm(range(0, len(files), batch_size)):
|
||||||
|
batch = files[bstart:bstart + batch_size]
|
||||||
|
filenames = [ filename for filename, embedding in batch ]
|
||||||
|
embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
|
||||||
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
||||||
|
scores = model.ensemble(inputs).float()
|
||||||
|
mscores = torch.median(scores, dim=0).values
|
||||||
|
for filename, mscore in zip(filenames, mscores):
|
||||||
|
ratings[filename] = float(mscore)
|
||||||
|
|
||||||
|
ratings = sorted(ratings.items(), key=lambda x: x[1])
|
||||||
|
|
||||||
|
def percentile(p, n):
|
||||||
|
base = round(p * len(ratings))
|
||||||
|
return ratings[base:base + n]
|
||||||
|
|
||||||
|
N = 25
|
||||||
|
def render_memeset(p):
|
||||||
|
filenames = percentile(p, N)
|
||||||
|
return f"""
|
||||||
|
<div>
|
||||||
|
<details><summary>Reveal Memeset</summary>{p}</details>
|
||||||
|
{''.join(f'<div><img src="{"images/" + f}" width="30%"><br><input type=checkbox id="{"col-" + str(p) + "-" + str(i)}"></div>' for i, (f, s) in enumerate(filenames))}
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
buf = """<!DOCTYPE html>"""
|
||||||
|
probs = [0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 0.95, 0.98, 0.99]
|
||||||
|
random.shuffle(probs)
|
||||||
|
for p in probs:
|
||||||
|
#for p in [0.3]:
|
||||||
|
buf += render_memeset(p)
|
||||||
|
|
||||||
|
buf += """
|
||||||
|
<script>
|
||||||
|
const computeCounts = () => {
|
||||||
|
const counts = {}
|
||||||
|
for (const x of document.querySelectorAll("input[type=checkbox]")) {
|
||||||
|
const [_, percentile, index] = x.getAttribute("id").split("-")
|
||||||
|
counts[percentile] ??= 0
|
||||||
|
if (x.checked) counts[percentile] += 1
|
||||||
|
}
|
||||||
|
console.log(counts)
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open("eval.html", "w") as f:
|
||||||
|
f.write(buf)
|
33
meme-rater/final_eval_results.py
Normal file
33
meme-rater/final_eval_results.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Data as a JSON string
|
||||||
|
data_json = '{"0.95":22,"0.75":21,"0.5":15,"0.98":23,"0.25":3,"0.05":0,"0.99":24,"0.1":2,"0.01":0,"0.02":0}'
|
||||||
|
|
||||||
|
# Parse the JSON string into a dictionary
|
||||||
|
data = json.loads(data_json)
|
||||||
|
|
||||||
|
# Extract the keys and values from the dictionary
|
||||||
|
keys = list(data.keys())
|
||||||
|
values = list(data.values())
|
||||||
|
|
||||||
|
# Convert the keys to floats
|
||||||
|
keys = [float(key) for key in keys]
|
||||||
|
|
||||||
|
# Sort the keys and values based on the keys
|
||||||
|
sorted_data = sorted(zip(keys, values))
|
||||||
|
keys, values = zip(*sorted_data)
|
||||||
|
|
||||||
|
plt.plot(keys, values)
|
||||||
|
|
||||||
|
# Set the x-axis tick labels
|
||||||
|
plt.xticks(keys, rotation=45)
|
||||||
|
|
||||||
|
# Add labels and title
|
||||||
|
plt.xlabel('Percentile')
|
||||||
|
plt.ylabel('Memes Kept')
|
||||||
|
plt.title('Final Model Evaluation')
|
||||||
|
|
||||||
|
# Display the plot
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
52
meme-rater/model.py
Normal file
52
meme-rater/model.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
import math
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
d_emb: int
|
||||||
|
n_hidden: int
|
||||||
|
n_ensemble: int
|
||||||
|
device: str
|
||||||
|
dtype: torch.dtype
|
||||||
|
dropout: float
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
|
||||||
|
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
|
||||||
|
self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device)
|
||||||
|
|
||||||
|
def forward(self, embs):
|
||||||
|
x = embs
|
||||||
|
for (layer, dropout) in zip(self.hidden, self.dropout):
|
||||||
|
x = F.silu(layer(dropout(x)))
|
||||||
|
return self.output(x)
|
||||||
|
|
||||||
|
class Ensemble(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.models = nn.ModuleList([ Model(config) for i in range(config.n_ensemble) ])
|
||||||
|
|
||||||
|
# model batch
|
||||||
|
def forward(self, embs):
|
||||||
|
xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
|
||||||
|
return xs.squeeze(-1)
|
||||||
|
|
||||||
|
class BradleyTerry(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.ensemble = Ensemble(config)
|
||||||
|
|
||||||
|
def forward(self, embs): # model batch input=2 d_emb
|
||||||
|
scores1 = self.ensemble(embs[:, :, 0]).float() # model batch
|
||||||
|
scores2 = self.ensemble(embs[:, :, 1]).float()
|
||||||
|
# win probabilities
|
||||||
|
#print(scores1, scores2)
|
||||||
|
probs = torch.sigmoid(scores1 - scores2) # model batch
|
||||||
|
#print(probs)
|
||||||
|
return probs
|
@ -11,15 +11,22 @@ routes = web.RouteTableDef()
|
|||||||
|
|
||||||
async def get_pair(db):
|
async def get_pair(db):
|
||||||
while True:
|
while True:
|
||||||
|
csr = await db.execute("SELECT * FROM queue")
|
||||||
|
row = await csr.fetchone()
|
||||||
|
await csr.close()
|
||||||
|
iteration = None
|
||||||
|
if row:
|
||||||
|
m1, m2, iteration = row
|
||||||
|
else:
|
||||||
filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ]
|
filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ]
|
||||||
m1, m2 = tuple(sorted(random.sample(filenames, 2)))
|
m1, m2 = tuple(sorted(random.sample(filenames, 2)))
|
||||||
csr = await db.execute("SELECT 1 FROM ratings WHERE meme1 = ? AND meme2 = ?", (m1, m2))
|
csr = await db.execute("SELECT 1 FROM ratings WHERE meme1 = ? AND meme2 = ?", (m1, m2))
|
||||||
if not await csr.fetchone():
|
if not await csr.fetchone():
|
||||||
return m1, m2
|
return m1, m2, iteration
|
||||||
|
|
||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
async def index(request):
|
async def index(request):
|
||||||
meme1, meme2 = await get_pair(request.app["db"])
|
meme1, meme2, iteration = await get_pair(request.app["db"])
|
||||||
return web.Response(text=f"""
|
return web.Response(text=f"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
@ -46,6 +53,7 @@ async def index(request):
|
|||||||
|
|
||||||
<input type="hidden" name="meme1" value="{meme1}">
|
<input type="hidden" name="meme1" value="{meme1}">
|
||||||
<input type="hidden" name="meme2" value="{meme2}">
|
<input type="hidden" name="meme2" value="{meme2}">
|
||||||
|
<input type="hidden" name="iteration" value="{str(iteration or 0)}">
|
||||||
<input type="submit" value="Submit">
|
<input type="submit" value="Submit">
|
||||||
<div class="memes">
|
<div class="memes">
|
||||||
<img src="/memes/{meme1}" id="meme1">
|
<img src="/memes/{meme1}" id="meme1">
|
||||||
@ -81,8 +89,10 @@ async def rate(request):
|
|||||||
post = await request.post()
|
post = await request.post()
|
||||||
meme1 = post["meme1"]
|
meme1 = post["meme1"]
|
||||||
meme2 = post["meme2"]
|
meme2 = post["meme2"]
|
||||||
|
iteration = post["iteration"]
|
||||||
rating = post["rating"]
|
rating = post["rating"]
|
||||||
await db.execute("INSERT INTO ratings (meme1, meme2, rating) VALUES (?, ?, ?)", (meme1, meme2, rating))
|
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration))
|
||||||
|
await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return web.HTTPFound("/")
|
return web.HTTPFound("/")
|
||||||
|
|
||||||
@ -93,8 +103,15 @@ CREATE TABLE IF NOT EXISTS ratings (
|
|||||||
meme1 TEXT NOT NULL,
|
meme1 TEXT NOT NULL,
|
||||||
meme2 TEXT NOT NULL,
|
meme2 TEXT NOT NULL,
|
||||||
rating TEXT NOT NULL,
|
rating TEXT NOT NULL,
|
||||||
|
iteration TEXT,
|
||||||
UNIQUE (meme1, meme2)
|
UNIQUE (meme1, meme2)
|
||||||
);
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS queue (
|
||||||
|
meme1 TEXT NOT NULL,
|
||||||
|
meme2 TEXT NOT NULL,
|
||||||
|
iteration TEXT NOT NULL,
|
||||||
|
UNIQUE (meme1, meme2, iteration)
|
||||||
|
);
|
||||||
""")
|
""")
|
||||||
app.router.add_routes(routes)
|
app.router.add_routes(routes)
|
||||||
app.router.add_static("/memes/", "./images")
|
app.router.add_static("/memes/", "./images")
|
||||||
|
52
meme-rater/run_graph.py
Normal file
52
meme-rater/run_graph.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# claude-3
|
||||||
|
|
||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Read data from log.jsonl
|
||||||
|
data = []
|
||||||
|
with open('log.jsonl', 'r') as file:
|
||||||
|
for line in file:
|
||||||
|
data.append(json.loads(line))
|
||||||
|
|
||||||
|
# Extract steps, loss, and val_loss
|
||||||
|
steps = [entry['step'] for entry in data if "loss" in entry]
|
||||||
|
loss = [entry['loss'] for entry in data if "loss" in entry]
|
||||||
|
val_loss_data = [entry['val_loss'] for entry in data if 'val_loss' in entry]
|
||||||
|
val_steps = [entry['step'] for entry in data if 'val_loss' in entry]
|
||||||
|
|
||||||
|
# Extract individual validation loss series
|
||||||
|
val_loss_series = {}
|
||||||
|
for val_loss in val_loss_data:
|
||||||
|
for key, value in val_loss.items():
|
||||||
|
if key not in val_loss_series:
|
||||||
|
val_loss_series[key] = []
|
||||||
|
val_loss_series[key].append(value)
|
||||||
|
|
||||||
|
# Calculate rolling average for loss
|
||||||
|
window_size = 50
|
||||||
|
rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)]
|
||||||
|
rolling_steps = steps[window_size-1:]
|
||||||
|
|
||||||
|
# Calculate rolling averages for validation loss series
|
||||||
|
val_rolling_avgs = {}
|
||||||
|
for key, series in val_loss_series.items():
|
||||||
|
val_rolling_avgs[key] = [sum(series[i:i+window_size])/window_size for i in range(len(series)-window_size+1)]
|
||||||
|
|
||||||
|
print([(name, min(series)) for name, series in val_loss_series.items()])
|
||||||
|
|
||||||
|
# Create the plot
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
#plt.plot(steps, loss, label='Loss')
|
||||||
|
plt.plot(rolling_steps, rolling_avg, label='Rolling Average (Loss)')
|
||||||
|
|
||||||
|
for key, series in val_loss_series.items():
|
||||||
|
#plt.plot(val_steps, series, marker='o', linestyle='', label=f'Validation Loss ({key})')
|
||||||
|
plt.plot(val_steps[window_size-1:], val_rolling_avgs[key], label=f'Rolling Average (Validation Loss {key})')
|
||||||
|
|
||||||
|
plt.xlabel('Steps')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.title('Loss and Validation Loss vs. Steps')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
plt.show()
|
53
meme-rater/shared.py
Normal file
53
meme-rater/shared.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import sqlite3
|
||||||
|
import hashlib
|
||||||
|
from collections import defaultdict
|
||||||
|
import numpy
|
||||||
|
import random
|
||||||
|
|
||||||
|
db = sqlite3.connect("data.sqlite3")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
val_fraction = 0.2
|
||||||
|
def is_val_set(meme1, meme2):
|
||||||
|
def is_one_val(meme):
|
||||||
|
return hashlib.sha256(meme.encode("utf-8")).digest()[0] / 255 < (val_fraction / 2) # not strictly correct but good enough
|
||||||
|
return is_one_val(meme1) or is_one_val(meme2)
|
||||||
|
|
||||||
|
def fetch_embedding(filename):
|
||||||
|
csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,))
|
||||||
|
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
|
||||||
|
csr.close()
|
||||||
|
return x.copy() # PyTorch complains otherwise due to bad
|
||||||
|
|
||||||
|
def map_rating(rating, uncertainty=0.05):
|
||||||
|
match rating:
|
||||||
|
case "1": # meme 1 is better
|
||||||
|
return 1 - uncertainty
|
||||||
|
case "2":
|
||||||
|
return uncertainty
|
||||||
|
case _: raise ValueError("invalid rating, please fix")
|
||||||
|
|
||||||
|
def fetch_ratings():
|
||||||
|
trains = defaultdict(list)
|
||||||
|
validations = defaultdict(list)
|
||||||
|
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
||||||
|
for meme1, meme2, rating, iteration in csr.fetchall():
|
||||||
|
(validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
||||||
|
csr.close()
|
||||||
|
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
|
||||||
|
|
||||||
|
def generate_random_permutations(x, n):
|
||||||
|
out = []
|
||||||
|
for _ in range(n):
|
||||||
|
random.shuffle(x)
|
||||||
|
out.append(x.copy())
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fetch_all_files():
|
||||||
|
csr = db.execute("SELECT filename, embedding_vector FROM files")
|
||||||
|
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
|
||||||
|
csr.close()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def checkpoint_for(steps):
|
||||||
|
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"
|
129
meme-rater/train.py
Normal file
129
meme-rater/train.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import torch.nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
import sqlite3
|
||||||
|
import random
|
||||||
|
import numpy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
from model import Config as ModelConfig, BradleyTerry
|
||||||
|
import shared
|
||||||
|
|
||||||
|
trains, validations = shared.fetch_ratings()
|
||||||
|
for train, validation in zip(trains, validations):
|
||||||
|
print(len(train), len(validation))
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainConfig:
|
||||||
|
model: ModelConfig
|
||||||
|
lr: float
|
||||||
|
weight_decay: float
|
||||||
|
batch_size: int
|
||||||
|
epochs: int
|
||||||
|
compile: bool
|
||||||
|
data_grouped_by_iter: bool
|
||||||
|
|
||||||
|
config = TrainConfig(
|
||||||
|
model=ModelConfig(
|
||||||
|
d_emb=1152,
|
||||||
|
n_hidden=1,
|
||||||
|
n_ensemble=16,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
dropout=0.1
|
||||||
|
),
|
||||||
|
lr=3e-4,
|
||||||
|
weight_decay=0.2,
|
||||||
|
batch_size=1,
|
||||||
|
epochs=5,
|
||||||
|
compile=False,
|
||||||
|
data_grouped_by_iter=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def exprange(min, max, n):
|
||||||
|
lmin, lmax = math.log(min), math.log(max)
|
||||||
|
step = (lmax - lmin) / (n - 1)
|
||||||
|
return (math.exp(lmin + step * i) for i in range(n))
|
||||||
|
|
||||||
|
model = BradleyTerry(config.model)
|
||||||
|
params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
|
||||||
|
|
||||||
|
def train_step(model, batch, real):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# model batch
|
||||||
|
win_probabilities = model(batch).float()
|
||||||
|
loss = F.binary_cross_entropy(win_probabilities, real)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
loss = loss.detach().cpu().item()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
if config.compile:
|
||||||
|
print("compiling...")
|
||||||
|
train_step = torch.compile(train_step)
|
||||||
|
|
||||||
|
def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]):
|
||||||
|
batch_input = torch.stack([
|
||||||
|
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
|
||||||
|
for input in inputs
|
||||||
|
]).to(device)
|
||||||
|
target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device)
|
||||||
|
return batch_input, target
|
||||||
|
|
||||||
|
def evaluate(steps):
|
||||||
|
print("evaluating...")
|
||||||
|
model.eval()
|
||||||
|
results = {"step": steps, "time": time.time(), "val_loss": {}}
|
||||||
|
for vset, validation in enumerate(validations):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_input, target = batch_from_inputs([ validation[:128] for _ in range(config.model.n_ensemble) ])
|
||||||
|
result = model(batch_input).float()
|
||||||
|
val_loss = F.binary_cross_entropy(result, target).detach().cpu().item()
|
||||||
|
model.train()
|
||||||
|
results["val_loss"][vset] = val_loss
|
||||||
|
log.write(json.dumps(results) + "\n")
|
||||||
|
|
||||||
|
def save_ckpt(log, steps):
|
||||||
|
print("saving...")
|
||||||
|
modelc, optimc = shared.checkpoint_for(steps)
|
||||||
|
torch.save(optimizer.state_dict(), optimc)
|
||||||
|
torch.save(model.state_dict(), modelc)
|
||||||
|
|
||||||
|
class JSONEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o):
|
||||||
|
if isinstance(o, torch.dtype):
|
||||||
|
return str(o)
|
||||||
|
else: return super().default(o)
|
||||||
|
|
||||||
|
logfile = f"logs/log-{time.time()}.jsonl"
|
||||||
|
with open(logfile, "w") as log:
|
||||||
|
steps = 0
|
||||||
|
log.write(JSONEncoder().encode(asdict(config)) + "\n")
|
||||||
|
for epoch in range(config.epochs):
|
||||||
|
for train in (trains if config.data_grouped_by_iter else [[ sample for trainss in trains for sample in trainss ]]):
|
||||||
|
data_orders = shared.generate_random_permutations(train, config.model.n_ensemble)
|
||||||
|
for bstart in range(0, len(train), config.batch_size):
|
||||||
|
batch_input, target = batch_from_inputs([ order[bstart:bstart + config.batch_size] for order in data_orders ])
|
||||||
|
loss = train_step(model, batch_input, target)
|
||||||
|
print(steps, loss)
|
||||||
|
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
||||||
|
if steps % 10 == 0:
|
||||||
|
if steps % 250 == 0: save_ckpt(log, steps)
|
||||||
|
loss = evaluate(steps)
|
||||||
|
#print(loss)
|
||||||
|
#best = min(loss, best)
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
save_ckpt(log, steps)
|
||||||
|
|
||||||
|
print(logfile)
|
Loading…
x
Reference in New Issue
Block a user