mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-06-06 16:34:06 +00:00
51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
import sys, aiohttp, msgpack, numpy, pgvector.asyncpg, asyncio
|
|
|
|
async def use_emb_server(sess, query):
|
|
async with sess.post("http://100.64.0.10:1708/", data=msgpack.dumps(query), timeout=aiohttp.ClientTimeout(connect=5, sock_connect=5, sock_read=None)) as res:
|
|
response = msgpack.loads(await res.read())
|
|
if res.status == 200:
|
|
return response
|
|
else:
|
|
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
|
|
|
|
BATCH_SIZE = 32
|
|
|
|
async def main():
|
|
with open("query_data.bin", "wb") as f:
|
|
with open("queries.txt", "r") as g:
|
|
write_lock = asyncio.Lock()
|
|
async with aiohttp.ClientSession() as sess:
|
|
async with asyncio.TaskGroup() as tg:
|
|
sem = asyncio.Semaphore(3)
|
|
|
|
async def process_batch(batch):
|
|
while True:
|
|
try:
|
|
embs = await use_emb_server(sess, { "text": batch })
|
|
async with write_lock:
|
|
f.write(b"".join(embs))
|
|
sys.stdout.write(".")
|
|
sys.stdout.flush()
|
|
break
|
|
except Exception as e:
|
|
print(e)
|
|
await asyncio.sleep(5)
|
|
|
|
sem.release()
|
|
|
|
async def dispatch(batch):
|
|
await sem.acquire()
|
|
tg.create_task(process_batch(batch))
|
|
|
|
batch = []
|
|
while line := g.readline():
|
|
if line.strip(): batch.append(line.strip())
|
|
if len(batch) == BATCH_SIZE:
|
|
await dispatch(batch)
|
|
batch = []
|
|
if len(batch) > 0:
|
|
await dispatch(batch)
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|