diff --git a/meme-rater/active_learning.py b/meme-rater/active_learning.py index fae847c..dac462a 100644 --- a/meme-rater/active_learning.py +++ b/meme-rater/active_learning.py @@ -7,6 +7,7 @@ import numpy import json import time from tqdm import tqdm +import sys from model import Config, BradleyTerry import shared @@ -20,11 +21,12 @@ config = Config( n_hidden=1, n_ensemble=16, device=device, - dtype=torch.bfloat16, - dropout=0.5 + dtype=torch.float32, + output_channels=3, + dropout=0.1 ) model = BradleyTerry(config) -modelc, _ = shared.checkpoint_for(2250) +modelc, _ = shared.checkpoint_for(int(sys.argv[1])) model.load_state_dict(torch.load(modelc)) params = sum(p.numel() for p in model.parameters()) print(f"{params/1e6:.1f}M parameters") @@ -45,11 +47,13 @@ with torch.inference_mode(): embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ]) inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device) win_probs = model(inputs) + #print(win_probs, win_probs.shape) #print(win_probs.shape) - batchvar = torch.var(win_probs, dim=0) + batchvar = torch.var(win_probs, dim=0).max(-1).values + #print(batchvar, batchvar.shape) for filename, var in zip(filenames, batchvar): variance[filename] = float(var) top = sorted(variance.items(), key=lambda x: -x[1]) with open("top.json", "w") as f: - json.dump(top[:256], f) \ No newline at end of file + json.dump(top[:100], f) diff --git a/meme-rater/active_learning_find_top.py b/meme-rater/active_learning_find_top.py new file mode 100644 index 0000000..b12b435 --- /dev/null +++ b/meme-rater/active_learning_find_top.py @@ -0,0 +1,64 @@ +import torch.nn +import torch.nn.functional as F +import torch +import sqlite3 +import random +import numpy +import json +import time +from tqdm import tqdm +import sys +from collections import defaultdict + +from model import Config, BradleyTerry +import shared + +batch_size = 128 +num_pairs = batch_size * 1024 +device = "cuda" + +config = Config( + d_emb=1152, + n_hidden=1, + n_ensemble=16, + device=device, + dtype=torch.float32, + output_channels=3, + dropout=0.1 +) +model = BradleyTerry(config) +modelc, _ = shared.checkpoint_for(int(sys.argv[1])) +model.load_state_dict(torch.load(modelc)) +params = sum(p.numel() for p in model.parameters()) +print(f"{params/1e6:.1f}M parameters") +print(model) + +files = shared.fetch_all_files() +results = {} + +model.eval() +with torch.inference_mode(): + for bstart in tqdm(range(0, len(files), batch_size)): + batch = files[bstart:bstart + batch_size] + filenames = [ f1 for f1, e1 in batch ] + embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ]) + inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device) + scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy() + #print(batchvar, batchvar.shape) + for filename, score in zip(filenames, scores): + results[filename] = score + +channel = int(sys.argv[2]) +percentile = float(sys.argv[3]) +output_pairs = int(sys.argv[4]) +mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()])) +top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True) +select_from = top[:int(len(top) * percentile)] + +out = [] +for _ in range(output_pairs): + # dummy score for compatibility with existing code + out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0)) + +with open("top.json", "w") as f: + json.dump(out, f) diff --git a/meme-rater/al2.py b/meme-rater/active_learning_gradients.py similarity index 91% rename from meme-rater/al2.py rename to meme-rater/active_learning_gradients.py index 2410fa4..13909df 100644 --- a/meme-rater/al2.py +++ b/meme-rater/active_learning_gradients.py @@ -8,11 +8,11 @@ import json import time from tqdm import tqdm from torch.func import functional_call, vmap, grad +import sys from model import Config, BradleyTerry import shared -steps = 855 batch_size = 128 num_pairs = batch_size * 1024 device = "cuda" @@ -22,10 +22,12 @@ config = Config( n_hidden=1, n_ensemble=1, device=device, - dtype=torch.bfloat16 + dtype=torch.float32, + output_channels=3, + dropout=0.1 ) model = BradleyTerry(config) -modelc, _ = shared.checkpoint_for(855) +modelc, _ = shared.checkpoint_for(int(sys.argv[1])) model.load_state_dict(torch.load(modelc)) params = sum(p.numel() for p in model.parameters()) print(f"{params/1e6:.1f}M parameters") @@ -61,7 +63,7 @@ for bstart in tqdm(range(0, len(pairs), batch_size)): #win_probs = model(inputs) # TODO gradients # don't take variance: do backwards pass and compute gradient norm - grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device)) + grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size, config.output_channels), 0.95).to(device)) total_grad_norms = torch.zeros(batch_size).to(device) for k, v in grads.items(): param_dims = tuple(range(1, len(v.shape))) @@ -73,4 +75,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)): top = sorted(importance.items(), key=lambda x: -x[1]) with open("top.json", "w") as f: - json.dump(top[:256], f) \ No newline at end of file + json.dump(top[:256], f) diff --git a/meme-rater/load_from_json.py b/meme-rater/load_from_json.py new file mode 100644 index 0000000..970efe1 --- /dev/null +++ b/meme-rater/load_from_json.py @@ -0,0 +1,20 @@ +import jsonlines +import sqlite3 +import numpy as np + +import shared + +shared.db.executescript(""" +CREATE TABLE IF NOT EXISTS files ( + filename TEXT NOT NULL, + title TEXT NOT NULL, + link TEXT NOT NULL, + embedding BLOB NOT NULL, + UNIQUE (filename) +); +""") + +with jsonlines.open("sample.jsonl") as reader: + for obj in reader: + shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes()))) +shared.db.commit() diff --git a/meme-rater/model.py b/meme-rater/model.py index 1487ea3..49cf551 100644 --- a/meme-rater/model.py +++ b/meme-rater/model.py @@ -13,13 +13,14 @@ class Config: device: str dtype: torch.dtype dropout: float + output_channels: int class Model(nn.Module): def __init__(self, config): super().__init__() self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ]) self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ]) - self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device) + self.output = nn.Linear(config.d_emb, config.output_channels, dtype=config.dtype, device=config.device) def forward(self, embs): x = embs @@ -34,8 +35,7 @@ class Ensemble(nn.Module): # model batch def forward(self, embs): - xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 - return xs.squeeze(-1) + return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 class BradleyTerry(nn.Module): def __init__(self, config): @@ -49,4 +49,4 @@ class BradleyTerry(nn.Module): #print(scores1, scores2) probs = torch.sigmoid(scores1 - scores2) # model batch #print(probs) - return probs \ No newline at end of file + return probs diff --git a/meme-rater/rater_server.py b/meme-rater/rater_server.py index f003482..9ee2080 100644 --- a/meme-rater/rater_server.py +++ b/meme-rater/rater_server.py @@ -30,9 +30,10 @@ async def index(request): return web.Response(text=f""" + Data Labelling Frontend (Not Evil) -

