1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00

full pipeline

This commit is contained in:
osmarks 2024-04-22 18:44:29 +01:00
parent 7bae095384
commit 80db16d02a
4 changed files with 258 additions and 51 deletions

View File

@ -67,8 +67,7 @@ async def download(sess, url, file):
await fh.write(chunk)
return dict(res.headers)
if __name__ == "__main__":
async def main():
async def main(time_threshold):
sem = asyncio.Semaphore(16)
async with aiohttp.ClientSession() as sess:
@ -113,6 +112,13 @@ if __name__ == "__main__":
#print("downloading new set")
async with asyncio.TaskGroup() as tg:
for item in items:
if time_threshold and time_threshold > item["created"]:
return
tg.create_task(dl_task(item))
asyncio.run(main())
if __name__ == "__main__":
threshold = None
if len(sys.argv) > 1:
print("thresholding at", sys.argv[1])
threshold = float(sys.argv[1])
asyncio.run(main(threshold))

View File

@ -0,0 +1,102 @@
from aiohttp import web
import aiosqlite
import asyncio
import random
import sys
import json
import os
from pathlib import Path
import shutil
PORT, DATABASE, TARGET_DIR = sys.argv[1:]
with open("rater_mse_config.json", "r") as f:
mse_config = json.load(f)
basedir = Path(mse_config["files"])
TARGET_DIR = Path(TARGET_DIR)
app = web.Application(client_max_size=32*1024**2)
routes = web.RouteTableDef()
@routes.get("/")
async def index(request):
csr = await request.app["db"].execute("SELECT filename FROM library_queue ORDER BY score DESC")
filename, = await csr.fetchone()
await csr.close()
return web.Response(text=f"""
<!DOCTYPE html>
<html>
<style>
.memes img {{
width: 100%;
}}
input {{
width: 100%;
}}
.memes {{
margin-top: 2em;
}}
</style>
<body>
<h1>Meme Processing</h1>
<form action="/" method="POST">
<input type="text" name="filename" id="filename" autofocus>
<input type="hidden" name="original_filename" value="{filename}">
<input type="submit" value="Submit">
<div class="memes">
<img src="/memes/{filename}" id="meme1">
</div>
</form>
<script>
document.addEventListener("keypress", function(event) {{
if (event.key === "Enter") {{
document.querySelector("input[name='rating'][value='1']").checked = true
document.querySelector("form").submit()
}}
}});
</script>
</body>
</html>
""", content_type="text/html")
def find_new_path(basename, ext):
ctr = 1
while True:
new = TARGET_DIR / (basename + ("" if ctr == 1 else "-" + str(ctr)) + ext)
if not new.exists(): return new
ctr += 1
@routes.post("/")
async def rate(request):
db = request.app["db"]
post = await request.post()
filename = post["filename"]
original_filename = post["original_filename"]
real_path = basedir / original_filename
assert real_path.is_file()
if filename == "": # bad meme, discard
real_path.unlink()
else:
new_path = find_new_path(filename.replace(" ", "-"), real_path.suffix)
print(real_path, new_path, real_path.suffix)
shutil.move(real_path, new_path)
await db.execute("DELETE FROM library_queue WHERE filename = ?", (original_filename,))
await db.commit()
return web.HTTPFound("/")
async def main():
app["db"] = await aiosqlite.connect(DATABASE)
app.router.add_routes(routes)
app.router.add_static("/memes/", "./images")
print("Ready")
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", int(PORT))
await site.start()
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
loop.run_forever()

View File

@ -0,0 +1,97 @@
import subprocess
import torch
from tqdm import tqdm
import json
from pathlib import Path
import os
import asyncio
import aiohttp
import time
import shared
from model import Config, BradleyTerry
meme_search_backend = "http://localhost:1707/"
score_threshold = 1.5162627696990967
shared.db.executescript("""
CREATE TABLE IF NOT EXISTS last_crawl (time INTEGER);
CREATE TABLE IF NOT EXISTS library_queue (
filename TEXT PRIMARY KEY,
score REAL NOT NULL
);
""")
shared.db.commit()
csr = shared.db.execute("SELECT MAX(time) FROM last_crawl")
row = csr.fetchone()
last_crawl = row[0] or 0
csr.close()
with open("rater_mse_config.json", "r") as f:
mse_config = json.load(f)
basedir = Path(mse_config["files"])
print("crawling...")
crawl_start = time.time()
subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
print("indexing...")
subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode()
print("evaluating...")
batch_size = 128
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
dropout=0.1
)
model = BradleyTerry(config).float()
modelc, _ = shared.checkpoint_for(1500)
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
ratings = {}
model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size]
filenames = [ filename for filename, embedding in batch ]
embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
scores = model.ensemble(inputs).float()
mscores = torch.median(scores, dim=0).values
for filename, mscore in zip(filenames, mscores):
ratings[filename] = float(mscore)
print(sorted(ratings.values())[round(len(ratings) * 0.85)])
files = dict(files)
async def run_inserts():
async with aiohttp.ClientSession():
async def duplicate_exists(embedding):
async with aiohttp.request("POST", meme_search_backend, json={
"embeddings": [ list(float(x) for x in embedding) ], # sorry
"top_k": 1
}) as res:
result = await res.json()
closest = result[0]["score"]
return closest > 0.99 # arbitrary threshold, TODO
for filename, rating in ratings.items():
if rating > score_threshold and not await duplicate_exists(files[filename]):
shared.db.execute("INSERT OR REPLACE INTO library_queue VALUES (?, ?)", (filename, rating))
else:
os.unlink(basedir / filename)
shared.db.execute("INSERT INTO last_crawl VALUES (?)", (crawl_start,))
shared.db.commit()
asyncio.run(run_inserts())

14
mse.py
View File

@ -42,10 +42,11 @@ async def run_query(request):
if text := data.get("text", []):
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
weights = [ w for x, w in images ] + [ w for x, w in text ]
embeddings = [ e * w for e, w in zip(embeddings, weights) ]
if not embeddings:
weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
if not weighted_embeddings:
return web.json_response([])
return web.json_response(app["index"].search(sum(embeddings)))
return web.json_response(app["index"].search(sum(weighted_embeddings), top_k=data.get("top_k", 4000)))
@routes.get("/")
async def health_check(request):
@ -70,8 +71,8 @@ class Index:
self.inference_server_config = inference_server_config
self.lock = asyncio.Lock()
def search(self, query):
distances, indices = self.faiss_index.search(numpy.array([query]), 4000)
def search(self, query, top_k):
distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
distances = distances[0]
indices = indices[0]
try:
@ -214,6 +215,7 @@ async def main():
app["index"] = index
await index.reload()
print("Ready")
if CONFIG.get("no_run_server", False): return
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", CONFIG["port"])
@ -223,4 +225,4 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
loop.run_forever()
if CONFIG.get("no_run_server", False) == False: loop.run_forever()