1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-09-21 18:19:35 +00:00
meme-search-engine/mse.py

295 lines
12 KiB
Python
Raw Normal View History

2023-09-28 16:30:20 +00:00
from aiohttp import web
import aiohttp
import asyncio
import traceback
import umsgpack
from PIL import Image
import base64
import aiosqlite
import faiss
import numpy
import os
import aiohttp_cors
import json
2023-09-29 17:34:06 +00:00
import io
import time
2023-09-28 16:30:20 +00:00
import sys
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import threading
2023-09-28 16:30:20 +00:00
with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)
app = web.Application(client_max_size=32*1024**2)
routes = web.RouteTableDef()
async def clip_server(sess: aiohttp.ClientSession, query, unpack_buffer=True):
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
response = umsgpack.loads(await res.read())
if res.status == 200:
if unpack_buffer:
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
return response
else:
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
2023-09-28 16:30:20 +00:00
@routes.post("/")
async def run_query(request):
sess = app["session"]
2023-09-28 16:30:20 +00:00
data = await request.json()
embeddings = []
if images := data.get("images", []):
2023-10-08 21:52:17 +00:00
target_image_size = app["index"].inference_server_config["image_size"]
embeddings.extend(await clip_server(sess, { "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] }))
2023-09-28 16:30:20 +00:00
if text := data.get("text", []):
embeddings.extend(await clip_server(sess, { "text": [ x for x, w in text ] }))
2023-09-28 16:30:20 +00:00
weights = [ w for x, w in images ] + [ w for x, w in text ]
2024-04-22 17:44:29 +00:00
weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
if not weighted_embeddings:
2023-09-28 16:30:20 +00:00
return web.json_response([])
2024-04-22 17:44:29 +00:00
return web.json_response(app["index"].search(sum(weighted_embeddings), top_k=data.get("top_k", 4000)))
2023-09-28 16:30:20 +00:00
@routes.get("/")
async def health_check(request):
return web.Response(text="OK")
@routes.post("/reload_index")
async def reload_index_route(request):
await request.app["index"].reload()
return web.json_response(True)
2023-10-08 21:52:17 +00:00
def load_image(path, image_size):
im = Image.open(path)
im.draft("RGB", image_size)
buf = io.BytesIO()
im.resize(image_size).convert("RGB").save(buf, format="BMP")
return buf.getvalue(), path
2023-09-28 16:30:20 +00:00
class Index:
def __init__(self, inference_server_config, http_session):
2023-09-28 16:30:20 +00:00
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
self.associated_filenames = []
self.inference_server_config = inference_server_config
self.lock = asyncio.Lock()
self.session = http_session
2023-09-28 16:30:20 +00:00
2024-04-22 17:44:29 +00:00
def search(self, query, top_k):
distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
2023-09-28 16:30:20 +00:00
distances = distances[0]
indices = indices[0]
try:
indices = indices[:numpy.where(indices==-1)[0][0]]
except IndexError: pass
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
async def run_ocr(self):
if not CONFIG.get("enable_ocr"): return
import ocr
print("Running OCR")
conn = await aiosqlite.connect(CONFIG["db_path"])
unocred = await conn.execute_fetchall("SELECT files.filename FROM files LEFT JOIN ocr ON files.filename = ocr.filename WHERE ocr.scan_time IS NULL OR ocr.scan_time < files.modtime")
ocr_sem = asyncio.Semaphore(20) # Google has more concurrency than our internal CLIP backend. I am sure they will be fine.
load_sem = threading.Semaphore(100) # provide backpressure in loading to avoid using 50GB of RAM (this happened)
async def run_image(filename, chunks):
try:
text, regions = await ocr.scan_chunks(self.session, chunks)
await conn.execute("INSERT OR REPLACE INTO ocr VALUES (?, ?, ?, ?)", (filename, time.time(), text, json.dumps(regions)))
await conn.commit()
sys.stdout.write(".")
sys.stdout.flush()
except:
print("OCR failed on", filename)
finally:
ocr_sem.release()
def load_and_chunk_image(filename):
load_sem.acquire()
im = Image.open(Path(CONFIG["files"]) / filename)
return filename, ocr.chunk_image(im)
async with asyncio.TaskGroup() as tg:
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_and_chunk_image, file[0]) for file in unocred ]):
filename, chunks = await task
await ocr_sem.acquire()
tg.create_task(run_image(filename, chunks))
load_sem.release()
2023-09-28 16:30:20 +00:00
async def reload(self):
async with self.lock:
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
2023-10-08 21:52:17 +00:00
print("Indexing")
conn = await aiosqlite.connect(CONFIG["db_path"])
2023-10-08 21:52:17 +00:00
conn.row_factory = aiosqlite.Row
await conn.executescript("""
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
modtime REAL NOT NULL,
embedding_vector BLOB NOT NULL
);
CREATE TABLE IF NOT EXISTS ocr (
filename TEXT PRIMARY KEY REFERENCES files(filename),
scan_time INTEGER NOT NULL,
text TEXT NOT NULL,
raw_segments TEXT
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
text,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON ocr BEGIN
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, new.text);
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON ocr BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, text) VALUES ('delete', old.rowid, old.filename, old.text);
END;
2023-10-08 21:52:17 +00:00
""")
try:
async with asyncio.TaskGroup() as tg:
batch_sem = asyncio.Semaphore(3)
modified = set()
async def do_batch(batch):
try:
query = { "images": [ arg[2] for arg in batch ] }
embeddings = await clip_server(self.session, query, False)
2023-10-08 21:52:17 +00:00
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
])
await conn.commit()
for filename, _, _ in batch:
modified.add(filename)
sys.stdout.write(".")
sys.stdout.flush()
finally:
batch_sem.release()
async def dispatch_batch(batch):
await batch_sem.acquire()
tg.create_task(do_batch(batch))
files = {}
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
files[filename] = modtime
await conn.commit()
batch = []
2024-01-02 14:12:26 +00:00
seen_files = set()
2023-10-08 21:52:17 +00:00
failed = set()
for dirpath, _, filenames in os.walk(CONFIG["files"]):
paths = set()
done = set()
for file in filenames:
path = os.path.join(dirpath, file)
file = os.path.relpath(path, CONFIG["files"])
st = os.stat(path)
2024-01-02 14:12:26 +00:00
seen_files.add(file)
2023-10-08 21:52:17 +00:00
if st.st_mtime != files.get(file):
paths.add(path)
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
2023-09-28 16:30:20 +00:00
try:
2023-10-08 21:52:17 +00:00
b, path = await task
st = os.stat(path)
file = os.path.relpath(path, CONFIG["files"])
done.add(path)
2023-09-28 16:30:20 +00:00
except Exception as e:
2023-10-08 21:52:17 +00:00
# print(file, "failed", e) we can't have access to file when we need it, oops
2023-09-28 16:30:20 +00:00
continue
batch.append((file, st.st_mtime, b))
2023-10-08 21:52:17 +00:00
if len(batch) == self.inference_server_config["batch"]:
2023-09-28 16:30:20 +00:00
await dispatch_batch(batch)
batch = []
2023-10-08 21:52:17 +00:00
failed |= paths - done
if batch:
await dispatch_batch(batch)
print()
for failed_ in failed:
2024-01-02 14:12:26 +00:00
print("Failed to load", failed_)
2023-10-08 21:52:17 +00:00
filenames_set = set(self.associated_filenames)
new_data = []
new_filenames = []
async with conn.execute("SELECT * FROM files") as csr:
while row := await csr.fetchone():
filename, modtime, embedding_vector = row
if filename not in filenames_set:
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
new_filenames.append(filename)
2023-10-27 14:50:21 +00:00
if not new_data: return
2023-10-08 21:52:17 +00:00
new_data = numpy.array(new_data)
self.associated_filenames.extend(new_filenames)
self.faiss_index.add(new_data)
2024-01-02 14:12:26 +00:00
remove_indices = []
for index, filename in enumerate(self.associated_filenames):
if filename not in seen_files or filename in modified:
remove_indices.append(index)
self.associated_filenames[index] = None
if filename not in seen_files:
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
await conn.commit()
print("Deleting", len(remove_indices), "old entries")
# TODO concurrency
# TODO understand what that comment meant
if remove_indices:
self.faiss_index.remove_ids(numpy.array(remove_indices))
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]
2023-10-08 21:52:17 +00:00
finally:
await conn.close()
2023-09-28 16:30:20 +00:00
await self.run_ocr()
2023-09-28 16:30:20 +00:00
app.router.add_routes(routes)
cors = aiohttp_cors.setup(app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=False,
expose_headers="*",
allow_headers="*",
)
})
for route in list(app.router.routes()):
cors.add(route)
async def main():
sess = aiohttp.ClientSession()
2023-09-28 16:30:20 +00:00
while True:
try:
async with await sess.get(CONFIG["clip_server"] + "config") as res:
inference_server_config = umsgpack.unpackb(await res.read())
print("Backend config:", inference_server_config)
break
except:
traceback.print_exc()
await asyncio.sleep(1)
index = Index(inference_server_config, sess)
2023-09-28 16:30:20 +00:00
app["index"] = index
app["session"] = sess
2023-09-28 16:30:20 +00:00
await index.reload()
print("Ready")
2024-04-22 17:44:29 +00:00
if CONFIG.get("no_run_server", False): return
2023-09-28 16:30:20 +00:00
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", CONFIG["port"])
await site.start()
2023-10-08 21:52:17 +00:00
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
2024-04-22 17:44:29 +00:00
if CONFIG.get("no_run_server", False) == False: loop.run_forever()