Meme Rating

+

Data Labelling Frontend (Not Evil)

- - + + + + + + + + + + + + + + + + +
- - + +
@@ -90,8 +124,8 @@ async def rate(request): meme1 = post["meme1"] meme2 = post["meme2"] iteration = post["iteration"] - rating = post["rating"] - await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration)) + rating = post["rating-useful"] + "," + post["rating-meme"] + "," + post["rating-aesthetic"] + await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration, ip) VALUES (?, ?, ?, ?, ?)", (meme1, meme2, rating, iteration, request.remote)) await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2)) await db.commit() return web.HTTPFound("/") @@ -104,6 +138,7 @@ CREATE TABLE IF NOT EXISTS ratings ( meme2 TEXT NOT NULL, rating TEXT NOT NULL, iteration TEXT, + ip TEXT, UNIQUE (meme1, meme2) ); CREATE TABLE IF NOT EXISTS queue ( diff --git a/meme-rater/run_graph.py b/meme-rater/run_graph.py index 0f152b5..36db9d5 100644 --- a/meme-rater/run_graph.py +++ b/meme-rater/run_graph.py @@ -2,10 +2,11 @@ import json import matplotlib.pyplot as plt +import sys # Read data from log.jsonl data = [] -with open('log.jsonl', 'r') as file: +with open(sys.argv[1], 'r') as file: for line in file: data.append(json.loads(line)) diff --git a/meme-rater/shared.py b/meme-rater/shared.py index 76360c3..5c52c97 100644 --- a/meme-rater/shared.py +++ b/meme-rater/shared.py @@ -3,6 +3,7 @@ import hashlib from collections import defaultdict import numpy import random +import numpy as np db = sqlite3.connect("data.sqlite3") db.row_factory = sqlite3.Row @@ -20,19 +21,24 @@ def fetch_embedding(filename): return x.copy() # PyTorch complains otherwise due to bad def map_rating(rating, uncertainty=0.05): - match rating: - case "1": # meme 1 is better - return 1 - uncertainty - case "2": - return uncertainty - case _: raise ValueError("invalid rating, please fix") + def map_one(rating): + match rating: + case "1": # meme 1 is better + return 1 - uncertainty + case "2": + return uncertainty + case "eq": + return 0.5 + case _: raise ValueError("invalid rating, please fix") + + return np.array([map_one(r) for r in rating.split(",")]) def fetch_ratings(): trains = defaultdict(list) validations = defaultdict(list) csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings") for meme1, meme2, rating, iteration in csr.fetchall(): - (validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating))) + (validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating))) csr.close() return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items())) @@ -50,4 +56,4 @@ def fetch_all_files(): return x def checkpoint_for(steps): - return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt" \ No newline at end of file + return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt" diff --git a/meme-rater/train.py b/meme-rater/train.py index 0acde88..15b0df2 100644 --- a/meme-rater/train.py +++ b/meme-rater/train.py @@ -36,7 +36,8 @@ config = TrainConfig( n_ensemble=16, device=device, dtype=torch.float32, - dropout=0.1 + dropout=0.1, + output_channels=3 ), lr=3e-4, weight_decay=0.2, @@ -72,12 +73,12 @@ if config.compile: print("compiling...") train_step = torch.compile(train_step) -def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]): +def batch_from_inputs(inputs: list[list[tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]]]): batch_input = torch.stack([ torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ]) for input in inputs ]).to(device) - target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device) + target = torch.stack([ torch.Tensor(numpy.array([ rating for emb1, emb2, rating in input ])).to(config.model.dtype) for input in inputs ]).to(device) return batch_input, target def evaluate(steps): @@ -118,7 +119,7 @@ with open(logfile, "w") as log: print(steps, loss) log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") if steps % 10 == 0: - if steps % 250 == 0: save_ckpt(log, steps) + if steps % 100 == 0: save_ckpt(log, steps) loss = evaluate(steps) #print(loss) #best = min(loss, best) @@ -126,4 +127,4 @@ with open(logfile, "w") as log: save_ckpt(log, steps) -print(logfile) \ No newline at end of file +print(logfile)