diff --git a/meme-rater/crawler.py b/meme-rater/crawler.py index 0a5e951..f02bb41 100644 --- a/meme-rater/crawler.py +++ b/meme-rater/crawler.py @@ -55,6 +55,11 @@ filetypes = { "image/webp": "webp", "image/avif": "avif" } +hard_exclude = { + ".mp4", + ".mkv", + ".webm" +} CHUNK_SIZE = 1<<18 # entirely arbitrary async def download(sess, url, file): @@ -80,13 +85,20 @@ async def main(time_threshold): bck = bucket(id) os.makedirs(os.path.join("images", bck), exist_ok=True) os.makedirs(os.path.join("meta", bck), exist_ok=True) + if not item.get("preview"): return if not item["url"].startswith("https://"): return meta_path = os.path.join("meta", bck, id + ".json") if not os.path.exists(meta_path): # sorry print("|", end="") sys.stdout.flush() + excluded = False + for excl in hard_exclude: + if item["url"].endswith(excl): + excluded = True + break try: - result = await download(sess, item["url"], os.path.join("images", bck, id)) + if not excluded: + result = await download(sess, item["url"], os.path.join("images", bck, id)) except Exception as e: print("\nMeme acquisition failure:", e) return diff --git a/meme-rater/meme_pipeline.py b/meme-rater/meme_pipeline.py index 9853842..74b2deb 100644 --- a/meme-rater/meme_pipeline.py +++ b/meme-rater/meme_pipeline.py @@ -35,7 +35,7 @@ print("crawling...") crawl_start = time.time() subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode() print("indexing...") -subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode() +subprocess.run(["./meme-search-engine", "rater_mse_config.json"]).check_returncode() print("evaluating...") batch_size = 128 @@ -80,11 +80,11 @@ async def run_inserts(): async with aiohttp.ClientSession(): async def duplicate_exists(embedding): async with aiohttp.request("POST", meme_search_backend, json={ - "embeddings": [ list(float(x) for x in embedding) ], # sorry - "top_k": 1 + "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry + "k": 1 }) as res: result = await res.json() - closest = result[0]["score"] + closest = result["matches"][0][0] return closest > 0.99 # arbitrary threshold, TODO for filename, rating in ratings.items(): diff --git a/meme-rater/rater_mse_config.json b/meme-rater/rater_mse_config.json index a3e7225..c322caf 100644 --- a/meme-rater/rater_mse_config.json +++ b/meme-rater/rater_mse_config.json @@ -1,5 +1,5 @@ { - "clip_server": "http://100.64.0.10:1708/", + "clip_server": "http://100.64.0.10:1708", "db_path": "data.sqlite3", "port": 1707, "files": "./images", diff --git a/meme-rater/shared.py b/meme-rater/shared.py index a8398d7..76360c3 100644 --- a/meme-rater/shared.py +++ b/meme-rater/shared.py @@ -14,7 +14,7 @@ def is_val_set(meme1, meme2): return is_one_val(meme1) or is_one_val(meme2) def fetch_embedding(filename): - csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,)) + csr = db.execute("SELECT embedding FROM files WHERE filename = ?", (filename,)) x = numpy.frombuffer(csr.fetchone()[0], dtype="float16") csr.close() return x.copy() # PyTorch complains otherwise due to bad @@ -44,7 +44,7 @@ def generate_random_permutations(x, n): return out def fetch_all_files(): - csr = db.execute("SELECT filename, embedding_vector FROM files") + csr = db.execute("SELECT filename, embedding FROM files WHERE embedding IS NOT NULL") x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ] csr.close() return x