diff --git a/meme-rater/crawler.py b/meme-rater/crawler.py index 52bb27d..0a5e951 100644 --- a/meme-rater/crawler.py +++ b/meme-rater/crawler.py @@ -67,52 +67,58 @@ async def download(sess, url, file): await fh.write(chunk) return dict(res.headers) -if __name__ == "__main__": - async def main(): - sem = asyncio.Semaphore(16) - - async with aiohttp.ClientSession() as sess: - async def download_item(item): - #print("starting on", item["name"]) - print(".", end="") +async def main(time_threshold): + sem = asyncio.Semaphore(16) + + async with aiohttp.ClientSession() as sess: + async def download_item(item): + #print("starting on", item["name"]) + print(".", end="") + sys.stdout.flush() + if item["over_18"] or not item["is_robot_indexable"]: return + id = item["name"] + bck = bucket(id) + os.makedirs(os.path.join("images", bck), exist_ok=True) + os.makedirs(os.path.join("meta", bck), exist_ok=True) + if not item["url"].startswith("https://"): return + meta_path = os.path.join("meta", bck, id + ".json") + if not os.path.exists(meta_path): # sorry + print("|", end="") sys.stdout.flush() - if item["over_18"] or not item["is_robot_indexable"]: return - id = item["name"] - bck = bucket(id) - os.makedirs(os.path.join("images", bck), exist_ok=True) - os.makedirs(os.path.join("meta", bck), exist_ok=True) - if not item["url"].startswith("https://"): return - meta_path = os.path.join("meta", bck, id + ".json") - if not os.path.exists(meta_path): # sorry - print("|", end="") + try: + result = await download(sess, item["url"], os.path.join("images", bck, id)) + except Exception as e: + print("\nMeme acquisition failure:", e) + return + if result: + item["headers"] = result + with open(meta_path, "w") as fh: + json.dump(item, fh) + else: + print("!", end="") sys.stdout.flush() - try: - result = await download(sess, item["url"], os.path.join("images", bck, id)) - except Exception as e: - print("\nMeme acquisition failure:", e) + #print("done on", id) + + async def dl_task(item): + async with sem: + try: + await asyncio.wait_for(download_item(item), timeout=30) + except asyncio.TimeoutError: pass + + async for items in fetch_past(sess, "https://www.reddit.com/user/osmarks/m/memeharvesting/new", 20000): + #print("got new chunk") + await sem.acquire() + sem.release() + #print("downloading new set") + async with asyncio.TaskGroup() as tg: + for item in items: + if time_threshold and time_threshold > item["created"]: return - if result: - item["headers"] = result - with open(meta_path, "w") as fh: - json.dump(item, fh) - else: - print("!", end="") - sys.stdout.flush() - #print("done on", id) + tg.create_task(dl_task(item)) - async def dl_task(item): - async with sem: - try: - await asyncio.wait_for(download_item(item), timeout=30) - except asyncio.TimeoutError: pass - - async for items in fetch_past(sess, "https://www.reddit.com/user/osmarks/m/memeharvesting/new", 20000): - #print("got new chunk") - await sem.acquire() - sem.release() - #print("downloading new set") - async with asyncio.TaskGroup() as tg: - for item in items: - tg.create_task(dl_task(item)) - - asyncio.run(main()) \ No newline at end of file +if __name__ == "__main__": + threshold = None + if len(sys.argv) > 1: + print("thresholding at", sys.argv[1]) + threshold = float(sys.argv[1]) + asyncio.run(main(threshold)) \ No newline at end of file diff --git a/meme-rater/library_processing_server.py b/meme-rater/library_processing_server.py new file mode 100644 index 0000000..dd70cde --- /dev/null +++ b/meme-rater/library_processing_server.py @@ -0,0 +1,102 @@ +from aiohttp import web +import aiosqlite +import asyncio +import random +import sys +import json +import os +from pathlib import Path +import shutil + +PORT, DATABASE, TARGET_DIR = sys.argv[1:] +with open("rater_mse_config.json", "r") as f: + mse_config = json.load(f) + basedir = Path(mse_config["files"]) +TARGET_DIR = Path(TARGET_DIR) + +app = web.Application(client_max_size=32*1024**2) +routes = web.RouteTableDef() + +@routes.get("/") +async def index(request): + csr = await request.app["db"].execute("SELECT filename FROM library_queue ORDER BY score DESC") + filename, = await csr.fetchone() + await csr.close() + return web.Response(text=f""" + + + + +

Meme Processing

