mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
Port meme acquisition pipeline to new API, database
Also fix a really stupid oversight in crawling code.
This commit is contained in:
parent
30b1b72712
commit
bd426a30ba
@ -55,6 +55,11 @@ filetypes = {
|
|||||||
"image/webp": "webp",
|
"image/webp": "webp",
|
||||||
"image/avif": "avif"
|
"image/avif": "avif"
|
||||||
}
|
}
|
||||||
|
hard_exclude = {
|
||||||
|
".mp4",
|
||||||
|
".mkv",
|
||||||
|
".webm"
|
||||||
|
}
|
||||||
|
|
||||||
CHUNK_SIZE = 1<<18 # entirely arbitrary
|
CHUNK_SIZE = 1<<18 # entirely arbitrary
|
||||||
async def download(sess, url, file):
|
async def download(sess, url, file):
|
||||||
@ -80,12 +85,19 @@ async def main(time_threshold):
|
|||||||
bck = bucket(id)
|
bck = bucket(id)
|
||||||
os.makedirs(os.path.join("images", bck), exist_ok=True)
|
os.makedirs(os.path.join("images", bck), exist_ok=True)
|
||||||
os.makedirs(os.path.join("meta", bck), exist_ok=True)
|
os.makedirs(os.path.join("meta", bck), exist_ok=True)
|
||||||
|
if not item.get("preview"): return
|
||||||
if not item["url"].startswith("https://"): return
|
if not item["url"].startswith("https://"): return
|
||||||
meta_path = os.path.join("meta", bck, id + ".json")
|
meta_path = os.path.join("meta", bck, id + ".json")
|
||||||
if not os.path.exists(meta_path): # sorry
|
if not os.path.exists(meta_path): # sorry
|
||||||
print("|", end="")
|
print("|", end="")
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
excluded = False
|
||||||
|
for excl in hard_exclude:
|
||||||
|
if item["url"].endswith(excl):
|
||||||
|
excluded = True
|
||||||
|
break
|
||||||
try:
|
try:
|
||||||
|
if not excluded:
|
||||||
result = await download(sess, item["url"], os.path.join("images", bck, id))
|
result = await download(sess, item["url"], os.path.join("images", bck, id))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("\nMeme acquisition failure:", e)
|
print("\nMeme acquisition failure:", e)
|
||||||
|
@ -35,7 +35,7 @@ print("crawling...")
|
|||||||
crawl_start = time.time()
|
crawl_start = time.time()
|
||||||
subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
|
subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
|
||||||
print("indexing...")
|
print("indexing...")
|
||||||
subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode()
|
subprocess.run(["./meme-search-engine", "rater_mse_config.json"]).check_returncode()
|
||||||
print("evaluating...")
|
print("evaluating...")
|
||||||
|
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
@ -80,11 +80,11 @@ async def run_inserts():
|
|||||||
async with aiohttp.ClientSession():
|
async with aiohttp.ClientSession():
|
||||||
async def duplicate_exists(embedding):
|
async def duplicate_exists(embedding):
|
||||||
async with aiohttp.request("POST", meme_search_backend, json={
|
async with aiohttp.request("POST", meme_search_backend, json={
|
||||||
"embeddings": [ list(float(x) for x in embedding) ], # sorry
|
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
||||||
"top_k": 1
|
"k": 1
|
||||||
}) as res:
|
}) as res:
|
||||||
result = await res.json()
|
result = await res.json()
|
||||||
closest = result[0]["score"]
|
closest = result["matches"][0][0]
|
||||||
return closest > 0.99 # arbitrary threshold, TODO
|
return closest > 0.99 # arbitrary threshold, TODO
|
||||||
|
|
||||||
for filename, rating in ratings.items():
|
for filename, rating in ratings.items():
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"clip_server": "http://100.64.0.10:1708/",
|
"clip_server": "http://100.64.0.10:1708",
|
||||||
"db_path": "data.sqlite3",
|
"db_path": "data.sqlite3",
|
||||||
"port": 1707,
|
"port": 1707,
|
||||||
"files": "./images",
|
"files": "./images",
|
||||||
|
@ -14,7 +14,7 @@ def is_val_set(meme1, meme2):
|
|||||||
return is_one_val(meme1) or is_one_val(meme2)
|
return is_one_val(meme1) or is_one_val(meme2)
|
||||||
|
|
||||||
def fetch_embedding(filename):
|
def fetch_embedding(filename):
|
||||||
csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,))
|
csr = db.execute("SELECT embedding FROM files WHERE filename = ?", (filename,))
|
||||||
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
|
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
|
||||||
csr.close()
|
csr.close()
|
||||||
return x.copy() # PyTorch complains otherwise due to bad
|
return x.copy() # PyTorch complains otherwise due to bad
|
||||||
@ -44,7 +44,7 @@ def generate_random_permutations(x, n):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def fetch_all_files():
|
def fetch_all_files():
|
||||||
csr = db.execute("SELECT filename, embedding_vector FROM files")
|
csr = db.execute("SELECT filename, embedding FROM files WHERE embedding IS NOT NULL")
|
||||||
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
|
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
|
||||||
csr.close()
|
csr.close()
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user