1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00
meme-search-engine/meme-rater/active_learning.py

55 lines
1.5 KiB
Python
Raw Normal View History

import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
from model import Config, BradleyTerry
import shared
batch_size = 128
num_pairs = batch_size * 1024
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.bfloat16,
dropout=0.5
)
model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(2250)
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
variance = {}
pairs = []
for _ in range(num_pairs):
pairs.append(tuple(random.sample(files, 2)))
model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(pairs), batch_size)):
batch = pairs[bstart:bstart + batch_size]
filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
win_probs = model(inputs)
#print(win_probs.shape)
batchvar = torch.var(win_probs, dim=0)
for filename, var in zip(filenames, batchvar):
variance[filename] = float(var)
top = sorted(variance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
json.dump(top[:256], f)