mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-07 14:40:08 +00:00
full pipeline
This commit is contained in:
parent
7bae095384
commit
80db16d02a
@ -67,8 +67,7 @@ async def download(sess, url, file):
|
|||||||
await fh.write(chunk)
|
await fh.write(chunk)
|
||||||
return dict(res.headers)
|
return dict(res.headers)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
async def main(time_threshold):
|
||||||
async def main():
|
|
||||||
sem = asyncio.Semaphore(16)
|
sem = asyncio.Semaphore(16)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as sess:
|
async with aiohttp.ClientSession() as sess:
|
||||||
@ -113,6 +112,13 @@ if __name__ == "__main__":
|
|||||||
#print("downloading new set")
|
#print("downloading new set")
|
||||||
async with asyncio.TaskGroup() as tg:
|
async with asyncio.TaskGroup() as tg:
|
||||||
for item in items:
|
for item in items:
|
||||||
|
if time_threshold and time_threshold > item["created"]:
|
||||||
|
return
|
||||||
tg.create_task(dl_task(item))
|
tg.create_task(dl_task(item))
|
||||||
|
|
||||||
asyncio.run(main())
|
if __name__ == "__main__":
|
||||||
|
threshold = None
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
print("thresholding at", sys.argv[1])
|
||||||
|
threshold = float(sys.argv[1])
|
||||||
|
asyncio.run(main(threshold))
|
102
meme-rater/library_processing_server.py
Normal file
102
meme-rater/library_processing_server.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from aiohttp import web
|
||||||
|
import aiosqlite
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
PORT, DATABASE, TARGET_DIR = sys.argv[1:]
|
||||||
|
with open("rater_mse_config.json", "r") as f:
|
||||||
|
mse_config = json.load(f)
|
||||||
|
basedir = Path(mse_config["files"])
|
||||||
|
TARGET_DIR = Path(TARGET_DIR)
|
||||||
|
|
||||||
|
app = web.Application(client_max_size=32*1024**2)
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
|
||||||
|
@routes.get("/")
|
||||||
|
async def index(request):
|
||||||
|
csr = await request.app["db"].execute("SELECT filename FROM library_queue ORDER BY score DESC")
|
||||||
|
filename, = await csr.fetchone()
|
||||||
|
await csr.close()
|
||||||
|
return web.Response(text=f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<style>
|
||||||
|
.memes img {{
|
||||||
|
width: 100%;
|
||||||
|
}}
|
||||||
|
|
||||||
|
input {{
|
||||||
|
width: 100%;
|
||||||
|
}}
|
||||||
|
|
||||||
|
.memes {{
|
||||||
|
margin-top: 2em;
|
||||||
|
}}
|
||||||
|
</style>
|
||||||
|
<body>
|
||||||
|
<h1>Meme Processing</h1>
|
||||||
|
<form action="/" method="POST">
|
||||||
|
<input type="text" name="filename" id="filename" autofocus>
|
||||||
|
<input type="hidden" name="original_filename" value="{filename}">
|
||||||
|
<input type="submit" value="Submit">
|
||||||
|
<div class="memes">
|
||||||
|
<img src="/memes/{filename}" id="meme1">
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
<script>
|
||||||
|
document.addEventListener("keypress", function(event) {{
|
||||||
|
if (event.key === "Enter") {{
|
||||||
|
document.querySelector("input[name='rating'][value='1']").checked = true
|
||||||
|
document.querySelector("form").submit()
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""", content_type="text/html")
|
||||||
|
|
||||||
|
def find_new_path(basename, ext):
|
||||||
|
ctr = 1
|
||||||
|
while True:
|
||||||
|
new = TARGET_DIR / (basename + ("" if ctr == 1 else "-" + str(ctr)) + ext)
|
||||||
|
if not new.exists(): return new
|
||||||
|
ctr += 1
|
||||||
|
|
||||||
|
@routes.post("/")
|
||||||
|
async def rate(request):
|
||||||
|
db = request.app["db"]
|
||||||
|
post = await request.post()
|
||||||
|
filename = post["filename"]
|
||||||
|
original_filename = post["original_filename"]
|
||||||
|
real_path = basedir / original_filename
|
||||||
|
assert real_path.is_file()
|
||||||
|
if filename == "": # bad meme, discard
|
||||||
|
real_path.unlink()
|
||||||
|
else:
|
||||||
|
new_path = find_new_path(filename.replace(" ", "-"), real_path.suffix)
|
||||||
|
print(real_path, new_path, real_path.suffix)
|
||||||
|
shutil.move(real_path, new_path)
|
||||||
|
await db.execute("DELETE FROM library_queue WHERE filename = ?", (original_filename,))
|
||||||
|
await db.commit()
|
||||||
|
return web.HTTPFound("/")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
app["db"] = await aiosqlite.connect(DATABASE)
|
||||||
|
app.router.add_routes(routes)
|
||||||
|
app.router.add_static("/memes/", "./images")
|
||||||
|
print("Ready")
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "", int(PORT))
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
loop.run_forever()
|
97
meme-rater/meme_pipeline.py
Normal file
97
meme-rater/meme_pipeline.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import subprocess
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import time
|
||||||
|
|
||||||
|
import shared
|
||||||
|
from model import Config, BradleyTerry
|
||||||
|
|
||||||
|
meme_search_backend = "http://localhost:1707/"
|
||||||
|
score_threshold = 1.5162627696990967
|
||||||
|
|
||||||
|
shared.db.executescript("""
|
||||||
|
CREATE TABLE IF NOT EXISTS last_crawl (time INTEGER);
|
||||||
|
CREATE TABLE IF NOT EXISTS library_queue (
|
||||||
|
filename TEXT PRIMARY KEY,
|
||||||
|
score REAL NOT NULL
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
shared.db.commit()
|
||||||
|
csr = shared.db.execute("SELECT MAX(time) FROM last_crawl")
|
||||||
|
row = csr.fetchone()
|
||||||
|
last_crawl = row[0] or 0
|
||||||
|
csr.close()
|
||||||
|
|
||||||
|
with open("rater_mse_config.json", "r") as f:
|
||||||
|
mse_config = json.load(f)
|
||||||
|
basedir = Path(mse_config["files"])
|
||||||
|
|
||||||
|
print("crawling...")
|
||||||
|
crawl_start = time.time()
|
||||||
|
subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
|
||||||
|
print("indexing...")
|
||||||
|
subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode()
|
||||||
|
print("evaluating...")
|
||||||
|
|
||||||
|
batch_size = 128
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
config = Config(
|
||||||
|
d_emb=1152,
|
||||||
|
n_hidden=1,
|
||||||
|
n_ensemble=16,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
dropout=0.1
|
||||||
|
)
|
||||||
|
model = BradleyTerry(config).float()
|
||||||
|
modelc, _ = shared.checkpoint_for(1500)
|
||||||
|
model.load_state_dict(torch.load(modelc))
|
||||||
|
params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
files = shared.fetch_all_files()
|
||||||
|
ratings = {}
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.inference_mode():
|
||||||
|
for bstart in tqdm(range(0, len(files), batch_size)):
|
||||||
|
batch = files[bstart:bstart + batch_size]
|
||||||
|
filenames = [ filename for filename, embedding in batch ]
|
||||||
|
embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
|
||||||
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
||||||
|
scores = model.ensemble(inputs).float()
|
||||||
|
mscores = torch.median(scores, dim=0).values
|
||||||
|
for filename, mscore in zip(filenames, mscores):
|
||||||
|
ratings[filename] = float(mscore)
|
||||||
|
|
||||||
|
print(sorted(ratings.values())[round(len(ratings) * 0.85)])
|
||||||
|
|
||||||
|
files = dict(files)
|
||||||
|
|
||||||
|
async def run_inserts():
|
||||||
|
async with aiohttp.ClientSession():
|
||||||
|
async def duplicate_exists(embedding):
|
||||||
|
async with aiohttp.request("POST", meme_search_backend, json={
|
||||||
|
"embeddings": [ list(float(x) for x in embedding) ], # sorry
|
||||||
|
"top_k": 1
|
||||||
|
}) as res:
|
||||||
|
result = await res.json()
|
||||||
|
closest = result[0]["score"]
|
||||||
|
return closest > 0.99 # arbitrary threshold, TODO
|
||||||
|
|
||||||
|
for filename, rating in ratings.items():
|
||||||
|
if rating > score_threshold and not await duplicate_exists(files[filename]):
|
||||||
|
shared.db.execute("INSERT OR REPLACE INTO library_queue VALUES (?, ?)", (filename, rating))
|
||||||
|
else:
|
||||||
|
os.unlink(basedir / filename)
|
||||||
|
shared.db.execute("INSERT INTO last_crawl VALUES (?)", (crawl_start,))
|
||||||
|
shared.db.commit()
|
||||||
|
|
||||||
|
asyncio.run(run_inserts())
|
14
mse.py
14
mse.py
@ -42,10 +42,11 @@ async def run_query(request):
|
|||||||
if text := data.get("text", []):
|
if text := data.get("text", []):
|
||||||
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
|
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
|
||||||
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
||||||
embeddings = [ e * w for e, w in zip(embeddings, weights) ]
|
weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
|
||||||
if not embeddings:
|
weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
|
||||||
|
if not weighted_embeddings:
|
||||||
return web.json_response([])
|
return web.json_response([])
|
||||||
return web.json_response(app["index"].search(sum(embeddings)))
|
return web.json_response(app["index"].search(sum(weighted_embeddings), top_k=data.get("top_k", 4000)))
|
||||||
|
|
||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
async def health_check(request):
|
async def health_check(request):
|
||||||
@ -70,8 +71,8 @@ class Index:
|
|||||||
self.inference_server_config = inference_server_config
|
self.inference_server_config = inference_server_config
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
def search(self, query):
|
def search(self, query, top_k):
|
||||||
distances, indices = self.faiss_index.search(numpy.array([query]), 4000)
|
distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
|
||||||
distances = distances[0]
|
distances = distances[0]
|
||||||
indices = indices[0]
|
indices = indices[0]
|
||||||
try:
|
try:
|
||||||
@ -214,6 +215,7 @@ async def main():
|
|||||||
app["index"] = index
|
app["index"] = index
|
||||||
await index.reload()
|
await index.reload()
|
||||||
print("Ready")
|
print("Ready")
|
||||||
|
if CONFIG.get("no_run_server", False): return
|
||||||
runner = web.AppRunner(app)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(runner, "", CONFIG["port"])
|
site = web.TCPSite(runner, "", CONFIG["port"])
|
||||||
@ -223,4 +225,4 @@ if __name__ == "__main__":
|
|||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.run_until_complete(main())
|
loop.run_until_complete(main())
|
||||||
loop.run_forever()
|
if CONFIG.get("no_run_server", False) == False: loop.run_forever()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user