mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-13 23:34:49 +00:00
osmarks
7cb42e028f
I decided I wanted to integrate the experimental OCR thing better, so I rewrote in Go and also integrated the thumbnailer. However, Go is a bad langauge and I only used it out of spite. It turned out to have a very hard-to-fix memory leak due to some unclear interaction between libvips and both sets of bindings I tried, so I had Claude-3 transpile it to Rust then spent a while fixing the several mistakes it made and making tweaks. The new Rust version works, although I need to actually do something with the OCR data and make the index queryable concurrently.
295 lines
12 KiB
Python
295 lines
12 KiB
Python
from aiohttp import web
|
|
import aiohttp
|
|
import asyncio
|
|
import traceback
|
|
import umsgpack
|
|
from PIL import Image
|
|
import base64
|
|
import aiosqlite
|
|
import faiss
|
|
import numpy
|
|
import os
|
|
import aiohttp_cors
|
|
import json
|
|
import io
|
|
import time
|
|
import sys
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
import threading
|
|
|
|
with open(sys.argv[1], "r") as config_file:
|
|
CONFIG = json.load(config_file)
|
|
|
|
app = web.Application(client_max_size=32*1024**2)
|
|
routes = web.RouteTableDef()
|
|
|
|
async def clip_server(sess: aiohttp.ClientSession, query, unpack_buffer=True):
|
|
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
|
|
response = umsgpack.loads(await res.read())
|
|
if res.status == 200:
|
|
if unpack_buffer:
|
|
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
|
|
return response
|
|
else:
|
|
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
|
|
|
|
@routes.post("/")
|
|
async def run_query(request):
|
|
sess = app["session"]
|
|
data = await request.json()
|
|
embeddings = []
|
|
if images := data.get("images", []):
|
|
target_image_size = app["index"].inference_server_config["image_size"]
|
|
embeddings.extend(await clip_server(sess, { "images": [ load_image(io.BytesIO(base64.b64decode(x)), target_image_size)[0] for x, w in images ] }))
|
|
if text := data.get("text", []):
|
|
embeddings.extend(await clip_server(sess, { "text": [ x for x, w in text ] }))
|
|
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
|
weighted_embeddings = [ e * w for e, w in zip(embeddings, weights) ]
|
|
weighted_embeddings.extend([ numpy.array(x) for x in data.get("embeddings", []) ])
|
|
if not weighted_embeddings:
|
|
return web.json_response([])
|
|
return web.json_response(app["index"].search(sum(weighted_embeddings), top_k=data.get("top_k", 4000)))
|
|
|
|
@routes.get("/")
|
|
async def health_check(request):
|
|
return web.Response(text="OK")
|
|
|
|
@routes.post("/reload_index")
|
|
async def reload_index_route(request):
|
|
await request.app["index"].reload()
|
|
return web.json_response(True)
|
|
|
|
def load_image(path, image_size):
|
|
im = Image.open(path)
|
|
im.draft("RGB", image_size)
|
|
buf = io.BytesIO()
|
|
im.resize(image_size).convert("RGB").save(buf, format="BMP")
|
|
return buf.getvalue(), path
|
|
|
|
class Index:
|
|
def __init__(self, inference_server_config, http_session):
|
|
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
|
|
self.associated_filenames = []
|
|
self.inference_server_config = inference_server_config
|
|
self.lock = asyncio.Lock()
|
|
self.session = http_session
|
|
|
|
def search(self, query, top_k):
|
|
distances, indices = self.faiss_index.search(numpy.array([query]), top_k)
|
|
distances = distances[0]
|
|
indices = indices[0]
|
|
try:
|
|
indices = indices[:numpy.where(indices==-1)[0][0]]
|
|
except IndexError: pass
|
|
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
|
|
|
|
async def run_ocr(self):
|
|
if not CONFIG.get("enable_ocr"): return
|
|
|
|
import ocr
|
|
|
|
print("Running OCR")
|
|
|
|
conn = await aiosqlite.connect(CONFIG["db_path"])
|
|
unocred = await conn.execute_fetchall("SELECT files.filename FROM files LEFT JOIN ocr ON files.filename = ocr.filename WHERE ocr.scan_time IS NULL OR ocr.scan_time < files.modtime")
|
|
|
|
ocr_sem = asyncio.Semaphore(20) # Google has more concurrency than our internal CLIP backend. I am sure they will be fine.
|
|
load_sem = threading.Semaphore(100) # provide backpressure in loading to avoid using 50GB of RAM (this happened)
|
|
|
|
async def run_image(filename, chunks):
|
|
try:
|
|
text, regions = await ocr.scan_chunks(self.session, chunks)
|
|
await conn.execute("INSERT OR REPLACE INTO ocr VALUES (?, ?, ?, ?)", (filename, time.time(), text, json.dumps(regions)))
|
|
await conn.commit()
|
|
sys.stdout.write(".")
|
|
sys.stdout.flush()
|
|
except:
|
|
print("OCR failed on", filename)
|
|
finally:
|
|
ocr_sem.release()
|
|
|
|
def load_and_chunk_image(filename):
|
|
load_sem.acquire()
|
|
im = Image.open(Path(CONFIG["files"]) / filename)
|
|
return filename, ocr.chunk_image(im)
|
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
|
|
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_and_chunk_image, file[0]) for file in unocred ]):
|
|
filename, chunks = await task
|
|
await ocr_sem.acquire()
|
|
tg.create_task(run_image(filename, chunks))
|
|
load_sem.release()
|
|
|
|
async def reload(self):
|
|
async with self.lock:
|
|
with ThreadPoolExecutor(max_workers=CONFIG.get("n_workers", 1)) as executor:
|
|
print("Indexing")
|
|
conn = await aiosqlite.connect(CONFIG["db_path"])
|
|
conn.row_factory = aiosqlite.Row
|
|
await conn.executescript("""
|
|
CREATE TABLE IF NOT EXISTS files (
|
|
filename TEXT PRIMARY KEY,
|
|
modtime REAL NOT NULL,
|
|
embedding_vector BLOB NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS ocr (
|
|
filename TEXT PRIMARY KEY REFERENCES files(filename),
|
|
scan_time INTEGER NOT NULL,
|
|
text TEXT NOT NULL,
|
|
raw_segments TEXT
|
|
);
|
|
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
|
|
filename,
|
|
text,
|
|
tokenize='unicode61 remove_diacritics 2',
|
|
content='ocr'
|
|
);
|
|
|
|
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON ocr BEGIN
|
|
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, new.text);
|
|
END;
|
|
|
|
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON ocr BEGIN
|
|
INSERT INTO ocr_fts (ocr_fts, rowid, filename, text) VALUES ('delete', old.rowid, old.filename, old.text);
|
|
END;
|
|
""")
|
|
try:
|
|
async with asyncio.TaskGroup() as tg:
|
|
batch_sem = asyncio.Semaphore(3)
|
|
|
|
modified = set()
|
|
|
|
async def do_batch(batch):
|
|
try:
|
|
query = { "images": [ arg[2] for arg in batch ] }
|
|
embeddings = await clip_server(self.session, query, False)
|
|
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
|
|
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
|
|
])
|
|
await conn.commit()
|
|
for filename, _, _ in batch:
|
|
modified.add(filename)
|
|
sys.stdout.write(".")
|
|
sys.stdout.flush()
|
|
finally:
|
|
batch_sem.release()
|
|
|
|
async def dispatch_batch(batch):
|
|
await batch_sem.acquire()
|
|
tg.create_task(do_batch(batch))
|
|
|
|
files = {}
|
|
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
|
|
files[filename] = modtime
|
|
await conn.commit()
|
|
batch = []
|
|
seen_files = set()
|
|
|
|
failed = set()
|
|
for dirpath, _, filenames in os.walk(CONFIG["files"]):
|
|
paths = set()
|
|
done = set()
|
|
for file in filenames:
|
|
path = os.path.join(dirpath, file)
|
|
file = os.path.relpath(path, CONFIG["files"])
|
|
st = os.stat(path)
|
|
seen_files.add(file)
|
|
if st.st_mtime != files.get(file):
|
|
paths.add(path)
|
|
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
|
|
try:
|
|
b, path = await task
|
|
st = os.stat(path)
|
|
file = os.path.relpath(path, CONFIG["files"])
|
|
done.add(path)
|
|
except Exception as e:
|
|
# print(file, "failed", e) we can't have access to file when we need it, oops
|
|
continue
|
|
batch.append((file, st.st_mtime, b))
|
|
if len(batch) == self.inference_server_config["batch"]:
|
|
await dispatch_batch(batch)
|
|
batch = []
|
|
failed |= paths - done
|
|
if batch:
|
|
await dispatch_batch(batch)
|
|
|
|
print()
|
|
for failed_ in failed:
|
|
print("Failed to load", failed_)
|
|
|
|
filenames_set = set(self.associated_filenames)
|
|
new_data = []
|
|
new_filenames = []
|
|
async with conn.execute("SELECT * FROM files") as csr:
|
|
while row := await csr.fetchone():
|
|
filename, modtime, embedding_vector = row
|
|
if filename not in filenames_set:
|
|
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
|
|
new_filenames.append(filename)
|
|
if not new_data: return
|
|
new_data = numpy.array(new_data)
|
|
self.associated_filenames.extend(new_filenames)
|
|
self.faiss_index.add(new_data)
|
|
|
|
remove_indices = []
|
|
for index, filename in enumerate(self.associated_filenames):
|
|
if filename not in seen_files or filename in modified:
|
|
remove_indices.append(index)
|
|
self.associated_filenames[index] = None
|
|
if filename not in seen_files:
|
|
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
|
|
await conn.commit()
|
|
print("Deleting", len(remove_indices), "old entries")
|
|
# TODO concurrency
|
|
# TODO understand what that comment meant
|
|
if remove_indices:
|
|
self.faiss_index.remove_ids(numpy.array(remove_indices))
|
|
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]
|
|
finally:
|
|
await conn.close()
|
|
|
|
await self.run_ocr()
|
|
|
|
app.router.add_routes(routes)
|
|
|
|
cors = aiohttp_cors.setup(app, defaults={
|
|
"*": aiohttp_cors.ResourceOptions(
|
|
allow_credentials=False,
|
|
expose_headers="*",
|
|
allow_headers="*",
|
|
)
|
|
})
|
|
for route in list(app.router.routes()):
|
|
cors.add(route)
|
|
|
|
async def main():
|
|
sess = aiohttp.ClientSession()
|
|
while True:
|
|
try:
|
|
async with await sess.get(CONFIG["clip_server"] + "config") as res:
|
|
inference_server_config = umsgpack.unpackb(await res.read())
|
|
print("Backend config:", inference_server_config)
|
|
break
|
|
except:
|
|
traceback.print_exc()
|
|
await asyncio.sleep(1)
|
|
index = Index(inference_server_config, sess)
|
|
app["index"] = index
|
|
app["session"] = sess
|
|
await index.reload()
|
|
print("Ready")
|
|
if CONFIG.get("no_run_server", False): return
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "", CONFIG["port"])
|
|
await site.start()
|
|
|
|
if __name__ == "__main__":
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(main())
|
|
if CONFIG.get("no_run_server", False) == False: loop.run_forever()
|