1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-13 23:34:49 +00:00
meme-search-engine/mse.py
osmarks 7cb42e028f Rewrite entire application (well, backend) in Rust and also Go
I decided I wanted to integrate the experimental OCR thing better, so I rewrote in Go and also integrated the thumbnailer.
However, Go is a bad langauge and I only used it out of spite.
It turned out to have a very hard-to-fix memory leak due to some unclear interaction between libvips and both sets of bindings I tried, so I had Claude-3 transpile it to Rust then spent a while fixing the several mistakes it made and making tweaks.
The new Rust version works, although I need to actually do something with the OCR data and make the index queryable concurrently.
2024-05-21 00:09:04 +01:00

295 lines
12 KiB
Python

from aiohttp import web
import aiohttp
import asyncio
import traceback
import umsgpack
from PIL import Image
import base64
import aiosqlite
import faiss
import numpy
import os
import aiohttp_cors
import json
import io
import time
import sys
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import threading
with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)
app = web.Application(client_max_size=32*1024**2)
routes = web.RouteTableDef()
async def clip_server(sess: aiohttp.ClientSession, query, unpack_buffer=True):
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
response = umsgpack.loads(await res.read())
if res.status == 200:
if unpack_buffer:
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
return response
else:
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
@routes.post("/")
async def run_query(request):
sess = app["session"]
data = await request.json()
embeddings = []
if images := data.get("images", []):
target_image_size = app["index"].inference_server_config["image_size"]
embeddings.extend(await clip_server(sess, { "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] }))
if text := data.get("text", []):
embeddings.extend(await clip_server(sess, { "text": [ x for x, w in text ] }))
weights = [ w for x, w in images ] + [ w for x, w in text ]
weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
if not weighted_embeddings:
return web.json_response([])
return web.json_response(app["index"].search(sum(weighted_embeddings), top_k=data.get("top_k", 4000)))
@routes.get("/")
async def health_check(request):
return web.Response(text="OK")
@routes.post("/reload_index")
async def reload_index_route(request):
await request.app["index"].reload()
return web.json_response(True)
def load_image(path, image_size):
im = Image.open(path)
im.draft("RGB", image_size)
buf = io.BytesIO()
im.resize(image_size).convert("RGB").save(buf, format="BMP")
return buf.getvalue(), path
class Index:
def __init__(self, inference_server_config, http_session):
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
self.associated_filenames = []
self.inference_server_config = inference_server_config
self.lock = asyncio.Lock()
self.session = http_session
def search(self, query, top_k):
distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
distances = distances[0]
indices = indices[0]
try:
indices = indices[:numpy.where(indices==-1)[0][0]]
except IndexError: pass
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
async def run_ocr(self):
if not CONFIG.get("enable_ocr"): return
import ocr
print("Running OCR")
conn = await aiosqlite.connect(CONFIG["db_path"])
unocred = await conn.execute_fetchall("SELECT files.filename FROM files LEFT JOIN ocr ON files.filename = ocr.filename WHERE ocr.scan_time IS NULL OR ocr.scan_time < files.modtime")
ocr_sem = asyncio.Semaphore(20) # Google has more concurrency than our internal CLIP backend. I am sure they will be fine.
load_sem = threading.Semaphore(100) # provide backpressure in loading to avoid using 50GB of RAM (this happened)
async def run_image(filename, chunks):
try:
text, regions = await ocr.scan_chunks(self.session, chunks)
await conn.execute("INSERT OR REPLACE INTO ocr VALUES (?, ?, ?, ?)", (filename, time.time(), text, json.dumps(regions)))
await conn.commit()
sys.stdout.write(".")
sys.stdout.flush()
except:
print("OCR failed on", filename)
finally:
ocr_sem.release()
def load_and_chunk_image(filename):
load_sem.acquire()
im = Image.open(Path(CONFIG["files"]) / filename)
return filename, ocr.chunk_image(im)
async with asyncio.TaskGroup() as tg:
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_and_chunk_image, file[0]) for file in unocred ]):
filename, chunks = await task
await ocr_sem.acquire()
tg.create_task(run_image(filename, chunks))
load_sem.release()
async def reload(self):
async with self.lock:
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
print("Indexing")
conn = await aiosqlite.connect(CONFIG["db_path"])
conn.row_factory = aiosqlite.Row
await conn.executescript("""
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
modtime REAL NOT NULL,
embedding_vector BLOB NOT NULL
);
CREATE TABLE IF NOT EXISTS ocr (
filename TEXT PRIMARY KEY REFERENCES files(filename),
scan_time INTEGER NOT NULL,
text TEXT NOT NULL,
raw_segments TEXT
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
text,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON ocr BEGIN
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, new.text);
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON ocr BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, text) VALUES ('delete', old.rowid, old.filename, old.text);
END;
""")
try:
async with asyncio.TaskGroup() as tg:
batch_sem = asyncio.Semaphore(3)
modified = set()
async def do_batch(batch):
try:
query = { "images": [ arg[2] for arg in batch ] }
embeddings = await clip_server(self.session, query, False)
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
])
await conn.commit()
for filename, _, _ in batch:
modified.add(filename)
sys.stdout.write(".")
sys.stdout.flush()
finally:
batch_sem.release()
async def dispatch_batch(batch):
await batch_sem.acquire()
tg.create_task(do_batch(batch))
files = {}
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
files[filename] = modtime
await conn.commit()
batch = []
seen_files = set()
failed = set()
for dirpath, _, filenames in os.walk(CONFIG["files"]):
paths = set()
done = set()
for file in filenames:
path = os.path.join(dirpath, file)
file = os.path.relpath(path, CONFIG["files"])
st = os.stat(path)
seen_files.add(file)
if st.st_mtime != files.get(file):
paths.add(path)
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
try:
b, path = await task
st = os.stat(path)
file = os.path.relpath(path, CONFIG["files"])
done.add(path)
except Exception as e:
# print(file, "failed", e) we can't have access to file when we need it, oops
continue
batch.append((file, st.st_mtime, b))
if len(batch) == self.inference_server_config["batch"]:
await dispatch_batch(batch)
batch = []
failed |= paths - done
if batch:
await dispatch_batch(batch)
print()
for failed_ in failed:
print("Failed to load", failed_)
filenames_set = set(self.associated_filenames)
new_data = []
new_filenames = []
async with conn.execute("SELECT * FROM files") as csr:
while row := await csr.fetchone():
filename, modtime, embedding_vector = row
if filename not in filenames_set:
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
new_filenames.append(filename)
if not new_data: return
new_data = numpy.array(new_data)
self.associated_filenames.extend(new_filenames)
self.faiss_index.add(new_data)
remove_indices = []
for index, filename in enumerate(self.associated_filenames):
if filename not in seen_files or filename in modified:
remove_indices.append(index)
self.associated_filenames[index] = None
if filename not in seen_files:
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
await conn.commit()
print("Deleting", len(remove_indices), "old entries")
# TODO concurrency
# TODO understand what that comment meant
if remove_indices:
self.faiss_index.remove_ids(numpy.array(remove_indices))
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]
finally:
await conn.close()
await self.run_ocr()
app.router.add_routes(routes)
cors = aiohttp_cors.setup(app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=False,
expose_headers="*",
allow_headers="*",
)
})
for route in list(app.router.routes()):
cors.add(route)
async def main():
sess = aiohttp.ClientSession()
while True:
try:
async with await sess.get(CONFIG["clip_server"] + "config") as res:
inference_server_config = umsgpack.unpackb(await res.read())
print("Backend config:", inference_server_config)
break
except:
traceback.print_exc()
await asyncio.sleep(1)
index = Index(inference_server_config, sess)
app["index"] = index
app["session"] = sess
await index.reload()
print("Ready")
if CONFIG.get("no_run_server", False): return
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", CONFIG["port"])
await site.start()
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
if CONFIG.get("no_run_server", False) == False: loop.run_forever()