diff --git a/meme-rater/active_learning.py b/meme-rater/active_learning.py index fae847c..dac462a 100644 --- a/meme-rater/active_learning.py +++ b/meme-rater/active_learning.py @@ -7,6 +7,7 @@ import numpy import json import time from tqdm import tqdm +import sys from model import Config, BradleyTerry import shared @@ -20,11 +21,12 @@ config = Config( n_hidden=1, n_ensemble=16, device=device, - dtype=torch.bfloat16, - dropout=0.5 + dtype=torch.float32, + output_channels=3, + dropout=0.1 ) model = BradleyTerry(config) -modelc, _ = shared.checkpoint_for(2250) +modelc, _ = shared.checkpoint_for(int(sys.argv[1])) model.load_state_dict(torch.load(modelc)) params = sum(p.numel() for p in model.parameters()) print(f"{params/1e6:.1f}M parameters") @@ -45,11 +47,13 @@ with torch.inference_mode(): embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ]) inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device) win_probs = model(inputs) + #print(win_probs, win_probs.shape) #print(win_probs.shape) - batchvar = torch.var(win_probs, dim=0) + batchvar = torch.var(win_probs, dim=0).max(-1).values + #print(batchvar, batchvar.shape) for filename, var in zip(filenames, batchvar): variance[filename] = float(var) top = sorted(variance.items(), key=lambda x: -x[1]) with open("top.json", "w") as f: - json.dump(top[:256], f) \ No newline at end of file + json.dump(top[:100], f) diff --git a/meme-rater/active_learning_find_top.py b/meme-rater/active_learning_find_top.py new file mode 100644 index 0000000..b12b435 --- /dev/null +++ b/meme-rater/active_learning_find_top.py @@ -0,0 +1,64 @@ +import torch.nn +import torch.nn.functional as F +import torch +import sqlite3 +import random +import numpy +import json +import time +from tqdm import tqdm +import sys +from collections import defaultdict + +from model import Config, BradleyTerry +import shared + +batch_size = 128 +num_pairs = batch_size * 1024 +device = "cuda" + +config = Config( + d_emb=1152, + n_hidden=1, + n_ensemble=16, + device=device, + dtype=torch.float32, + output_channels=3, + dropout=0.1 +) +model = BradleyTerry(config) +modelc, _ = shared.checkpoint_for(int(sys.argv[1])) +model.load_state_dict(torch.load(modelc)) +params = sum(p.numel() for p in model.parameters()) +print(f"{params/1e6:.1f}M parameters") +print(model) + +files = shared.fetch_all_files() +results = {} + +model.eval() +with torch.inference_mode(): + for bstart in tqdm(range(0, len(files), batch_size)): + batch = files[bstart:bstart + batch_size] + filenames = [ f1 for f1, e1 in batch ] + embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ]) + inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device) + scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy() + #print(batchvar, batchvar.shape) + for filename, score in zip(filenames, scores): + results[filename] = score + +channel = int(sys.argv[2]) +percentile = float(sys.argv[3]) +output_pairs = int(sys.argv[4]) +mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()])) +top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True) +select_from = top[:int(len(top) * percentile)] + +out = [] +for _ in range(output_pairs): + # dummy score for compatibility with existing code + out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0)) + +with open("top.json", "w") as f: + json.dump(out, f) diff --git a/meme-rater/al2.py b/meme-rater/active_learning_gradients.py similarity index 91% rename from meme-rater/al2.py rename to meme-rater/active_learning_gradients.py index 2410fa4..13909df 100644 --- a/meme-rater/al2.py +++ b/meme-rater/active_learning_gradients.py @@ -8,11 +8,11 @@ import json import time from tqdm import tqdm from torch.func import functional_call, vmap, grad +import sys from model import Config, BradleyTerry import shared -steps = 855 batch_size = 128 num_pairs = batch_size * 1024 device = "cuda" @@ -22,10 +22,12 @@ config = Config( n_hidden=1, n_ensemble=1, device=device, - dtype=torch.bfloat16 + dtype=torch.float32, + output_channels=3, + dropout=0.1 ) model = BradleyTerry(config) -modelc, _ = shared.checkpoint_for(855) +modelc, _ = shared.checkpoint_for(int(sys.argv[1])) model.load_state_dict(torch.load(modelc)) params = sum(p.numel() for p in model.parameters()) print(f"{params/1e6:.1f}M parameters") @@ -61,7 +63,7 @@ for bstart in tqdm(range(0, len(pairs), batch_size)): #win_probs = model(inputs) # TODO gradients # don't take variance: do backwards pass and compute gradient norm - grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device)) + grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size, config.output_channels), 0.95).to(device)) total_grad_norms = torch.zeros(batch_size).to(device) for k, v in grads.items(): param_dims = tuple(range(1, len(v.shape))) @@ -73,4 +75,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)): top = sorted(importance.items(), key=lambda x: -x[1]) with open("top.json", "w") as f: - json.dump(top[:256], f) \ No newline at end of file + json.dump(top[:256], f) diff --git a/meme-rater/load_from_json.py b/meme-rater/load_from_json.py new file mode 100644 index 0000000..970efe1 --- /dev/null +++ b/meme-rater/load_from_json.py @@ -0,0 +1,20 @@ +import jsonlines +import sqlite3 +import numpy as np + +import shared + +shared.db.executescript(""" +CREATE TABLE IF NOT EXISTS files ( + filename TEXT NOT NULL, + title TEXT NOT NULL, + link TEXT NOT NULL, + embedding BLOB NOT NULL, + UNIQUE (filename) +); +""") + +with jsonlines.open("sample.jsonl") as reader: + for obj in reader: + shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes()))) +shared.db.commit() diff --git a/meme-rater/model.py b/meme-rater/model.py index 1487ea3..49cf551 100644 --- a/meme-rater/model.py +++ b/meme-rater/model.py @@ -13,13 +13,14 @@ class Config: device: str dtype: torch.dtype dropout: float + output_channels: int class Model(nn.Module): def __init__(self, config): super().__init__() self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ]) self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ]) - self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device) + self.output = nn.Linear(config.d_emb, config.output_channels, dtype=config.dtype, device=config.device) def forward(self, embs): x = embs @@ -34,8 +35,7 @@ class Ensemble(nn.Module): # model batch def forward(self, embs): - xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 - return xs.squeeze(-1) + return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 class BradleyTerry(nn.Module): def __init__(self, config): @@ -49,4 +49,4 @@ class BradleyTerry(nn.Module): #print(scores1, scores2) probs = torch.sigmoid(scores1 - scores2) # model batch #print(probs) - return probs \ No newline at end of file + return probs diff --git a/meme-rater/rater_server.py b/meme-rater/rater_server.py index f003482..9ee2080 100644 --- a/meme-rater/rater_server.py +++ b/meme-rater/rater_server.py @@ -30,9 +30,10 @@ async def index(request): return web.Response(text=f""" +