1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-09-21 18:19:35 +00:00
meme-search-engine/meme-rater/meme_pipeline.py
osmarks bd426a30ba Port meme acquisition pipeline to new API, database
Also fix a really stupid oversight in crawling code.
2024-05-22 15:43:56 +01:00

98 lines
3.1 KiB
Python

import subprocess
import torch
from tqdm import tqdm
import json
from pathlib import Path
import os
import asyncio
import aiohttp
import time
import shared
from model import Config, BradleyTerry
meme_search_backend = "http://localhost:1707/"
score_threshold = 1.7264162302017212
shared.db.executescript("""
CREATE TABLE IF NOT EXISTS last_crawl (time INTEGER);
CREATE TABLE IF NOT EXISTS library_queue (
filename TEXT PRIMARY KEY,
score REAL NOT NULL
);
""")
shared.db.commit()
csr = shared.db.execute("SELECT MAX(time) FROM last_crawl")
row = csr.fetchone()
last_crawl = row[0] or 0
csr.close()
with open("rater_mse_config.json", "r") as f:
mse_config = json.load(f)
basedir = Path(mse_config["files"])
print("crawling...")
crawl_start = time.time()
subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
print("indexing...")
subprocess.run(["./meme-search-engine", "rater_mse_config.json"]).check_returncode()
print("evaluating...")
batch_size = 128
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
dropout=0.1
)
model = BradleyTerry(config).float()
modelc, _ = shared.checkpoint_for(1500)
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
ratings = {}
model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size]
filenames = [ filename for filename, embedding in batch ]
embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
scores = model.ensemble(inputs).float()
mscores = torch.median(scores, dim=0).values
for filename, mscore in zip(filenames, mscores):
ratings[filename] = float(mscore)
print(sorted(ratings.values())[round(len(ratings) * 0.95)])
print(f"{len(ratings)} memes in {crawl_start - last_crawl} seconds ({len(ratings) / (crawl_start - last_crawl) * 1e3}mHz)")
files = dict(files)
async def run_inserts():
async with aiohttp.ClientSession():
async def duplicate_exists(embedding):
async with aiohttp.request("POST", meme_search_backend, json={
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
"k": 1
}) as res:
result = await res.json()
closest = result["matches"][0][0]
return closest > 0.99 # arbitrary threshold, TODO
for filename, rating in ratings.items():
if rating > score_threshold and not await duplicate_exists(files[filename]):
shared.db.execute("INSERT OR REPLACE INTO library_queue VALUES (?, ?)", (filename, rating))
else:
os.unlink(basedir / filename)
shared.db.execute("INSERT INTO last_crawl VALUES (?)", (crawl_start,))
shared.db.commit()
asyncio.run(run_inserts())