diff --git a/meme-rater/crawler.py b/meme-rater/crawler.py
index 52bb27d..0a5e951 100644
--- a/meme-rater/crawler.py
+++ b/meme-rater/crawler.py
@@ -67,52 +67,58 @@ async def download(sess, url, file):
await fh.write(chunk)
return dict(res.headers)
-if __name__ == "__main__":
- async def main():
- sem = asyncio.Semaphore(16)
-
- async with aiohttp.ClientSession() as sess:
- async def download_item(item):
- #print("starting on", item["name"])
- print(".", end="")
+async def main(time_threshold):
+ sem = asyncio.Semaphore(16)
+
+ async with aiohttp.ClientSession() as sess:
+ async def download_item(item):
+ #print("starting on", item["name"])
+ print(".", end="")
+ sys.stdout.flush()
+ if item["over_18"] or not item["is_robot_indexable"]: return
+ id = item["name"]
+ bck = bucket(id)
+ os.makedirs(os.path.join("images", bck), exist_ok=True)
+ os.makedirs(os.path.join("meta", bck), exist_ok=True)
+ if not item["url"].startswith("https://"): return
+ meta_path = os.path.join("meta", bck, id + ".json")
+ if not os.path.exists(meta_path): # sorry
+ print("|", end="")
sys.stdout.flush()
- if item["over_18"] or not item["is_robot_indexable"]: return
- id = item["name"]
- bck = bucket(id)
- os.makedirs(os.path.join("images", bck), exist_ok=True)
- os.makedirs(os.path.join("meta", bck), exist_ok=True)
- if not item["url"].startswith("https://"): return
- meta_path = os.path.join("meta", bck, id + ".json")
- if not os.path.exists(meta_path): # sorry
- print("|", end="")
+ try:
+ result = await download(sess, item["url"], os.path.join("images", bck, id))
+ except Exception as e:
+ print("\nMeme acquisition failure:", e)
+ return
+ if result:
+ item["headers"] = result
+ with open(meta_path, "w") as fh:
+ json.dump(item, fh)
+ else:
+ print("!", end="")
sys.stdout.flush()
- try:
- result = await download(sess, item["url"], os.path.join("images", bck, id))
- except Exception as e:
- print("\nMeme acquisition failure:", e)
+ #print("done on", id)
+
+ async def dl_task(item):
+ async with sem:
+ try:
+ await asyncio.wait_for(download_item(item), timeout=30)
+ except asyncio.TimeoutError: pass
+
+ async for items in fetch_past(sess, "https://www.reddit.com/user/osmarks/m/memeharvesting/new", 20000):
+ #print("got new chunk")
+ await sem.acquire()
+ sem.release()
+ #print("downloading new set")
+ async with asyncio.TaskGroup() as tg:
+ for item in items:
+ if time_threshold and time_threshold > item["created"]:
return
- if result:
- item["headers"] = result
- with open(meta_path, "w") as fh:
- json.dump(item, fh)
- else:
- print("!", end="")
- sys.stdout.flush()
- #print("done on", id)
+ tg.create_task(dl_task(item))
- async def dl_task(item):
- async with sem:
- try:
- await asyncio.wait_for(download_item(item), timeout=30)
- except asyncio.TimeoutError: pass
-
- async for items in fetch_past(sess, "https://www.reddit.com/user/osmarks/m/memeharvesting/new", 20000):
- #print("got new chunk")
- await sem.acquire()
- sem.release()
- #print("downloading new set")
- async with asyncio.TaskGroup() as tg:
- for item in items:
- tg.create_task(dl_task(item))
-
- asyncio.run(main())
\ No newline at end of file
+if __name__ == "__main__":
+ threshold = None
+ if len(sys.argv) > 1:
+ print("thresholding at", sys.argv[1])
+ threshold = float(sys.argv[1])
+ asyncio.run(main(threshold))
\ No newline at end of file
diff --git a/meme-rater/library_processing_server.py b/meme-rater/library_processing_server.py
new file mode 100644
index 0000000..dd70cde
--- /dev/null
+++ b/meme-rater/library_processing_server.py
@@ -0,0 +1,102 @@
+from aiohttp import web
+import aiosqlite
+import asyncio
+import random
+import sys
+import json
+import os
+from pathlib import Path
+import shutil
+
+PORT, DATABASE, TARGET_DIR = sys.argv[1:]
+with open("rater_mse_config.json", "r") as f:
+ mse_config = json.load(f)
+ basedir = Path(mse_config["files"])
+TARGET_DIR = Path(TARGET_DIR)
+
+app = web.Application(client_max_size=32*1024**2)
+routes = web.RouteTableDef()
+
+@routes.get("/")
+async def index(request):
+ csr = await request.app["db"].execute("SELECT filename FROM library_queue ORDER BY score DESC")
+ filename, = await csr.fetchone()
+ await csr.close()
+ return web.Response(text=f"""
+
+
+
+
+ Meme Processing
+
+
+
+
+ """, content_type="text/html")
+
+def find_new_path(basename, ext):
+ ctr = 1
+ while True:
+ new = TARGET_DIR / (basename + ("" if ctr == 1 else "-" + str(ctr)) + ext)
+ if not new.exists(): return new
+ ctr += 1
+
+@routes.post("/")
+async def rate(request):
+ db = request.app["db"]
+ post = await request.post()
+ filename = post["filename"]
+ original_filename = post["original_filename"]
+ real_path = basedir / original_filename
+ assert real_path.is_file()
+ if filename == "": # bad meme, discard
+ real_path.unlink()
+ else:
+ new_path = find_new_path(filename.replace(" ", "-"), real_path.suffix)
+ print(real_path, new_path, real_path.suffix)
+ shutil.move(real_path, new_path)
+ await db.execute("DELETE FROM library_queue WHERE filename = ?", (original_filename,))
+ await db.commit()
+ return web.HTTPFound("/")
+
+async def main():
+ app["db"] = await aiosqlite.connect(DATABASE)
+ app.router.add_routes(routes)
+ app.router.add_static("/memes/", "./images")
+ print("Ready")
+ runner = web.AppRunner(app)
+ await runner.setup()
+ site = web.TCPSite(runner, "", int(PORT))
+ await site.start()
+
+if __name__ == "__main__":
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ loop.run_until_complete(main())
+ loop.run_forever()
diff --git a/meme-rater/meme_pipeline.py b/meme-rater/meme_pipeline.py
new file mode 100644
index 0000000..819fe36
--- /dev/null
+++ b/meme-rater/meme_pipeline.py
@@ -0,0 +1,97 @@
+import subprocess
+import torch
+from tqdm import tqdm
+import json
+from pathlib import Path
+import os
+import asyncio
+import aiohttp
+import time
+
+import shared
+from model import Config, BradleyTerry
+
+meme_search_backend = "http://localhost:1707/"
+score_threshold = 1.5162627696990967
+
+shared.db.executescript("""
+CREATE TABLE IF NOT EXISTS last_crawl (time INTEGER);
+CREATE TABLE IF NOT EXISTS library_queue (
+ filename TEXT PRIMARY KEY,
+ score REAL NOT NULL
+);
+""")
+shared.db.commit()
+csr = shared.db.execute("SELECT MAX(time) FROM last_crawl")
+row = csr.fetchone()
+last_crawl = row[0] or 0
+csr.close()
+
+with open("rater_mse_config.json", "r") as f:
+ mse_config = json.load(f)
+ basedir = Path(mse_config["files"])
+
+print("crawling...")
+crawl_start = time.time()
+subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
+print("indexing...")
+subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode()
+print("evaluating...")
+
+batch_size = 128
+device = "cuda"
+
+config = Config(
+ d_emb=1152,
+ n_hidden=1,
+ n_ensemble=16,
+ device=device,
+ dtype=torch.float32,
+ dropout=0.1
+)
+model = BradleyTerry(config).float()
+modelc, _ = shared.checkpoint_for(1500)
+model.load_state_dict(torch.load(modelc))
+params = sum(p.numel() for p in model.parameters())
+print(f"{params/1e6:.1f}M parameters")
+print(model)
+
+files = shared.fetch_all_files()
+ratings = {}
+
+model.eval()
+with torch.inference_mode():
+ for bstart in tqdm(range(0, len(files), batch_size)):
+ batch = files[bstart:bstart + batch_size]
+ filenames = [ filename for filename, embedding in batch ]
+ embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
+ inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
+ scores = model.ensemble(inputs).float()
+ mscores = torch.median(scores, dim=0).values
+ for filename, mscore in zip(filenames, mscores):
+ ratings[filename] = float(mscore)
+
+print(sorted(ratings.values())[round(len(ratings) * 0.85)])
+
+files = dict(files)
+
+async def run_inserts():
+ async with aiohttp.ClientSession():
+ async def duplicate_exists(embedding):
+ async with aiohttp.request("POST", meme_search_backend, json={
+ "embeddings": [ list(float(x) for x in embedding) ], # sorry
+ "top_k": 1
+ }) as res:
+ result = await res.json()
+ closest = result[0]["score"]
+ return closest > 0.99 # arbitrary threshold, TODO
+
+ for filename, rating in ratings.items():
+ if rating > score_threshold and not await duplicate_exists(files[filename]):
+ shared.db.execute("INSERT OR REPLACE INTO library_queue VALUES (?, ?)", (filename, rating))
+ else:
+ os.unlink(basedir / filename)
+ shared.db.execute("INSERT INTO last_crawl VALUES (?)", (crawl_start,))
+ shared.db.commit()
+
+asyncio.run(run_inserts())
\ No newline at end of file
diff --git a/mse.py b/mse.py
index 68b68a9..36e1fba 100644
--- a/mse.py
+++ b/mse.py
@@ -42,10 +42,11 @@ async def run_query(request):
if text := data.get("text", []):
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
weights = [ w for x, w in images ] + [ w for x, w in text ]
- embeddings = [ e * w for e, w in zip(embeddings, weights) ]
- if not embeddings:
+ weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
+ weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
+ if not weighted_embeddings:
return web.json_response([])
- return web.json_response(app["index"].search(sum(embeddings)))
+ return web.json_response(app["index"].search(sum(weighted_embeddings), top_k=data.get("top_k", 4000)))
@routes.get("/")
async def health_check(request):
@@ -70,8 +71,8 @@ class Index:
self.inference_server_config = inference_server_config
self.lock = asyncio.Lock()
- def search(self, query):
- distances, indices = self.faiss_index.search(numpy.array([query]), 4000)
+ def search(self, query, top_k):
+ distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
distances = distances[0]
indices = indices[0]
try:
@@ -214,6 +215,7 @@ async def main():
app["index"] = index
await index.reload()
print("Ready")
+ if CONFIG.get("no_run_server", False): return
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", CONFIG["port"])
@@ -223,4 +225,4 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
- loop.run_forever()
+ if CONFIG.get("no_run_server", False) == False: loop.run_forever()