mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
Port meme acquisition pipeline to new API, database
Also fix a really stupid oversight in crawling code.
This commit is contained in:
parent
30b1b72712
commit
bd426a30ba
@ -55,6 +55,11 @@ filetypes = {
|
||||
"image/webp": "webp",
|
||||
"image/avif": "avif"
|
||||
}
|
||||
hard_exclude = {
|
||||
".mp4",
|
||||
".mkv",
|
||||
".webm"
|
||||
}
|
||||
|
||||
CHUNK_SIZE = 1<<18 # entirely arbitrary
|
||||
async def download(sess, url, file):
|
||||
@ -80,12 +85,19 @@ async def main(time_threshold):
|
||||
bck = bucket(id)
|
||||
os.makedirs(os.path.join("images", bck), exist_ok=True)
|
||||
os.makedirs(os.path.join("meta", bck), exist_ok=True)
|
||||
if not item.get("preview"): return
|
||||
if not item["url"].startswith("https://"): return
|
||||
meta_path = os.path.join("meta", bck, id + ".json")
|
||||
if not os.path.exists(meta_path): # sorry
|
||||
print("|", end="")
|
||||
sys.stdout.flush()
|
||||
excluded = False
|
||||
for excl in hard_exclude:
|
||||
if item["url"].endswith(excl):
|
||||
excluded = True
|
||||
break
|
||||
try:
|
||||
if not excluded:
|
||||
result = await download(sess, item["url"], os.path.join("images", bck, id))
|
||||
except Exception as e:
|
||||
print("\nMeme acquisition failure:", e)
|
||||
|
@ -35,7 +35,7 @@ print("crawling...")
|
||||
crawl_start = time.time()
|
||||
subprocess.run(["python", "crawler.py", str(last_crawl)]).check_returncode()
|
||||
print("indexing...")
|
||||
subprocess.run(["python", "../mse.py", "rater_mse_config.json"]).check_returncode()
|
||||
subprocess.run(["./meme-search-engine", "rater_mse_config.json"]).check_returncode()
|
||||
print("evaluating...")
|
||||
|
||||
batch_size = 128
|
||||
@ -80,11 +80,11 @@ async def run_inserts():
|
||||
async with aiohttp.ClientSession():
|
||||
async def duplicate_exists(embedding):
|
||||
async with aiohttp.request("POST", meme_search_backend, json={
|
||||
"embeddings": [ list(float(x) for x in embedding) ], # sorry
|
||||
"top_k": 1
|
||||
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
||||
"k": 1
|
||||
}) as res:
|
||||
result = await res.json()
|
||||
closest = result[0]["score"]
|
||||
closest = result["matches"][0][0]
|
||||
return closest > 0.99 # arbitrary threshold, TODO
|
||||
|
||||
for filename, rating in ratings.items():
|
||||
|
@ -1,5 +1,5 @@
|
||||
{
|
||||
"clip_server": "http://100.64.0.10:1708/",
|
||||
"clip_server": "http://100.64.0.10:1708",
|
||||
"db_path": "data.sqlite3",
|
||||
"port": 1707,
|
||||
"files": "./images",
|
||||
|
@ -14,7 +14,7 @@ def is_val_set(meme1, meme2):
|
||||
return is_one_val(meme1) or is_one_val(meme2)
|
||||
|
||||
def fetch_embedding(filename):
|
||||
csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,))
|
||||
csr = db.execute("SELECT embedding FROM files WHERE filename = ?", (filename,))
|
||||
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
|
||||
csr.close()
|
||||
return x.copy() # PyTorch complains otherwise due to bad
|
||||
@ -44,7 +44,7 @@ def generate_random_permutations(x, n):
|
||||
return out
|
||||
|
||||
def fetch_all_files():
|
||||
csr = db.execute("SELECT filename, embedding_vector FROM files")
|
||||
csr = db.execute("SELECT filename, embedding FROM files WHERE embedding IS NOT NULL")
|
||||
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
|
||||
csr.close()
|
||||
return x
|
||||
|
Loading…
Reference in New Issue
Block a user