1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00

Port meme acquisition pipeline to new API, database

Also fix a really stupid oversight in crawling code.
This commit is contained in:
osmarks 2024-05-22 15:43:56 +01:00
parent 30b1b72712
commit bd426a30ba
4 changed files with 20 additions and 8 deletions

View File

@ -55,6 +55,11 @@ filetypes = {
"image/webp": "webp", "image/webp": "webp",
"image/avif": "avif" "image/avif": "avif"
} }
hard_exclude = {
".mp4",
".mkv",
".webm"
}
CHUNK_SIZE = 1<<18 # entirely arbitrary CHUNK_SIZE = 1<<18 # entirely arbitrary
async def download(sess, url, file): async def download(sess, url, file):
@ -80,12 +85,19 @@ async def main(time_threshold):
bck = bucket(id) bck = bucket(id)
os.makedirs(os.path.join("images", bck), exist_ok=True) os.makedirs(os.path.join("images", bck), exist_ok=True)
os.makedirs(os.path.join("meta", bck), exist_ok=True) os.makedirs(os.path.join("meta", bck), exist_ok=True)
if not item.get("preview"): return
if not item["url"].startswith("https://"): return if not item["url"].startswith("https://"): return
meta_path = os.path.join("meta", bck, id + ".json") meta_path = os.path.join("meta", bck, id + ".json")
if not os.path.exists(meta_path): # sorry if not os.path.exists(meta_path): # sorry
print("|", end="") print("|", end="")
sys.stdout.flush() sys.stdout.flush()
excluded = False
for excl in hard_exclude:
if item["url"].endswith(excl):
excluded = True
break
try: try:
if not excluded:
result = await download(sess, item["url"], os.path.join("images", bck, id)) result = await download(sess, item["url"], os.path.join("images", bck, id))
except Exception as e: except Exception as e:
print("\nMeme acquisition failure:", e) print("\nMeme acquisition failure:", e)

View File

@ -35,7 +35,7 @@ print("crawling...")
crawl_start = time.time() crawl_start = time.time()
subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode() subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
print("indexing...") print("indexing...")
subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode() subprocess.run(["./meme-search-engine", "rater_mse_config.json"]).check_returncode()
print("evaluating...") print("evaluating...")
batch_size = 128 batch_size = 128
@ -80,11 +80,11 @@ async def run_inserts():
async with aiohttp.ClientSession(): async with aiohttp.ClientSession():
async def duplicate_exists(embedding): async def duplicate_exists(embedding):
async with aiohttp.request("POST", meme_search_backend, json={ async with aiohttp.request("POST", meme_search_backend, json={
"embeddings": [ list(float(x) for x in embedding) ], # sorry "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
"top_k": 1 "k": 1
}) as res: }) as res:
result = await res.json() result = await res.json()
closest = result[0]["score"] closest = result["matches"][0][0]
return closest > 0.99 # arbitrary threshold, TODO return closest > 0.99 # arbitrary threshold, TODO
for filename, rating in ratings.items(): for filename, rating in ratings.items():

View File

@ -1,5 +1,5 @@
{ {
"clip_server": "http://100.64.0.10:1708/", "clip_server": "http://100.64.0.10:1708",
"db_path": "data.sqlite3", "db_path": "data.sqlite3",
"port": 1707, "port": 1707,
"files": "./images", "files": "./images",

View File

@ -14,7 +14,7 @@ def is_val_set(meme1, meme2):
return is_one_val(meme1) or is_one_val(meme2) return is_one_val(meme1) or is_one_val(meme2)
def fetch_embedding(filename): def fetch_embedding(filename):
csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,)) csr = db.execute("SELECT embedding FROM files WHERE filename = ?", (filename,))
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16") x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
csr.close() csr.close()
return x.copy() # PyTorch complains otherwise due to bad return x.copy() # PyTorch complains otherwise due to bad
@ -44,7 +44,7 @@ def generate_random_permutations(x, n):
return out return out
def fetch_all_files(): def fetch_all_files():
csr = db.execute("SELECT filename, embedding_vector FROM files") csr = db.execute("SELECT filename, embedding FROM files WHERE embedding IS NOT NULL")
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ] x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
csr.close() csr.close()
return x return x