diff --git a/meme-rater/active_learning.py b/meme-rater/active_learning.py
new file mode 100644
index 0000000..fae847c
--- /dev/null
+++ b/meme-rater/active_learning.py
@@ -0,0 +1,55 @@
+import torch.nn
+import torch.nn.functional as F
+import torch
+import sqlite3
+import random
+import numpy
+import json
+import time
+from tqdm import tqdm
+
+from model import Config, BradleyTerry
+import shared
+
+batch_size = 128
+num_pairs = batch_size * 1024
+device = "cuda"
+
+config = Config(
+ d_emb=1152,
+ n_hidden=1,
+ n_ensemble=16,
+ device=device,
+ dtype=torch.bfloat16,
+ dropout=0.5
+)
+model = BradleyTerry(config)
+modelc, _ = shared.checkpoint_for(2250)
+model.load_state_dict(torch.load(modelc))
+params = sum(p.numel() for p in model.parameters())
+print(f"{params/1e6:.1f}M parameters")
+print(model)
+
+files = shared.fetch_all_files()
+variance = {}
+
+pairs = []
+for _ in range(num_pairs):
+ pairs.append(tuple(random.sample(files, 2)))
+
+model.eval()
+with torch.inference_mode():
+ for bstart in tqdm(range(0, len(pairs), batch_size)):
+ batch = pairs[bstart:bstart + batch_size]
+ filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
+ embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
+ inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
+ win_probs = model(inputs)
+ #print(win_probs.shape)
+ batchvar = torch.var(win_probs, dim=0)
+ for filename, var in zip(filenames, batchvar):
+ variance[filename] = float(var)
+
+top = sorted(variance.items(), key=lambda x: -x[1])
+with open("top.json", "w") as f:
+ json.dump(top[:256], f)
\ No newline at end of file
diff --git a/meme-rater/al2.py b/meme-rater/al2.py
new file mode 100644
index 0000000..2410fa4
--- /dev/null
+++ b/meme-rater/al2.py
@@ -0,0 +1,76 @@
+import torch.nn
+import torch.nn.functional as F
+import torch
+import sqlite3
+import random
+import numpy
+import json
+import time
+from tqdm import tqdm
+from torch.func import functional_call, vmap, grad
+
+from model import Config, BradleyTerry
+import shared
+
+steps = 855
+batch_size = 128
+num_pairs = batch_size * 1024
+device = "cuda"
+
+config = Config(
+ d_emb=1152,
+ n_hidden=1,
+ n_ensemble=1,
+ device=device,
+ dtype=torch.bfloat16
+)
+model = BradleyTerry(config)
+modelc, _ = shared.checkpoint_for(855)
+model.load_state_dict(torch.load(modelc))
+params = sum(p.numel() for p in model.parameters())
+print(f"{params/1e6:.1f}M parameters")
+print(model)
+
+files = shared.fetch_all_files()
+importance = {}
+
+params = {k: v.detach() for k, v in model.named_parameters()}
+buffers = {k: v.detach() for k, v in model.named_buffers()}
+
+# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
+def compute_loss(params, buffers, sample, target):
+ batch = sample.unsqueeze(0)
+ targets = target.unsqueeze(0)
+
+ predictions = functional_call(model, (params, buffers), (batch,))
+ loss = F.binary_cross_entropy(predictions, targets)
+ return loss
+
+ft_compute_grad = grad(compute_loss)
+ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1))
+
+pairs = []
+for _ in range(num_pairs):
+ pairs.append(tuple(random.sample(files, 2)))
+
+for bstart in tqdm(range(0, len(pairs), batch_size)):
+ batch = pairs[bstart:bstart + batch_size]
+ filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
+ embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
+ inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
+ #win_probs = model(inputs)
+ # TODO gradients
+ # don't take variance: do backwards pass and compute gradient norm
+ grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device))
+ total_grad_norms = torch.zeros(batch_size).to(device)
+ for k, v in grads.items():
+ param_dims = tuple(range(1, len(v.shape)))
+ total_grad_norms += torch.linalg.vector_norm(v, dim=param_dims)
+ tgn = total_grad_norms.cpu().numpy()
+
+ for filename, tg in zip(filenames, tgn):
+ importance[filename] = float(tg)
+
+top = sorted(importance.items(), key=lambda x: -x[1])
+with open("top.json", "w") as f:
+ json.dump(top[:256], f)
\ No newline at end of file
diff --git a/meme-rater/copy_into_queue.py b/meme-rater/copy_into_queue.py
new file mode 100644
index 0000000..b6bd084
--- /dev/null
+++ b/meme-rater/copy_into_queue.py
@@ -0,0 +1,14 @@
+import sqlite3
+import json
+import sys
+
+iteration = sys.argv[1]
+
+db = sqlite3.connect("data.sqlite3")
+db.row_factory = sqlite3.Row
+
+with open("top.json", "r") as f:
+ listing = json.load(f)
+
+db.executemany("INSERT INTO queue VALUES (?, ?, ?)", [ (x[0], x[1], iteration) for x, v in listing ])
+db.commit()
\ No newline at end of file
diff --git a/meme-rater/eval.py b/meme-rater/eval.py
new file mode 100644
index 0000000..e35be1d
--- /dev/null
+++ b/meme-rater/eval.py
@@ -0,0 +1,86 @@
+import torch.nn
+import torch.nn.functional as F
+import torch
+import sqlite3
+import random
+import numpy
+import json
+import time
+from tqdm import tqdm
+import torch
+
+from model import Config, BradleyTerry
+import shared
+
+batch_size = 128
+device = "cuda"
+
+config = Config(
+ d_emb=1152,
+ n_hidden=1,
+ n_ensemble=16,
+ device=device,
+ dtype=torch.float32,
+ dropout=0.1
+)
+model = BradleyTerry(config).float()
+modelc, _ = shared.checkpoint_for(1500)
+model.load_state_dict(torch.load(modelc))
+params = sum(p.numel() for p in model.parameters())
+print(f"{params/1e6:.1f}M parameters")
+print(model)
+
+files = shared.fetch_all_files()
+ratings = {}
+
+model.eval()
+with torch.inference_mode():
+ for bstart in tqdm(range(0, len(files), batch_size)):
+ batch = files[bstart:bstart + batch_size]
+ filenames = [ filename for filename, embedding in batch ]
+ embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
+ inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
+ scores = model.ensemble(inputs).float()
+ mscores = torch.median(scores, dim=0).values
+ for filename, mscore in zip(filenames, mscores):
+ ratings[filename] = float(mscore)
+
+ratings = sorted(ratings.items(), key=lambda x: x[1])
+
+def percentile(p, n):
+ base = round(p * len(ratings))
+ return ratings[base:base + n]
+
+N = 25
+def render_memeset(p):
+ filenames = percentile(p, N)
+ return f"""
+
+
Reveal Memeset
{p}
+ {''.join(f'
' for i, (f, s) in enumerate(filenames))}
+
+"""
+
+buf = """"""
+probs = [0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 0.95, 0.98, 0.99]
+random.shuffle(probs)
+for p in probs:
+#for p in [0.3]:
+ buf += render_memeset(p)
+
+buf += """
+
+"""
+
+with open("eval.html", "w") as f:
+ f.write(buf)
\ No newline at end of file
diff --git a/meme-rater/final_eval_results.py b/meme-rater/final_eval_results.py
new file mode 100644
index 0000000..aed2262
--- /dev/null
+++ b/meme-rater/final_eval_results.py
@@ -0,0 +1,33 @@
+import matplotlib.pyplot as plt
+import json
+
+# Data as a JSON string
+data_json = '{"0.95":22,"0.75":21,"0.5":15,"0.98":23,"0.25":3,"0.05":0,"0.99":24,"0.1":2,"0.01":0,"0.02":0}'
+
+# Parse the JSON string into a dictionary
+data = json.loads(data_json)
+
+# Extract the keys and values from the dictionary
+keys = list(data.keys())
+values = list(data.values())
+
+# Convert the keys to floats
+keys = [float(key) for key in keys]
+
+# Sort the keys and values based on the keys
+sorted_data = sorted(zip(keys, values))
+keys, values = zip(*sorted_data)
+
+plt.plot(keys, values)
+
+# Set the x-axis tick labels
+plt.xticks(keys, rotation=45)
+
+# Add labels and title
+plt.xlabel('Percentile')
+plt.ylabel('Memes Kept')
+plt.title('Final Model Evaluation')
+
+# Display the plot
+plt.tight_layout()
+plt.show()
diff --git a/meme-rater/model.py b/meme-rater/model.py
new file mode 100644
index 0000000..1487ea3
--- /dev/null
+++ b/meme-rater/model.py
@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dataclasses import dataclass
+from functools import partial
+import math
+
+@dataclass
+class Config:
+ d_emb: int
+ n_hidden: int
+ n_ensemble: int
+ device: str
+ dtype: torch.dtype
+ dropout: float
+
+class Model(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
+ self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
+ self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device)
+
+ def forward(self, embs):
+ x = embs
+ for (layer, dropout) in zip(self.hidden, self.dropout):
+ x = F.silu(layer(dropout(x)))
+ return self.output(x)
+
+class Ensemble(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.models = nn.ModuleList([ Model(config) for i in range(config.n_ensemble) ])
+
+ # model batch
+ def forward(self, embs):
+ xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
+ return xs.squeeze(-1)
+
+class BradleyTerry(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.ensemble = Ensemble(config)
+
+ def forward(self, embs): # model batch input=2 d_emb
+ scores1 = self.ensemble(embs[:, :, 0]).float() # model batch
+ scores2 = self.ensemble(embs[:, :, 1]).float()
+ # win probabilities
+ #print(scores1, scores2)
+ probs = torch.sigmoid(scores1 - scores2) # model batch
+ #print(probs)
+ return probs
\ No newline at end of file
diff --git a/meme-rater/rater_server.py b/meme-rater/rater_server.py
index 5f41cab..f003482 100644
--- a/meme-rater/rater_server.py
+++ b/meme-rater/rater_server.py
@@ -11,15 +11,22 @@ routes = web.RouteTableDef()
async def get_pair(db):
while True:
- filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ]
- m1, m2 = tuple(sorted(random.sample(filenames, 2)))
+ csr = await db.execute("SELECT * FROM queue")
+ row = await csr.fetchone()
+ await csr.close()
+ iteration = None
+ if row:
+ m1, m2, iteration = row
+ else:
+ filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ]
+ m1, m2 = tuple(sorted(random.sample(filenames, 2)))
csr = await db.execute("SELECT 1 FROM ratings WHERE meme1 = ? AND meme2 = ?", (m1, m2))
if not await csr.fetchone():
- return m1, m2
+ return m1, m2, iteration
@routes.get("/")
async def index(request):
- meme1, meme2 = await get_pair(request.app["db"])
+ meme1, meme2, iteration = await get_pair(request.app["db"])
return web.Response(text=f"""
@@ -46,6 +53,7 @@ async def index(request):
+
@@ -81,8 +89,10 @@ async def rate(request):
post = await request.post()
meme1 = post["meme1"]
meme2 = post["meme2"]
+ iteration = post["iteration"]
rating = post["rating"]
- await db.execute("INSERT INTO ratings (meme1, meme2, rating) VALUES (?, ?, ?)", (meme1, meme2, rating))
+ await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration))
+ await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
await db.commit()
return web.HTTPFound("/")
@@ -93,8 +103,15 @@ CREATE TABLE IF NOT EXISTS ratings (
meme1 TEXT NOT NULL,
meme2 TEXT NOT NULL,
rating TEXT NOT NULL,
+ iteration TEXT,
UNIQUE (meme1, meme2)
);
+CREATE TABLE IF NOT EXISTS queue (
+ meme1 TEXT NOT NULL,
+ meme2 TEXT NOT NULL,
+ iteration TEXT NOT NULL,
+ UNIQUE (meme1, meme2, iteration)
+);
""")
app.router.add_routes(routes)
app.router.add_static("/memes/", "./images")
diff --git a/meme-rater/run_graph.py b/meme-rater/run_graph.py
new file mode 100644
index 0000000..0f152b5
--- /dev/null
+++ b/meme-rater/run_graph.py
@@ -0,0 +1,52 @@
+# claude-3
+
+import json
+import matplotlib.pyplot as plt
+
+# Read data from log.jsonl
+data = []
+with open('log.jsonl', 'r') as file:
+ for line in file:
+ data.append(json.loads(line))
+
+# Extract steps, loss, and val_loss
+steps = [entry['step'] for entry in data if "loss" in entry]
+loss = [entry['loss'] for entry in data if "loss" in entry]
+val_loss_data = [entry['val_loss'] for entry in data if 'val_loss' in entry]
+val_steps = [entry['step'] for entry in data if 'val_loss' in entry]
+
+# Extract individual validation loss series
+val_loss_series = {}
+for val_loss in val_loss_data:
+ for key, value in val_loss.items():
+ if key not in val_loss_series:
+ val_loss_series[key] = []
+ val_loss_series[key].append(value)
+
+# Calculate rolling average for loss
+window_size = 50
+rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)]
+rolling_steps = steps[window_size-1:]
+
+# Calculate rolling averages for validation loss series
+val_rolling_avgs = {}
+for key, series in val_loss_series.items():
+ val_rolling_avgs[key] = [sum(series[i:i+window_size])/window_size for i in range(len(series)-window_size+1)]
+
+print([(name, min(series)) for name, series in val_loss_series.items()])
+
+# Create the plot
+plt.figure(figsize=(10, 6))
+#plt.plot(steps, loss, label='Loss')
+plt.plot(rolling_steps, rolling_avg, label='Rolling Average (Loss)')
+
+for key, series in val_loss_series.items():
+ #plt.plot(val_steps, series, marker='o', linestyle='', label=f'Validation Loss ({key})')
+ plt.plot(val_steps[window_size-1:], val_rolling_avgs[key], label=f'Rolling Average (Validation Loss {key})')
+
+plt.xlabel('Steps')
+plt.ylabel('Loss')
+plt.title('Loss and Validation Loss vs. Steps')
+plt.legend()
+plt.grid(True)
+plt.show()
diff --git a/meme-rater/shared.py b/meme-rater/shared.py
new file mode 100644
index 0000000..a8398d7
--- /dev/null
+++ b/meme-rater/shared.py
@@ -0,0 +1,53 @@
+import sqlite3
+import hashlib
+from collections import defaultdict
+import numpy
+import random
+
+db = sqlite3.connect("data.sqlite3")
+db.row_factory = sqlite3.Row
+
+val_fraction = 0.2
+def is_val_set(meme1, meme2):
+ def is_one_val(meme):
+ return hashlib.sha256(meme.encode("utf-8")).digest()[0] / 255 < (val_fraction / 2) # not strictly correct but good enough
+ return is_one_val(meme1) or is_one_val(meme2)
+
+def fetch_embedding(filename):
+ csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,))
+ x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
+ csr.close()
+ return x.copy() # PyTorch complains otherwise due to bad
+
+def map_rating(rating, uncertainty=0.05):
+ match rating:
+ case "1": # meme 1 is better
+ return 1 - uncertainty
+ case "2":
+ return uncertainty
+ case _: raise ValueError("invalid rating, please fix")
+
+def fetch_ratings():
+ trains = defaultdict(list)
+ validations = defaultdict(list)
+ csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
+ for meme1, meme2, rating, iteration in csr.fetchall():
+ (validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
+ csr.close()
+ return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
+
+def generate_random_permutations(x, n):
+ out = []
+ for _ in range(n):
+ random.shuffle(x)
+ out.append(x.copy())
+ return out
+
+def fetch_all_files():
+ csr = db.execute("SELECT filename, embedding_vector FROM files")
+ x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
+ csr.close()
+ return x
+
+def checkpoint_for(steps):
+ return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"
\ No newline at end of file
diff --git a/meme-rater/train.py b/meme-rater/train.py
new file mode 100644
index 0000000..0acde88
--- /dev/null
+++ b/meme-rater/train.py
@@ -0,0 +1,129 @@
+import torch.nn
+import torch.nn.functional as F
+import torch
+import sqlite3
+import random
+import numpy
+import json
+import time
+from tqdm import tqdm
+import math
+from dataclasses import dataclass, asdict
+
+from model import Config as ModelConfig, BradleyTerry
+import shared
+
+trains, validations = shared.fetch_ratings()
+for train, validation in zip(trains, validations):
+ print(len(train), len(validation))
+
+device = "cuda"
+
+@dataclass
+class TrainConfig:
+ model: ModelConfig
+ lr: float
+ weight_decay: float
+ batch_size: int
+ epochs: int
+ compile: bool
+ data_grouped_by_iter: bool
+
+config = TrainConfig(
+ model=ModelConfig(
+ d_emb=1152,
+ n_hidden=1,
+ n_ensemble=16,
+ device=device,
+ dtype=torch.float32,
+ dropout=0.1
+ ),
+ lr=3e-4,
+ weight_decay=0.2,
+ batch_size=1,
+ epochs=5,
+ compile=False,
+ data_grouped_by_iter=False
+)
+
+def exprange(min, max, n):
+ lmin, lmax = math.log(min), math.log(max)
+ step = (lmax - lmin) / (n - 1)
+ return (math.exp(lmin + step * i) for i in range(n))
+
+model = BradleyTerry(config.model)
+params = sum(p.numel() for p in model.parameters())
+print(f"{params/1e6:.1f}M parameters")
+print(model)
+
+optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
+
+def train_step(model, batch, real):
+ optimizer.zero_grad()
+ # model batch
+ win_probabilities = model(batch).float()
+ loss = F.binary_cross_entropy(win_probabilities, real)
+ loss.backward()
+ optimizer.step()
+ loss = loss.detach().cpu().item()
+ return loss
+
+if config.compile:
+ print("compiling...")
+ train_step = torch.compile(train_step)
+
+def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]):
+ batch_input = torch.stack([
+ torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
+ for input in inputs
+ ]).to(device)
+ target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device)
+ return batch_input, target
+
+def evaluate(steps):
+ print("evaluating...")
+ model.eval()
+ results = {"step": steps, "time": time.time(), "val_loss": {}}
+ for vset, validation in enumerate(validations):
+ with torch.no_grad():
+ batch_input, target = batch_from_inputs([ validation[:128] for _ in range(config.model.n_ensemble) ])
+ result = model(batch_input).float()
+ val_loss = F.binary_cross_entropy(result, target).detach().cpu().item()
+ model.train()
+ results["val_loss"][vset] = val_loss
+ log.write(json.dumps(results) + "\n")
+
+def save_ckpt(log, steps):
+ print("saving...")
+ modelc, optimc = shared.checkpoint_for(steps)
+ torch.save(optimizer.state_dict(), optimc)
+ torch.save(model.state_dict(), modelc)
+
+class JSONEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, torch.dtype):
+ return str(o)
+ else: return super().default(o)
+
+logfile = f"logs/log-{time.time()}.jsonl"
+with open(logfile, "w") as log:
+ steps = 0
+ log.write(JSONEncoder().encode(asdict(config)) + "\n")
+ for epoch in range(config.epochs):
+ for train in (trains if config.data_grouped_by_iter else [[ sample for trainss in trains for sample in trainss ]]):
+ data_orders = shared.generate_random_permutations(train, config.model.n_ensemble)
+ for bstart in range(0, len(train), config.batch_size):
+ batch_input, target = batch_from_inputs([ order[bstart:bstart + config.batch_size] for order in data_orders ])
+ loss = train_step(model, batch_input, target)
+ print(steps, loss)
+ log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
+ if steps % 10 == 0:
+ if steps % 250 == 0: save_ckpt(log, steps)
+ loss = evaluate(steps)
+ #print(loss)
+ #best = min(loss, best)
+ steps += 1
+
+ save_ckpt(log, steps)
+
+print(logfile)
\ No newline at end of file