mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
full pipeline
This commit is contained in:
parent
7bae095384
commit
80db16d02a
@ -67,8 +67,7 @@ async def download(sess, url, file):
|
||||
await fh.write(chunk)
|
||||
return dict(res.headers)
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
async def main(time_threshold):
|
||||
sem = asyncio.Semaphore(16)
|
||||
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
@ -113,6 +112,13 @@ if __name__ == "__main__":
|
||||
#print("downloading new set")
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for item in items:
|
||||
if time_threshold and time_threshold > item["created"]:
|
||||
return
|
||||
tg.create_task(dl_task(item))
|
||||
|
||||
asyncio.run(main())
|
||||
if __name__ == "__main__":
|
||||
threshold = None
|
||||
if len(sys.argv) > 1:
|
||||
print("thresholding at", sys.argv[1])
|
||||
threshold = float(sys.argv[1])
|
||||
asyncio.run(main(threshold))
|
102
meme-rater/library_processing_server.py
Normal file
102
meme-rater/library_processing_server.py
Normal file
@ -0,0 +1,102 @@
|
||||
from aiohttp import web
|
||||
import aiosqlite
|
||||
import asyncio
|
||||
import random
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
PORT, DATABASE, TARGET_DIR = sys.argv[1:]
|
||||
with open("rater_mse_config.json", "r") as f:
|
||||
mse_config = json.load(f)
|
||||
basedir = Path(mse_config["files"])
|
||||
TARGET_DIR = Path(TARGET_DIR)
|
||||
|
||||
app = web.Application(client_max_size=32*1024**2)
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
@routes.get("/")
|
||||
async def index(request):
|
||||
csr = await request.app["db"].execute("SELECT filename FROM library_queue ORDER BY score DESC")
|
||||
filename, = await csr.fetchone()
|
||||
await csr.close()
|
||||
return web.Response(text=f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<style>
|
||||
.memes img {{
|
||||
width: 100%;
|
||||
}}
|
||||
|
||||
input {{
|
||||
width: 100%;
|
||||
}}
|
||||
|
||||
.memes {{
|
||||
margin-top: 2em;
|
||||
}}
|
||||
</style>
|
||||
<body>
|
||||
<h1>Meme Processing</h1>
|
||||
<form action="/" method="POST">
|
||||
<input type="text" name="filename" id="filename" autofocus>
|
||||
<input type="hidden" name="original_filename" value="{filename}">
|
||||
<input type="submit" value="Submit">
|
||||
<div class="memes">
|
||||
<img src="/memes/{filename}" id="meme1">
|
||||
</div>
|
||||
</form>
|
||||
<script>
|
||||
document.addEventListener("keypress", function(event) {{
|
||||
if (event.key === "Enter") {{
|
||||
document.querySelector("input[name='rating'][value='1']").checked = true
|
||||
document.querySelector("form").submit()
|
||||
}}
|
||||
}});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""", content_type="text/html")
|
||||
|
||||
def find_new_path(basename, ext):
|
||||
ctr = 1
|
||||
while True:
|
||||
new = TARGET_DIR / (basename + ("" if ctr == 1 else "-" + str(ctr)) + ext)
|
||||
if not new.exists(): return new
|
||||
ctr += 1
|
||||
|
||||
@routes.post("/")
|
||||
async def rate(request):
|
||||
db = request.app["db"]
|
||||
post = await request.post()
|
||||
filename = post["filename"]
|
||||
original_filename = post["original_filename"]
|
||||
real_path = basedir / original_filename
|
||||
assert real_path.is_file()
|
||||
if filename == "": # bad meme, discard
|
||||
real_path.unlink()
|
||||
else:
|
||||
new_path = find_new_path(filename.replace(" ", "-"), real_path.suffix)
|
||||
print(real_path, new_path, real_path.suffix)
|
||||
shutil.move(real_path, new_path)
|
||||
await db.execute("DELETE FROM library_queue WHERE filename = ?", (original_filename,))
|
||||
await db.commit()
|
||||
return web.HTTPFound("/")
|
||||
|
||||
async def main():
|
||||
app["db"] = await aiosqlite.connect(DATABASE)
|
||||
app.router.add_routes(routes)
|
||||
app.router.add_static("/memes/", "./images")
|
||||
print("Ready")
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "", int(PORT))
|
||||
await site.start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(main())
|
||||
loop.run_forever()
|
97
meme-rater/meme_pipeline.py
Normal file
97
meme-rater/meme_pipeline.py
Normal file
@ -0,0 +1,97 @@
|
||||
import subprocess
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from pathlib import Path
|
||||
import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time
|
||||
|
||||
import shared
|
||||
from model import Config, BradleyTerry
|
||||
|
||||
meme_search_backend = "http://localhost:1707/"
|
||||
score_threshold = 1.5162627696990967
|
||||
|
||||
shared.db.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS last_crawl (time INTEGER);
|
||||
CREATE TABLE IF NOT EXISTS library_queue (
|
||||
filename TEXT PRIMARY KEY,
|
||||
score REAL NOT NULL
|
||||
);
|
||||
""")
|
||||
shared.db.commit()
|
||||
csr = shared.db.execute("SELECT MAX(time) FROM last_crawl")
|
||||
row = csr.fetchone()
|
||||
last_crawl = row[0] or 0
|
||||
csr.close()
|
||||
|
||||
with open("rater_mse_config.json", "r") as f:
|
||||
mse_config = json.load(f)
|
||||
basedir = Path(mse_config["files"])
|
||||
|
||||
print("crawling...")
|
||||
crawl_start = time.time()
|
||||
subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
|
||||
print("indexing...")
|
||||
subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode()
|
||||
print("evaluating...")
|
||||
|
||||
batch_size = 128
|
||||
device = "cuda"
|
||||
|
||||
config = Config(
|
||||
d_emb=1152,
|
||||
n_hidden=1,
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
dropout=0.1
|
||||
)
|
||||
model = BradleyTerry(config).float()
|
||||
modelc, _ = shared.checkpoint_for(1500)
|
||||
model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
|
||||
files = shared.fetch_all_files()
|
||||
ratings = {}
|
||||
|
||||
model.eval()
|
||||
with torch.inference_mode():
|
||||
for bstart in tqdm(range(0, len(files), batch_size)):
|
||||
batch = files[bstart:bstart + batch_size]
|
||||
filenames = [ filename for filename, embedding in batch ]
|
||||
embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
|
||||
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
||||
scores = model.ensemble(inputs).float()
|
||||
mscores = torch.median(scores, dim=0).values
|
||||
for filename, mscore in zip(filenames, mscores):
|
||||
ratings[filename] = float(mscore)
|
||||
|
||||
print(sorted(ratings.values())[round(len(ratings) * 0.85)])
|
||||
|
||||
files = dict(files)
|
||||
|
||||
async def run_inserts():
|
||||
async with aiohttp.ClientSession():
|
||||
async def duplicate_exists(embedding):
|
||||
async with aiohttp.request("POST", meme_search_backend, json={
|
||||
"embeddings": [ list(float(x) for x in embedding) ], # sorry
|
||||
"top_k": 1
|
||||
}) as res:
|
||||
result = await res.json()
|
||||
closest = result[0]["score"]
|
||||
return closest > 0.99 # arbitrary threshold, TODO
|
||||
|
||||
for filename, rating in ratings.items():
|
||||
if rating > score_threshold and not await duplicate_exists(files[filename]):
|
||||
shared.db.execute("INSERT OR REPLACE INTO library_queue VALUES (?, ?)", (filename, rating))
|
||||
else:
|
||||
os.unlink(basedir / filename)
|
||||
shared.db.execute("INSERT INTO last_crawl VALUES (?)", (crawl_start,))
|
||||
shared.db.commit()
|
||||
|
||||
asyncio.run(run_inserts())
|
14
mse.py
14
mse.py
@ -42,10 +42,11 @@ async def run_query(request):
|
||||
if text := data.get("text", []):
|
||||
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
|
||||
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
||||
embeddings = [ e * w for e, w in zip(embeddings, weights) ]
|
||||
if not embeddings:
|
||||
weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
|
||||
weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
|
||||
if not weighted_embeddings:
|
||||
return web.json_response([])
|
||||
return web.json_response(app["index"].search(sum(embeddings)))
|
||||
return web.json_response(app["index"].search(sum(weighted_embeddings), top_k=data.get("top_k", 4000)))
|
||||
|
||||
@routes.get("/")
|
||||
async def health_check(request):
|
||||
@ -70,8 +71,8 @@ class Index:
|
||||
self.inference_server_config = inference_server_config
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
def search(self, query):
|
||||
distances, indices = self.faiss_index.search(numpy.array([query]), 4000)
|
||||
def search(self, query, top_k):
|
||||
distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
try:
|
||||
@ -214,6 +215,7 @@ async def main():
|
||||
app["index"] = index
|
||||
await index.reload()
|
||||
print("Ready")
|
||||
if CONFIG.get("no_run_server", False): return
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "", CONFIG["port"])
|
||||
@ -223,4 +225,4 @@ if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(main())
|
||||
loop.run_forever()
|
||||
if CONFIG.get("no_run_server", False) == False: loop.run_forever()
|
||||
|
Loading…
Reference in New Issue
Block a user