1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-06-06 16:34:06 +00:00
meme-search-engine/generate_queries_bin.py
2025-01-16 21:10:12 +00:00

51 lines
2.0 KiB
Python

import sys, aiohttp, msgpack, numpy, pgvector.asyncpg, asyncio
async def use_emb_server(sess, query):
async with sess.post("http://100.64.0.10:1708/", data=msgpack.dumps(query), timeout=aiohttp.ClientTimeout(connect=5, sock_connect=5, sock_read=None)) as res:
response = msgpack.loads(await res.read())
if res.status == 200:
return response
else:
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
BATCH_SIZE = 32
async def main():
with open("query_data.bin", "wb") as f:
with open("queries.txt", "r") as g:
write_lock = asyncio.Lock()
async with aiohttp.ClientSession() as sess:
async with asyncio.TaskGroup() as tg:
sem = asyncio.Semaphore(3)
async def process_batch(batch):
while True:
try:
embs = await use_emb_server(sess, { "text": batch })
async with write_lock:
f.write(b"".join(embs))
sys.stdout.write(".")
sys.stdout.flush()
break
except Exception as e:
print(e)
await asyncio.sleep(5)
sem.release()
async def dispatch(batch):
await sem.acquire()
tg.create_task(process_batch(batch))
batch = []
while line := g.readline():
if line.strip(): batch.append(line.strip())
if len(batch) == BATCH_SIZE:
await dispatch(batch)
batch = []
if len(batch) > 0:
await dispatch(batch)
if __name__ == "__main__":
asyncio.run(main())