diff --git a/meme-rater/active_learning.py b/meme-rater/active_learning.py new file mode 100644 index 0000000..fae847c --- /dev/null +++ b/meme-rater/active_learning.py @@ -0,0 +1,55 @@ +import torch.nn +import torch.nn.functional as F +import torch +import sqlite3 +import random +import numpy +import json +import time +from tqdm import tqdm + +from model import Config, BradleyTerry +import shared + +batch_size = 128 +num_pairs = batch_size * 1024 +device = "cuda" + +config = Config( + d_emb=1152, + n_hidden=1, + n_ensemble=16, + device=device, + dtype=torch.bfloat16, + dropout=0.5 +) +model = BradleyTerry(config) +modelc, _ = shared.checkpoint_for(2250) +model.load_state_dict(torch.load(modelc)) +params = sum(p.numel() for p in model.parameters()) +print(f"{params/1e6:.1f}M parameters") +print(model) + +files = shared.fetch_all_files() +variance = {} + +pairs = [] +for _ in range(num_pairs): + pairs.append(tuple(random.sample(files, 2))) + +model.eval() +with torch.inference_mode(): + for bstart in tqdm(range(0, len(pairs), batch_size)): + batch = pairs[bstart:bstart + batch_size] + filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ] + embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ]) + inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device) + win_probs = model(inputs) + #print(win_probs.shape) + batchvar = torch.var(win_probs, dim=0) + for filename, var in zip(filenames, batchvar): + variance[filename] = float(var) + +top = sorted(variance.items(), key=lambda x: -x[1]) +with open("top.json", "w") as f: + json.dump(top[:256], f) \ No newline at end of file diff --git a/meme-rater/al2.py b/meme-rater/al2.py new file mode 100644 index 0000000..2410fa4 --- /dev/null +++ b/meme-rater/al2.py @@ -0,0 +1,76 @@ +import torch.nn +import torch.nn.functional as F +import torch +import sqlite3 +import random +import numpy +import json +import time +from tqdm import tqdm +from torch.func import functional_call, vmap, grad + +from model import Config, BradleyTerry +import shared + +steps = 855 +batch_size = 128 +num_pairs = batch_size * 1024 +device = "cuda" + +config = Config( + d_emb=1152, + n_hidden=1, + n_ensemble=1, + device=device, + dtype=torch.bfloat16 +) +model = BradleyTerry(config) +modelc, _ = shared.checkpoint_for(855) +model.load_state_dict(torch.load(modelc)) +params = sum(p.numel() for p in model.parameters()) +print(f"{params/1e6:.1f}M parameters") +print(model) + +files = shared.fetch_all_files() +importance = {} + +params = {k: v.detach() for k, v in model.named_parameters()} +buffers = {k: v.detach() for k, v in model.named_buffers()} + +# https://pytorch.org/tutorials/intermediate/per_sample_grads.html +def compute_loss(params, buffers, sample, target): + batch = sample.unsqueeze(0) + targets = target.unsqueeze(0) + + predictions = functional_call(model, (params, buffers), (batch,)) + loss = F.binary_cross_entropy(predictions, targets) + return loss + +ft_compute_grad = grad(compute_loss) +ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1)) + +pairs = [] +for _ in range(num_pairs): + pairs.append(tuple(random.sample(files, 2))) + +for bstart in tqdm(range(0, len(pairs), batch_size)): + batch = pairs[bstart:bstart + batch_size] + filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ] + embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ]) + inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device) + #win_probs = model(inputs) + # TODO gradients + # don't take variance: do backwards pass and compute gradient norm + grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device)) + total_grad_norms = torch.zeros(batch_size).to(device) + for k, v in grads.items(): + param_dims = tuple(range(1, len(v.shape))) + total_grad_norms += torch.linalg.vector_norm(v, dim=param_dims) + tgn = total_grad_norms.cpu().numpy() + + for filename, tg in zip(filenames, tgn): + importance[filename] = float(tg) + +top = sorted(importance.items(), key=lambda x: -x[1]) +with open("top.json", "w") as f: + json.dump(top[:256], f) \ No newline at end of file diff --git a/meme-rater/copy_into_queue.py b/meme-rater/copy_into_queue.py new file mode 100644 index 0000000..b6bd084 --- /dev/null +++ b/meme-rater/copy_into_queue.py @@ -0,0 +1,14 @@ +import sqlite3 +import json +import sys + +iteration = sys.argv[1] + +db = sqlite3.connect("data.sqlite3") +db.row_factory = sqlite3.Row + +with open("top.json", "r") as f: + listing = json.load(f) + +db.executemany("INSERT INTO queue VALUES (?, ?, ?)", [ (x[0], x[1], iteration) for x, v in listing ]) +db.commit() \ No newline at end of file diff --git a/meme-rater/eval.py b/meme-rater/eval.py new file mode 100644 index 0000000..e35be1d --- /dev/null +++ b/meme-rater/eval.py @@ -0,0 +1,86 @@ +import torch.nn +import torch.nn.functional as F +import torch +import sqlite3 +import random +import numpy +import json +import time +from tqdm import tqdm +import torch + +from model import Config, BradleyTerry +import shared + +batch_size = 128 +device = "cuda" + +config = Config( + d_emb=1152, + n_hidden=1, + n_ensemble=16, + device=device, + dtype=torch.float32, + dropout=0.1 +) +model = BradleyTerry(config).float() +modelc, _ = shared.checkpoint_for(1500) +model.load_state_dict(torch.load(modelc)) +params = sum(p.numel() for p in model.parameters()) +print(f"{params/1e6:.1f}M parameters") +print(model) + +files = shared.fetch_all_files() +ratings = {} + +model.eval() +with torch.inference_mode(): + for bstart in tqdm(range(0, len(files), batch_size)): + batch = files[bstart:bstart + batch_size] + filenames = [ filename for filename, embedding in batch ] + embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ]) + inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device) + scores = model.ensemble(inputs).float() + mscores = torch.median(scores, dim=0).values + for filename, mscore in zip(filenames, mscores): + ratings[filename] = float(mscore) + +ratings = sorted(ratings.items(), key=lambda x: x[1]) + +def percentile(p, n): + base = round(p * len(ratings)) + return ratings[base:base + n] + +N = 25 +def render_memeset(p): + filenames = percentile(p, N) + return f""" +
+
Reveal Memeset{p}
+ {''.join(f'

' for i, (f, s) in enumerate(filenames))} +
+""" + +buf = """""" +probs = [0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 0.95, 0.98, 0.99] +random.shuffle(probs) +for p in probs: +#for p in [0.3]: + buf += render_memeset(p) + +buf += """ + +""" + +with open("eval.html", "w") as f: + f.write(buf) \ No newline at end of file diff --git a/meme-rater/final_eval_results.py b/meme-rater/final_eval_results.py new file mode 100644 index 0000000..aed2262 --- /dev/null +++ b/meme-rater/final_eval_results.py @@ -0,0 +1,33 @@ +import matplotlib.pyplot as plt +import json + +# Data as a JSON string +data_json = '{"0.95":22,"0.75":21,"0.5":15,"0.98":23,"0.25":3,"0.05":0,"0.99":24,"0.1":2,"0.01":0,"0.02":0}' + +# Parse the JSON string into a dictionary +data = json.loads(data_json) + +# Extract the keys and values from the dictionary +keys = list(data.keys()) +values = list(data.values()) + +# Convert the keys to floats +keys = [float(key) for key in keys] + +# Sort the keys and values based on the keys +sorted_data = sorted(zip(keys, values)) +keys, values = zip(*sorted_data) + +plt.plot(keys, values) + +# Set the x-axis tick labels +plt.xticks(keys, rotation=45) + +# Add labels and title +plt.xlabel('Percentile') +plt.ylabel('Memes Kept') +plt.title('Final Model Evaluation') + +# Display the plot +plt.tight_layout() +plt.show() diff --git a/meme-rater/model.py b/meme-rater/model.py new file mode 100644 index 0000000..1487ea3 --- /dev/null +++ b/meme-rater/model.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from functools import partial +import math + +@dataclass +class Config: + d_emb: int + n_hidden: int + n_ensemble: int + device: str + dtype: torch.dtype + dropout: float + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ]) + self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ]) + self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device) + + def forward(self, embs): + x = embs + for (layer, dropout) in zip(self.hidden, self.dropout): + x = F.silu(layer(dropout(x))) + return self.output(x) + +class Ensemble(nn.Module): + def __init__(self, config): + super().__init__() + self.models = nn.ModuleList([ Model(config) for i in range(config.n_ensemble) ]) + + # model batch + def forward(self, embs): + xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 + return xs.squeeze(-1) + +class BradleyTerry(nn.Module): + def __init__(self, config): + super().__init__() + self.ensemble = Ensemble(config) + + def forward(self, embs): # model batch input=2 d_emb + scores1 = self.ensemble(embs[:, :, 0]).float() # model batch + scores2 = self.ensemble(embs[:, :, 1]).float() + # win probabilities + #print(scores1, scores2) + probs = torch.sigmoid(scores1 - scores2) # model batch + #print(probs) + return probs \ No newline at end of file diff --git a/meme-rater/rater_server.py b/meme-rater/rater_server.py index 5f41cab..f003482 100644 --- a/meme-rater/rater_server.py +++ b/meme-rater/rater_server.py @@ -11,15 +11,22 @@ routes = web.RouteTableDef() async def get_pair(db): while True: - filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ] - m1, m2 = tuple(sorted(random.sample(filenames, 2))) + csr = await db.execute("SELECT * FROM queue") + row = await csr.fetchone() + await csr.close() + iteration = None + if row: + m1, m2, iteration = row + else: + filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ] + m1, m2 = tuple(sorted(random.sample(filenames, 2))) csr = await db.execute("SELECT 1 FROM ratings WHERE meme1 = ? AND meme2 = ?", (m1, m2)) if not await csr.fetchone(): - return m1, m2 + return m1, m2, iteration @routes.get("/") async def index(request): - meme1, meme2 = await get_pair(request.app["db"]) + meme1, meme2, iteration = await get_pair(request.app["db"]) return web.Response(text=f""" @@ -46,6 +53,7 @@ async def index(request): +
@@ -81,8 +89,10 @@ async def rate(request): post = await request.post() meme1 = post["meme1"] meme2 = post["meme2"] + iteration = post["iteration"] rating = post["rating"] - await db.execute("INSERT INTO ratings (meme1, meme2, rating) VALUES (?, ?, ?)", (meme1, meme2, rating)) + await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration)) + await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2)) await db.commit() return web.HTTPFound("/") @@ -93,8 +103,15 @@ CREATE TABLE IF NOT EXISTS ratings ( meme1 TEXT NOT NULL, meme2 TEXT NOT NULL, rating TEXT NOT NULL, + iteration TEXT, UNIQUE (meme1, meme2) ); +CREATE TABLE IF NOT EXISTS queue ( + meme1 TEXT NOT NULL, + meme2 TEXT NOT NULL, + iteration TEXT NOT NULL, + UNIQUE (meme1, meme2, iteration) +); """) app.router.add_routes(routes) app.router.add_static("/memes/", "./images") diff --git a/meme-rater/run_graph.py b/meme-rater/run_graph.py new file mode 100644 index 0000000..0f152b5 --- /dev/null +++ b/meme-rater/run_graph.py @@ -0,0 +1,52 @@ +# claude-3 + +import json +import matplotlib.pyplot as plt + +# Read data from log.jsonl +data = [] +with open('log.jsonl', 'r') as file: + for line in file: + data.append(json.loads(line)) + +# Extract steps, loss, and val_loss +steps = [entry['step'] for entry in data if "loss" in entry] +loss = [entry['loss'] for entry in data if "loss" in entry] +val_loss_data = [entry['val_loss'] for entry in data if 'val_loss' in entry] +val_steps = [entry['step'] for entry in data if 'val_loss' in entry] + +# Extract individual validation loss series +val_loss_series = {} +for val_loss in val_loss_data: + for key, value in val_loss.items(): + if key not in val_loss_series: + val_loss_series[key] = [] + val_loss_series[key].append(value) + +# Calculate rolling average for loss +window_size = 50 +rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)] +rolling_steps = steps[window_size-1:] + +# Calculate rolling averages for validation loss series +val_rolling_avgs = {} +for key, series in val_loss_series.items(): + val_rolling_avgs[key] = [sum(series[i:i+window_size])/window_size for i in range(len(series)-window_size+1)] + +print([(name, min(series)) for name, series in val_loss_series.items()]) + +# Create the plot +plt.figure(figsize=(10, 6)) +#plt.plot(steps, loss, label='Loss') +plt.plot(rolling_steps, rolling_avg, label='Rolling Average (Loss)') + +for key, series in val_loss_series.items(): + #plt.plot(val_steps, series, marker='o', linestyle='', label=f'Validation Loss ({key})') + plt.plot(val_steps[window_size-1:], val_rolling_avgs[key], label=f'Rolling Average (Validation Loss {key})') + +plt.xlabel('Steps') +plt.ylabel('Loss') +plt.title('Loss and Validation Loss vs. Steps') +plt.legend() +plt.grid(True) +plt.show() diff --git a/meme-rater/shared.py b/meme-rater/shared.py new file mode 100644 index 0000000..a8398d7 --- /dev/null +++ b/meme-rater/shared.py @@ -0,0 +1,53 @@ +import sqlite3 +import hashlib +from collections import defaultdict +import numpy +import random + +db = sqlite3.connect("data.sqlite3") +db.row_factory = sqlite3.Row + +val_fraction = 0.2 +def is_val_set(meme1, meme2): + def is_one_val(meme): + return hashlib.sha256(meme.encode("utf-8")).digest()[0] / 255 < (val_fraction / 2) # not strictly correct but good enough + return is_one_val(meme1) or is_one_val(meme2) + +def fetch_embedding(filename): + csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,)) + x = numpy.frombuffer(csr.fetchone()[0], dtype="float16") + csr.close() + return x.copy() # PyTorch complains otherwise due to bad + +def map_rating(rating, uncertainty=0.05): + match rating: + case "1": # meme 1 is better + return 1 - uncertainty + case "2": + return uncertainty + case _: raise ValueError("invalid rating, please fix") + +def fetch_ratings(): + trains = defaultdict(list) + validations = defaultdict(list) + csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings") + for meme1, meme2, rating, iteration in csr.fetchall(): + (validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating))) + csr.close() + return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items())) + +def generate_random_permutations(x, n): + out = [] + for _ in range(n): + random.shuffle(x) + out.append(x.copy()) + return out + +def fetch_all_files(): + csr = db.execute("SELECT filename, embedding_vector FROM files") + x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ] + csr.close() + return x + +def checkpoint_for(steps): + return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt" \ No newline at end of file diff --git a/meme-rater/train.py b/meme-rater/train.py new file mode 100644 index 0000000..0acde88 --- /dev/null +++ b/meme-rater/train.py @@ -0,0 +1,129 @@ +import torch.nn +import torch.nn.functional as F +import torch +import sqlite3 +import random +import numpy +import json +import time +from tqdm import tqdm +import math +from dataclasses import dataclass, asdict + +from model import Config as ModelConfig, BradleyTerry +import shared + +trains, validations = shared.fetch_ratings() +for train, validation in zip(trains, validations): + print(len(train), len(validation)) + +device = "cuda" + +@dataclass +class TrainConfig: + model: ModelConfig + lr: float + weight_decay: float + batch_size: int + epochs: int + compile: bool + data_grouped_by_iter: bool + +config = TrainConfig( + model=ModelConfig( + d_emb=1152, + n_hidden=1, + n_ensemble=16, + device=device, + dtype=torch.float32, + dropout=0.1 + ), + lr=3e-4, + weight_decay=0.2, + batch_size=1, + epochs=5, + compile=False, + data_grouped_by_iter=False +) + +def exprange(min, max, n): + lmin, lmax = math.log(min), math.log(max) + step = (lmax - lmin) / (n - 1) + return (math.exp(lmin + step * i) for i in range(n)) + +model = BradleyTerry(config.model) +params = sum(p.numel() for p in model.parameters()) +print(f"{params/1e6:.1f}M parameters") +print(model) + +optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) + +def train_step(model, batch, real): + optimizer.zero_grad() + # model batch + win_probabilities = model(batch).float() + loss = F.binary_cross_entropy(win_probabilities, real) + loss.backward() + optimizer.step() + loss = loss.detach().cpu().item() + return loss + +if config.compile: + print("compiling...") + train_step = torch.compile(train_step) + +def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]): + batch_input = torch.stack([ + torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ]) + for input in inputs + ]).to(device) + target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device) + return batch_input, target + +def evaluate(steps): + print("evaluating...") + model.eval() + results = {"step": steps, "time": time.time(), "val_loss": {}} + for vset, validation in enumerate(validations): + with torch.no_grad(): + batch_input, target = batch_from_inputs([ validation[:128] for _ in range(config.model.n_ensemble) ]) + result = model(batch_input).float() + val_loss = F.binary_cross_entropy(result, target).detach().cpu().item() + model.train() + results["val_loss"][vset] = val_loss + log.write(json.dumps(results) + "\n") + +def save_ckpt(log, steps): + print("saving...") + modelc, optimc = shared.checkpoint_for(steps) + torch.save(optimizer.state_dict(), optimc) + torch.save(model.state_dict(), modelc) + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, torch.dtype): + return str(o) + else: return super().default(o) + +logfile = f"logs/log-{time.time()}.jsonl" +with open(logfile, "w") as log: + steps = 0 + log.write(JSONEncoder().encode(asdict(config)) + "\n") + for epoch in range(config.epochs): + for train in (trains if config.data_grouped_by_iter else [[ sample for trainss in trains for sample in trainss ]]): + data_orders = shared.generate_random_permutations(train, config.model.n_ensemble) + for bstart in range(0, len(train), config.batch_size): + batch_input, target = batch_from_inputs([ order[bstart:bstart + config.batch_size] for order in data_orders ]) + loss = train_step(model, batch_input, target) + print(steps, loss) + log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") + if steps % 10 == 0: + if steps % 250 == 0: save_ckpt(log, steps) + loss = evaluate(steps) + #print(loss) + #best = min(loss, best) + steps += 1 + + save_ckpt(log, steps) + +print(logfile) \ No newline at end of file