+
+ + + +
+ +
+
+ + + + """, content_type="text/html") + +def find_new_path(basename, ext): + ctr = 1 + while True: + new = TARGET_DIR / (basename + ("" if ctr == 1 else "-" + str(ctr)) + ext) + if not new.exists(): return new + ctr += 1 + +@routes.post("/") +async def rate(request): + db = request.app["db"] + post = await request.post() + filename = post["filename"] + original_filename = post["original_filename"] + real_path = basedir / original_filename + assert real_path.is_file() + if filename == "": # bad meme, discard + real_path.unlink() + else: + new_path = find_new_path(filename.replace(" ", "-"), real_path.suffix) + print(real_path, new_path, real_path.suffix) + shutil.move(real_path, new_path) + await db.execute("DELETE FROM library_queue WHERE filename = ?", (original_filename,)) + await db.commit() + return web.HTTPFound("/") + +async def main(): + app["db"] = await aiosqlite.connect(DATABASE) + app.router.add_routes(routes) + app.router.add_static("/memes/", "./images") + print("Ready") + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "", int(PORT)) + await site.start() + +if __name__ == "__main__": + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(main()) + loop.run_forever() diff --git a/meme-rater/meme_pipeline.py b/meme-rater/meme_pipeline.py new file mode 100644 index 0000000..819fe36 --- /dev/null +++ b/meme-rater/meme_pipeline.py @@ -0,0 +1,97 @@ +import subprocess +import torch +from tqdm import tqdm +import json +from pathlib import Path +import os +import asyncio +import aiohttp +import time + +import shared +from model import Config, BradleyTerry + +meme_search_backend = "http://localhost:1707/" +score_threshold = 1.5162627696990967 + +shared.db.executescript(""" +CREATE TABLE IF NOT EXISTS last_crawl (time INTEGER); +CREATE TABLE IF NOT EXISTS library_queue ( + filename TEXT PRIMARY KEY, + score REAL NOT NULL +); +""") +shared.db.commit() +csr = shared.db.execute("SELECT MAX(time) FROM last_crawl") +row = csr.fetchone() +last_crawl = row[0] or 0 +csr.close() + +with open("rater_mse_config.json", "r") as f: + mse_config = json.load(f) + basedir = Path(mse_config["files"]) + +print("crawling...") +crawl_start = time.time() +subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode() +print("indexing...") +subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode() +print("evaluating...") + +batch_size = 128 +device = "cuda" + +config = Config( + d_emb=1152, + n_hidden=1, + n_ensemble=16, + device=device, + dtype=torch.float32, + dropout=0.1 +) +model = BradleyTerry(config).float() +modelc, _ = shared.checkpoint_for(1500) +model.load_state_dict(torch.load(modelc)) +params = sum(p.numel() for p in model.parameters()) +print(f"{params/1e6:.1f}M parameters") +print(model) + +files = shared.fetch_all_files() +ratings = {} + +model.eval() +with torch.inference_mode(): + for bstart in tqdm(range(0, len(files), batch_size)): + batch = files[bstart:bstart + batch_size] + filenames = [ filename for filename, embedding in batch ] + embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ]) + inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device) + scores = model.ensemble(inputs).float() + mscores = torch.median(scores, dim=0).values + for filename, mscore in zip(filenames, mscores): + ratings[filename] = float(mscore) + +print(sorted(ratings.values())[round(len(ratings) * 0.85)]) + +files = dict(files) + +async def run_inserts(): + async with aiohttp.ClientSession(): + async def duplicate_exists(embedding): + async with aiohttp.request("POST", meme_search_backend, json={ + "embeddings": [ list(float(x) for x in embedding) ], # sorry + "top_k": 1 + }) as res: + result = await res.json() + closest = result[0]["score"] + return closest > 0.99 # arbitrary threshold, TODO + + for filename, rating in ratings.items(): + if rating > score_threshold and not await duplicate_exists(files[filename]): + shared.db.execute("INSERT OR REPLACE INTO library_queue VALUES (?, ?)", (filename, rating)) + else: + os.unlink(basedir / filename) + shared.db.execute("INSERT INTO last_crawl VALUES (?)", (crawl_start,)) + shared.db.commit() + +asyncio.run(run_inserts()) \ No newline at end of file diff --git a/mse.py b/mse.py index 68b68a9..36e1fba 100644 --- a/mse.py +++ b/mse.py @@ -42,10 +42,11 @@ async def run_query(request): if text := data.get("text", []): embeddings.extend(await clip_server({ "text": [ x for x, w in text ] })) weights = [ w for x, w in images ] + [ w for x, w in text ] - embeddings = [ e * w for e, w in zip(embeddings, weights) ] - if not embeddings: + weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ] + weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ]) + if not weighted_embeddings: return web.json_response([]) - return web.json_response(app["index"].search(sum(embeddings))) + return web.json_response(app["index"].search(sum(weighted_embeddings), top_k=data.get("top_k", 4000))) @routes.get("/") async def health_check(request): @@ -70,8 +71,8 @@ class Index: self.inference_server_config = inference_server_config self.lock = asyncio.Lock() - def search(self, query): - distances, indices = self.faiss_index.search(numpy.array([query]), 4000) + def search(self, query, top_k): + distances, indices = self.faiss_index.search(numpy.array([query]), top_k) distances = distances[0] indices = indices[0] try: @@ -214,6 +215,7 @@ async def main(): app["index"] = index await index.reload() print("Ready") + if CONFIG.get("no_run_server", False): return runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "", CONFIG["port"]) @@ -223,4 +225,4 @@ if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(main()) - loop.run_forever() + if CONFIG.get("no_run_server", False) == False: loop.run_forever()