mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2026-06-02 10:52:18 +00:00
"release" unfinished scripts and miscellaneous JSON files
This commit is contained in:
File diff suppressed because it is too large
Load Diff
Generated
+2795
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "meme-search-engine"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.5", features = ["multipart"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
anyhow = "1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version="0.3", features = ["env-filter"] }
|
||||
tower-http = { version = "0.2.0", features = ["fs", "trace", "add-extension"] }
|
||||
rusty_ulid = "1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
rmp-serde = "1"
|
||||
futures-util = "0.3"
|
||||
regex = "1.5"
|
||||
lazy_static = "1"
|
||||
config = { version = "0.13", default-features = false, features = ["toml"] }
|
||||
faiss = { version = "0.12", features = [] }
|
||||
reqwest = "0.11"
|
||||
walkdir = "2"
|
||||
rusqlite = { version = "0.30.0", features = ["bundled"] }
|
||||
futures = "0.3"
|
||||
image = { version = "0.24", features = ["avif", "webp", "default"] }
|
||||
rayon = "1.8"
|
||||
@@ -0,0 +1,189 @@
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import traceback
|
||||
import umsgpack
|
||||
import collections
|
||||
import queue
|
||||
from PIL import Image
|
||||
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from transformers import SiglipTokenizer, SiglipImageProcessor, T5TokenizerFast, SiglipTextConfig, SiglipVisionConfig
|
||||
import numpy
|
||||
|
||||
with open(sys.argv[1], "r") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
|
||||
# blatantly copypasted from colab
|
||||
# https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb
|
||||
VARIANT, RES = CONFIG["model"]
|
||||
CKPT, TXTVARIANT, EMBDIM, SEQLEN, VOCAB = {
|
||||
("So400m/14", 384): ("webli_en_so400m_384_58765454-fp16.safetensors", "So400m", 1152, 64, 32_000),
|
||||
}[VARIANT, RES]
|
||||
|
||||
model_cfg = ml_collections.ConfigDict()
|
||||
model_cfg.image_model = "vit" # TODO(lbeyer): remove later, default
|
||||
model_cfg.text_model = "proj.image_text.text_transformer" # TODO(lbeyer): remove later, default
|
||||
model_cfg.image = dict(variant=VARIANT, pool_type="map")
|
||||
model_cfg.text = dict(variant=TXTVARIANT, vocab_size=VOCAB)
|
||||
model_cfg.out_dim = (None, EMBDIM) # (image_out_dim, text_out_dim)
|
||||
model_cfg.bias_init = -10.0
|
||||
model_cfg.temperature_init = 10.0
|
||||
|
||||
model = model_mod.Model(**model_cfg)
|
||||
|
||||
init_params = None # sanity checks are a low-interest-rate phenomenon
|
||||
model_params = model_mod.load(init_params, f"{CKPT}", model_cfg) # assume path
|
||||
|
||||
pp_img = pp_builder.get_preprocess_fn(f"resize({RES})|value_range(-1, 1)")
|
||||
TOKENIZERS = {
|
||||
32_000: "c4_en",
|
||||
250_000: "mc4",
|
||||
}
|
||||
pp_txt = pp_builder.get_preprocess_fn(f'tokenize(max_len={SEQLEN}, model="{TOKENIZERS[VOCAB]}", eos="sticky", pad_value=1, inkey="text")')
|
||||
print("Model loaded")
|
||||
|
||||
BS = CONFIG["max_batch_size"]
|
||||
MODELNAME = CONFIG["model_name"]
|
||||
|
||||
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
|
||||
|
||||
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
|
||||
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
|
||||
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
|
||||
|
||||
@jax.jit
|
||||
def run_text_model(text_batch):
|
||||
_, features, out = model.apply({"params": model_params}, None, text_batch)
|
||||
return features
|
||||
|
||||
@jax.jit
|
||||
def run_image_model(image_batch):
|
||||
features, _, out = model.apply({"params": model_params}, image_batch, None)
|
||||
return features
|
||||
|
||||
def round_down_to_power_of_two(x):
|
||||
return 1<<(x.bit_length()-1)
|
||||
|
||||
def minimize_jits(fn, batch):
|
||||
out = numpy.zeros((batch.shape[0], EMBDIM), dtype="float16")
|
||||
i = 0
|
||||
while True:
|
||||
batch_dim = batch.shape[0]
|
||||
s = round_down_to_power_of_two(batch_dim)
|
||||
fst = batch[:s,...]
|
||||
out[i:(i + s), ...] = fn(fst)
|
||||
i += s
|
||||
batch = batch[s:, ...]
|
||||
if batch.shape[0] == 0: break
|
||||
return out
|
||||
|
||||
def do_inference(params: InferenceParameters):
|
||||
try:
|
||||
text, images, callback = params
|
||||
if text is not None:
|
||||
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
|
||||
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
|
||||
features = run_text_model(text)
|
||||
elif images is not None:
|
||||
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
||||
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
||||
features = run_image_model(images)
|
||||
batch_count_ctr.labels(MODELNAME).inc()
|
||||
callback(True, numpy.asarray(features))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
callback(False, str(e))
|
||||
|
||||
iq = queue.Queue(100)
|
||||
def infer_thread():
|
||||
while True:
|
||||
do_inference(iq.get())
|
||||
|
||||
pq = queue.Queue(100)
|
||||
def preprocessing_thread():
|
||||
while True:
|
||||
text, images, callback = pq.get()
|
||||
try:
|
||||
if text:
|
||||
assert len(text) <= BS, f"max batch size is {BS}"
|
||||
# I feel like this ought to be batchable but I can't see how to do that
|
||||
text = numpy.array([pp_txt({"text": text})["labels"] for text in text])
|
||||
elif images:
|
||||
assert len(images) <= BS, f"max batch size is {BS}"
|
||||
images = numpy.array([pp_img({"image": numpy.array(Image.open(io.BytesIO(image)).convert("RGB"))})["image"] for image in images])
|
||||
else:
|
||||
assert False, "images or text required"
|
||||
iq.put(InferenceParameters(text, images, callback))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
callback(False, str(e))
|
||||
|
||||
app = web.Application(client_max_size=2**26)
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
@routes.post("/")
|
||||
async def run_inference(request):
|
||||
loop = asyncio.get_event_loop()
|
||||
data = umsgpack.loads(await request.read())
|
||||
event = asyncio.Event()
|
||||
results = None
|
||||
def callback(*argv):
|
||||
nonlocal results
|
||||
results = argv
|
||||
loop.call_soon_threadsafe(lambda: event.set())
|
||||
pq.put_nowait(InferenceParameters(data.get("text"), data.get("images"), callback))
|
||||
await event.wait()
|
||||
body_data = results[1]
|
||||
if results[0]:
|
||||
status = 200
|
||||
body_data = [x.astype("float16").tobytes() for x in body_data]
|
||||
else:
|
||||
status = 500
|
||||
print(results[1])
|
||||
return web.Response(body=umsgpack.dumps(body_data), status=status, content_type="application/msgpack")
|
||||
|
||||
@routes.get("/config")
|
||||
async def config(request):
|
||||
return web.Response(body=umsgpack.dumps({
|
||||
"model": CONFIG["model"],
|
||||
"batch": BS,
|
||||
"image_size": (RES, RES),
|
||||
"embedding_size": EMBDIM
|
||||
}), status=200, content_type="application/msgpack")
|
||||
|
||||
@routes.get("/")
|
||||
async def health(request):
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.get("/metrics")
|
||||
async def metrics(request):
|
||||
return web.Response(body=generate_latest(REGISTRY))
|
||||
|
||||
app.router.add_routes(routes)
|
||||
|
||||
async def run_webserver():
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "", CONFIG["port"])
|
||||
print("Ready")
|
||||
await site.start()
|
||||
|
||||
try:
|
||||
th = threading.Thread(target=infer_thread)
|
||||
th.start()
|
||||
th = threading.Thread(target=preprocessing_thread)
|
||||
th.start()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(run_webserver())
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
import sys
|
||||
sys.exit(0)
|
||||
@@ -0,0 +1,4 @@
|
||||
log_level = "debug"
|
||||
listen_address = "[::1]:1710"
|
||||
db_path = "./mse.sqlite3"
|
||||
images_path = "/data/public"
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,212 @@
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import traceback
|
||||
import umsgpack
|
||||
from PIL import Image
|
||||
import base64
|
||||
import aiosqlite
|
||||
import faiss
|
||||
import numpy
|
||||
import os
|
||||
import aiohttp_cors
|
||||
import json
|
||||
import io
|
||||
import sys
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
with open(sys.argv[1], "r") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
|
||||
app = web.Application(client_max_size=32*1024**2)
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
async def clip_server(query, unpack_buffer=True):
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
|
||||
response = umsgpack.loads(await res.read())
|
||||
if res.status == 200:
|
||||
if unpack_buffer:
|
||||
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
|
||||
return response
|
||||
else:
|
||||
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
|
||||
|
||||
@routes.post("/")
|
||||
async def run_query(request):
|
||||
data = await request.json()
|
||||
embeddings = []
|
||||
if images := data.get("images", []):
|
||||
embeddings.extend(await clip_server({ "images": [ base64.b64decode(x) for x, w in images ] }))
|
||||
if text := data.get("text", []):
|
||||
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
|
||||
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
||||
embeddings = [ e * w for e, w in zip(embeddings, weights) ]
|
||||
if not embeddings:
|
||||
return web.json_response([])
|
||||
return web.json_response(app["index"].search(sum(embeddings)))
|
||||
|
||||
@routes.get("/")
|
||||
async def health_check(request):
|
||||
return web.Response(text="OK")
|
||||
|
||||
@routes.post("/reload_index")
|
||||
async def reload_index_route(request):
|
||||
await request.app["index"].reload()
|
||||
return web.json_response(True)
|
||||
|
||||
def load_image(path, image_size):
|
||||
im = Image.open(path)
|
||||
im.draft("RGB", image_size)
|
||||
buf = io.BytesIO()
|
||||
im.resize(image_size).convert("RGB").save(buf, format="BMP")
|
||||
return buf.getvalue(), path
|
||||
|
||||
class Index:
|
||||
def __init__(self, inference_server_config):
|
||||
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
|
||||
self.associated_filenames = []
|
||||
self.inference_server_config = inference_server_config
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
def search(self, query):
|
||||
distances, indices = self.faiss_index.search(numpy.array([query]), 4000)
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
try:
|
||||
indices = indices[:numpy.where(indices==-1)[0][0]]
|
||||
except IndexError: pass
|
||||
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
with ProcessPoolExecutor(max_workers=12) as executor:
|
||||
print("Indexing")
|
||||
conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop())
|
||||
conn.row_factory = aiosqlite.Row
|
||||
await conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
modtime REAL NOT NULL,
|
||||
embedding_vector BLOB NOT NULL
|
||||
);
|
||||
""")
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
batch_sem = asyncio.Semaphore(32)
|
||||
|
||||
modified = set()
|
||||
|
||||
async def do_batch(batch):
|
||||
try:
|
||||
query = { "images": [ arg[2] for arg in batch ] }
|
||||
embeddings = await clip_server(query, False)
|
||||
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
|
||||
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
|
||||
])
|
||||
await conn.commit()
|
||||
for filename, _, _ in batch:
|
||||
modified.add(filename)
|
||||
sys.stdout.write(".")
|
||||
finally:
|
||||
batch_sem.release()
|
||||
|
||||
async def dispatch_batch(batch):
|
||||
await batch_sem.acquire()
|
||||
tg.create_task(do_batch(batch))
|
||||
|
||||
files = {}
|
||||
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
|
||||
files[filename] = modtime
|
||||
await conn.commit()
|
||||
batch = []
|
||||
|
||||
for dirpath, _, filenames in os.walk(CONFIG["files"]):
|
||||
paths = []
|
||||
for file in filenames:
|
||||
path = os.path.join(dirpath, file)
|
||||
file = os.path.relpath(path, CONFIG["files"])
|
||||
st = os.stat(path)
|
||||
if st.st_mtime != files.get(file):
|
||||
paths.append(path)
|
||||
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
|
||||
try:
|
||||
b, path = await task
|
||||
st = os.stat(path)
|
||||
file = os.path.relpath(path, CONFIG["files"])
|
||||
except Exception as e:
|
||||
print(file, "failed", e)
|
||||
continue
|
||||
batch.append((file, st.st_mtime, b))
|
||||
if len(batch) == self.inference_server_config["batch"]:
|
||||
await dispatch_batch(batch)
|
||||
batch = []
|
||||
if batch:
|
||||
await dispatch_batch(batch)
|
||||
|
||||
remove_indices = []
|
||||
for index, filename in enumerate(self.associated_filenames):
|
||||
if filename not in files or filename in modified:
|
||||
remove_indices.append(index)
|
||||
self.associated_filenames[index] = None
|
||||
if filename not in files:
|
||||
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
|
||||
await conn.commit()
|
||||
# TODO concurrency
|
||||
# TODO understand what that comment meant
|
||||
if remove_indices:
|
||||
self.faiss_index.remove_ids(numpy.array(remove_indices))
|
||||
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]
|
||||
|
||||
filenames_set = set(self.associated_filenames)
|
||||
new_data = []
|
||||
new_filenames = []
|
||||
async with conn.execute("SELECT * FROM files") as csr:
|
||||
while row := await csr.fetchone():
|
||||
filename, modtime, embedding_vector = row
|
||||
if filename not in filenames_set:
|
||||
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
|
||||
new_filenames.append(filename)
|
||||
new_data = numpy.array(new_data)
|
||||
self.associated_filenames.extend(new_filenames)
|
||||
self.faiss_index.add(new_data)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
app.router.add_routes(routes)
|
||||
|
||||
cors = aiohttp_cors.setup(app, defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=False,
|
||||
expose_headers="*",
|
||||
allow_headers="*",
|
||||
)
|
||||
})
|
||||
for route in list(app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
async def main():
|
||||
while True:
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
try:
|
||||
async with await sess.get(CONFIG["clip_server"] + "config") as res:
|
||||
inference_server_config = umsgpack.unpackb(await res.read())
|
||||
print("Backend config:", inference_server_config)
|
||||
break
|
||||
except:
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(1)
|
||||
index = Index(inference_server_config)
|
||||
app["index"] = index
|
||||
await index.reload()
|
||||
print("Ready")
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "", CONFIG["port"])
|
||||
await site.start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(main())
|
||||
loop.run_forever()
|
||||
@@ -0,0 +1,167 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_http::{services::ServeDir, add_extension::AddExtensionLayer, services::ServeFile};
|
||||
use axum::{extract::{Json, Extension, Multipart, Path as AxumPath}, http::{StatusCode, Request}, response::{IntoResponse}, body::Body, routing::{get, post, get_service}, Router};
|
||||
use std::sync::Arc;
|
||||
use tokio::{sync::RwLock, runtime::Handle, fs::File};
|
||||
use anyhow::Result;
|
||||
use rusqlite::Connection;
|
||||
use std::collections::HashMap;
|
||||
use futures::{stream, StreamExt};
|
||||
use std::path::Path;
|
||||
use image::{io::Reader as ImageReader, imageops};
|
||||
use tokio::task::block_in_place;
|
||||
use rayon::prelude::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
mod util;
|
||||
|
||||
use util::CONFIG;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct InferenceServerConfig {
|
||||
model: String,
|
||||
embedding_size: usize,
|
||||
batch: usize,
|
||||
image_size: usize
|
||||
}
|
||||
|
||||
struct Index {
|
||||
vectors: faiss::FlatIndex // we need the index to implement Send, which an arbitrary boxed one might not
|
||||
}
|
||||
|
||||
async fn build_index() -> Result<Index> {
|
||||
let mut conn = block_in_place(|| Connection::open(&CONFIG.db_path))?;
|
||||
block_in_place(|| conn.execute("CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
modtime REAL NOT NULL,
|
||||
embedding_vector BLOB NOT NULL
|
||||
)", ()))?;
|
||||
|
||||
let SIZE = 1024; // TODO
|
||||
let IMAGE_SIZE = 384; // TODO
|
||||
let BS = 32;
|
||||
|
||||
let mut files = HashMap::new();
|
||||
let mut new_files: HashMap<String, (i64, Option<Vec<u8>>)> = HashMap::new();
|
||||
let mut vectors = faiss::index_factory(SIZE, "Flat", faiss::MetricType::InnerProduct)?.into_flat()?;
|
||||
|
||||
block_in_place(|| -> Result<()> {
|
||||
let mut stmt = conn.prepare_cached("SELECT filename, modtime FROM files")?;
|
||||
let mut rows = stmt.query([])?;
|
||||
while let Some(row) = rows.next()? {
|
||||
let filename: String = row.get(0)?;
|
||||
let modtime: i64 = row.get(1)?;
|
||||
files.insert(filename, modtime);
|
||||
}
|
||||
|
||||
for entry in walkdir::WalkDir::new(&CONFIG.images_path).follow_links(true) {
|
||||
let entry = entry?;
|
||||
if entry.file_type().is_file() {
|
||||
let metadata = entry.metadata()?;
|
||||
let modtime = metadata.modified()?.duration_since(std::time::UNIX_EPOCH)?.as_secs() as i64;
|
||||
let filename = entry.path().strip_prefix(&CONFIG.images_path)?.to_string_lossy().to_string();
|
||||
match files.get(&filename) {
|
||||
Some(old_modtime) if *old_modtime < modtime => new_files.insert(filename, (modtime, None)),
|
||||
None => new_files.insert(filename, (modtime, None)),
|
||||
_ => None
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
let (itx, mut irx) = tokio::sync::mpsc::channel(BS * 2);
|
||||
|
||||
let new_files_ = new_files.clone();
|
||||
let image_reader_task = tokio::task::spawn_blocking(move || {
|
||||
new_files_.par_iter().try_for_each(|(filename, _)| -> Result<()> {
|
||||
let mut path = Path::new(&CONFIG.images_path).to_path_buf();
|
||||
path.push(filename);
|
||||
let image = ImageReader::open(path)?.with_guessed_format()?.decode()?;
|
||||
let resized = imageops::resize(&image.into_rgb8(), IMAGE_SIZE, IMAGE_SIZE, imageops::Lanczos3);
|
||||
let mut bytes: Vec<u8> = Vec::new();
|
||||
resized.write_to(&mut Cursor::new(&mut bytes), image::ImageOutputFormat::Png)?;
|
||||
itx.blocking_send((filename.to_string(), bytes))?;
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
let dispatch_batch = |batch| {
|
||||
|
||||
};
|
||||
|
||||
let mut batch = vec![];
|
||||
while let Some((filename, image)) = irx.recv().await {
|
||||
if batch.len() == BS {
|
||||
dispatch_batch(std::mem::replace(&mut batch, vec![]));
|
||||
}
|
||||
batch.push((filename, image));
|
||||
}
|
||||
if batch.len() > 0 {
|
||||
dispatch_batch(std::mem::replace(&mut batch, vec![]));
|
||||
}
|
||||
|
||||
// TODO switch to blocking
|
||||
{
|
||||
let tx = conn.transaction()?;
|
||||
{
|
||||
let mut stmt = tx.prepare_cached("INSERT OR REPLACE INTO files VALUES (?, ?, ?)")?;
|
||||
for (filename, (modtime, embedding)) in new_files {
|
||||
stmt.execute((filename, modtime, embedding.unwrap()))?;
|
||||
}
|
||||
}
|
||||
tx.commit()?;
|
||||
}
|
||||
|
||||
Ok(Index {
|
||||
vectors: vectors
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
if std::env::var_os("RUST_LOG").is_none() {
|
||||
std::env::set_var("RUST_LOG", format!("meme-search-engine={}", CONFIG.log_level))
|
||||
}
|
||||
|
||||
let notify = tokio::sync::Notify::new();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
notify.notified().await;
|
||||
let index = build_index().await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
//let db = Arc::new(RwLock::new(DB::init().await?));
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(health))
|
||||
.route("/", post(run_query));
|
||||
//.layer(AddExtensionLayer::new(db));
|
||||
|
||||
let addr = CONFIG.listen_address.parse().unwrap();
|
||||
tracing::info!("listening on {}", addr);
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health() -> String {
|
||||
format!("OK")
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct RawQuery {
|
||||
text: Vec<String>,
|
||||
images: Vec<String> // base64 (sorry)
|
||||
}
|
||||
|
||||
async fn run_query(query: Json<RawQuery>) -> Json<Vec<(String, f32)>> {
|
||||
tracing::info!("{:?}", query);
|
||||
Json(vec![])
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
use anyhow::{Result, Context};
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub log_level: String,
|
||||
pub listen_address: String,
|
||||
pub images_path: String,
|
||||
pub db_path: String,
|
||||
pub backend_url: String
|
||||
}
|
||||
|
||||
fn load_config() -> Result<Config> {
|
||||
use config::{Config, File};
|
||||
let s = Config::builder()
|
||||
.add_source(File::with_name("./config"))
|
||||
.build().context("loading config")?;
|
||||
Ok(s.try_deserialize().context("parsing config")?)
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
pub static ref CONFIG: Config = load_config().unwrap();
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,19 @@
|
||||
import numpy
|
||||
import xgboost as xgb
|
||||
|
||||
import shared
|
||||
|
||||
trains, validations = shared.fetch_ratings()
|
||||
|
||||
ranker = xgb.XGBRanker(
|
||||
tree_method="hist",
|
||||
lambdarank_num_pair_per_sample=8,
|
||||
objective="rank:ndcg",
|
||||
lambdarank_pair_method="topk",
|
||||
device="cuda"
|
||||
)
|
||||
flat_samples = [ sample for trainss in trains for sample in trainss ]
|
||||
X = numpy.concatenate([ numpy.stack((meme1, meme2)) for meme1, meme2, rating in flat_samples ])
|
||||
Y = numpy.concatenate([ numpy.stack((int(rating), int(1 - rating))) for meme1, meme2, rating in flat_samples ])
|
||||
qid = numpy.concatenate([ numpy.stack((i, i)) for i in range(len(flat_samples)) ])
|
||||
ranker.fit(X, Y, qid=qid, verbose=True)
|
||||
Reference in New Issue
Block a user