mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
"release" unfinished scripts and miscellaneous JSON files
This commit is contained in:
parent
caa8306ff7
commit
fa863c2075
7041
meme-rater/log.jsonl
Normal file
7041
meme-rater/log.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
2795
misc/Cargo.lock
generated
Normal file
2795
misc/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
29
misc/Cargo.toml
Normal file
29
misc/Cargo.toml
Normal file
@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "meme-search-engine"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.5", features = ["multipart"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
anyhow = "1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version="0.3", features = ["env-filter"] }
|
||||
tower-http = { version = "0.2.0", features = ["fs", "trace", "add-extension"] }
|
||||
rusty_ulid = "1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
rmp-serde = "1"
|
||||
futures-util = "0.3"
|
||||
regex = "1.5"
|
||||
lazy_static = "1"
|
||||
config = { version = "0.13", default-features = false, features = ["toml"] }
|
||||
faiss = { version = "0.12", features = [] }
|
||||
reqwest = "0.11"
|
||||
walkdir = "2"
|
||||
rusqlite = { version = "0.30.0", features = ["bundled"] }
|
||||
futures = "0.3"
|
||||
image = { version = "0.24", features = ["avif", "webp", "default"] }
|
||||
rayon = "1.8"
|
189
misc/clip_accursed.py
Normal file
189
misc/clip_accursed.py
Normal file
@ -0,0 +1,189 @@
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import traceback
|
||||
import umsgpack
|
||||
import collections
|
||||
import queue
|
||||
from PIL import Image
|
||||
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from transformers import SiglipTokenizer, SiglipImageProcessor, T5TokenizerFast, SiglipTextConfig, SiglipVisionConfig
|
||||
import numpy
|
||||
|
||||
with open(sys.argv[1], "r") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
|
||||
# blatantly copypasted from colab
|
||||
# https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb
|
||||
VARIANT, RES = CONFIG["model"]
|
||||
CKPT, TXTVARIANT, EMBDIM, SEQLEN, VOCAB = {
|
||||
("So400m/14", 384): ("webli_en_so400m_384_58765454-fp16.safetensors", "So400m", 1152, 64, 32_000),
|
||||
}[VARIANT, RES]
|
||||
|
||||
model_cfg = ml_collections.ConfigDict()
|
||||
model_cfg.image_model = "vit" # TODO(lbeyer): remove later, default
|
||||
model_cfg.text_model = "proj.image_text.text_transformer" # TODO(lbeyer): remove later, default
|
||||
model_cfg.image = dict(variant=VARIANT, pool_type="map")
|
||||
model_cfg.text = dict(variant=TXTVARIANT, vocab_size=VOCAB)
|
||||
model_cfg.out_dim = (None, EMBDIM) # (image_out_dim, text_out_dim)
|
||||
model_cfg.bias_init = -10.0
|
||||
model_cfg.temperature_init = 10.0
|
||||
|
||||
model = model_mod.Model(**model_cfg)
|
||||
|
||||
init_params = None # sanity checks are a low-interest-rate phenomenon
|
||||
model_params = model_mod.load(init_params, f"{CKPT}", model_cfg) # assume path
|
||||
|
||||
pp_img = pp_builder.get_preprocess_fn(f"resize({RES})|value_range(-1, 1)")
|
||||
TOKENIZERS = {
|
||||
32_000: "c4_en",
|
||||
250_000: "mc4",
|
||||
}
|
||||
pp_txt = pp_builder.get_preprocess_fn(f'tokenize(max_len={SEQLEN}, model="{TOKENIZERS[VOCAB]}", eos="sticky", pad_value=1, inkey="text")')
|
||||
print("Model loaded")
|
||||
|
||||
BS = CONFIG["max_batch_size"]
|
||||
MODELNAME = CONFIG["model_name"]
|
||||
|
||||
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
|
||||
|
||||
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
|
||||
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
|
||||
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
|
||||
|
||||
@jax.jit
|
||||
def run_text_model(text_batch):
|
||||
_, features, out = model.apply({"params": model_params}, None, text_batch)
|
||||
return features
|
||||
|
||||
@jax.jit
|
||||
def run_image_model(image_batch):
|
||||
features, _, out = model.apply({"params": model_params}, image_batch, None)
|
||||
return features
|
||||
|
||||
def round_down_to_power_of_two(x):
|
||||
return 1<<(x.bit_length()-1)
|
||||
|
||||
def minimize_jits(fn, batch):
|
||||
out = numpy.zeros((batch.shape[0], EMBDIM), dtype="float16")
|
||||
i = 0
|
||||
while True:
|
||||
batch_dim = batch.shape[0]
|
||||
s = round_down_to_power_of_two(batch_dim)
|
||||
fst = batch[:s,...]
|
||||
out[i:(i + s), ...] = fn(fst)
|
||||
i += s
|
||||
batch = batch[s:, ...]
|
||||
if batch.shape[0] == 0: break
|
||||
return out
|
||||
|
||||
def do_inference(params: InferenceParameters):
|
||||
try:
|
||||
text, images, callback = params
|
||||
if text is not None:
|
||||
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
|
||||
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
|
||||
features = run_text_model(text)
|
||||
elif images is not None:
|
||||
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
|
||||
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
|
||||
features = run_image_model(images)
|
||||
batch_count_ctr.labels(MODELNAME).inc()
|
||||
callback(True, numpy.asarray(features))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
callback(False, str(e))
|
||||
|
||||
iq = queue.Queue(100)
|
||||
def infer_thread():
|
||||
while True:
|
||||
do_inference(iq.get())
|
||||
|
||||
pq = queue.Queue(100)
|
||||
def preprocessing_thread():
|
||||
while True:
|
||||
text, images, callback = pq.get()
|
||||
try:
|
||||
if text:
|
||||
assert len(text) <= BS, f"max batch size is {BS}"
|
||||
# I feel like this ought to be batchable but I can't see how to do that
|
||||
text = numpy.array([pp_txt({"text": text})["labels"] for text in text])
|
||||
elif images:
|
||||
assert len(images) <= BS, f"max batch size is {BS}"
|
||||
images = numpy.array([pp_img({"image": numpy.array(Image.open(io.BytesIO(image)).convert("RGB"))})["image"] for image in images])
|
||||
else:
|
||||
assert False, "images or text required"
|
||||
iq.put(InferenceParameters(text, images, callback))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
callback(False, str(e))
|
||||
|
||||
app = web.Application(client_max_size=2**26)
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
@routes.post("/")
|
||||
async def run_inference(request):
|
||||
loop = asyncio.get_event_loop()
|
||||
data = umsgpack.loads(await request.read())
|
||||
event = asyncio.Event()
|
||||
results = None
|
||||
def callback(*argv):
|
||||
nonlocal results
|
||||
results = argv
|
||||
loop.call_soon_threadsafe(lambda: event.set())
|
||||
pq.put_nowait(InferenceParameters(data.get("text"), data.get("images"), callback))
|
||||
await event.wait()
|
||||
body_data = results[1]
|
||||
if results[0]:
|
||||
status = 200
|
||||
body_data = [x.astype("float16").tobytes() for x in body_data]
|
||||
else:
|
||||
status = 500
|
||||
print(results[1])
|
||||
return web.Response(body=umsgpack.dumps(body_data), status=status, content_type="application/msgpack")
|
||||
|
||||
@routes.get("/config")
|
||||
async def config(request):
|
||||
return web.Response(body=umsgpack.dumps({
|
||||
"model": CONFIG["model"],
|
||||
"batch": BS,
|
||||
"image_size": (RES, RES),
|
||||
"embedding_size": EMBDIM
|
||||
}), status=200, content_type="application/msgpack")
|
||||
|
||||
@routes.get("/")
|
||||
async def health(request):
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.get("/metrics")
|
||||
async def metrics(request):
|
||||
return web.Response(body=generate_latest(REGISTRY))
|
||||
|
||||
app.router.add_routes(routes)
|
||||
|
||||
async def run_webserver():
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "", CONFIG["port"])
|
||||
print("Ready")
|
||||
await site.start()
|
||||
|
||||
try:
|
||||
th = threading.Thread(target=infer_thread)
|
||||
th.start()
|
||||
th = threading.Thread(target=preprocessing_thread)
|
||||
th.start()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(run_webserver())
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
import sys
|
||||
sys.exit(0)
|
4
misc/config.toml
Normal file
4
misc/config.toml
Normal file
@ -0,0 +1,4 @@
|
||||
log_level = "debug"
|
||||
listen_address = "[::1]:1710"
|
||||
db_path = "./mse.sqlite3"
|
||||
images_path = "/data/public"
|
13
misc/eval.html
Normal file
13
misc/eval.html
Normal file
File diff suppressed because one or more lines are too long
2355
misc/log-1713711190.1288905.jsonl
Normal file
2355
misc/log-1713711190.1288905.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
2354
misc/log-1713711267.4419498.jsonl
Normal file
2354
misc/log-1713711267.4419498.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
212
misc/mse_accursed.py
Normal file
212
misc/mse_accursed.py
Normal file
@ -0,0 +1,212 @@
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import traceback
|
||||
import umsgpack
|
||||
from PIL import Image
|
||||
import base64
|
||||
import aiosqlite
|
||||
import faiss
|
||||
import numpy
|
||||
import os
|
||||
import aiohttp_cors
|
||||
import json
|
||||
import io
|
||||
import sys
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
with open(sys.argv[1], "r") as config_file:
|
||||
CONFIG = json.load(config_file)
|
||||
|
||||
app = web.Application(client_max_size=32*1024**2)
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
async def clip_server(query, unpack_buffer=True):
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
|
||||
response = umsgpack.loads(await res.read())
|
||||
if res.status == 200:
|
||||
if unpack_buffer:
|
||||
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
|
||||
return response
|
||||
else:
|
||||
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
|
||||
|
||||
@routes.post("/")
|
||||
async def run_query(request):
|
||||
data = await request.json()
|
||||
embeddings = []
|
||||
if images := data.get("images", []):
|
||||
embeddings.extend(await clip_server({ "images": [ base64.b64decode(x) for x, w in images ] }))
|
||||
if text := data.get("text", []):
|
||||
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
|
||||
weights = [ w for x, w in images ] + [ w for x, w in text ]
|
||||
embeddings = [ e * w for e, w in zip(embeddings, weights) ]
|
||||
if not embeddings:
|
||||
return web.json_response([])
|
||||
return web.json_response(app["index"].search(sum(embeddings)))
|
||||
|
||||
@routes.get("/")
|
||||
async def health_check(request):
|
||||
return web.Response(text="OK")
|
||||
|
||||
@routes.post("/reload_index")
|
||||
async def reload_index_route(request):
|
||||
await request.app["index"].reload()
|
||||
return web.json_response(True)
|
||||
|
||||
def load_image(path, image_size):
|
||||
im = Image.open(path)
|
||||
im.draft("RGB", image_size)
|
||||
buf = io.BytesIO()
|
||||
im.resize(image_size).convert("RGB").save(buf, format="BMP")
|
||||
return buf.getvalue(), path
|
||||
|
||||
class Index:
|
||||
def __init__(self, inference_server_config):
|
||||
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
|
||||
self.associated_filenames = []
|
||||
self.inference_server_config = inference_server_config
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
def search(self, query):
|
||||
distances, indices = self.faiss_index.search(numpy.array([query]), 4000)
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
try:
|
||||
indices = indices[:numpy.where(indices==-1)[0][0]]
|
||||
except IndexError: pass
|
||||
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
|
||||
|
||||
async def reload(self):
|
||||
async with self.lock:
|
||||
with ProcessPoolExecutor(max_workers=12) as executor:
|
||||
print("Indexing")
|
||||
conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop())
|
||||
conn.row_factory = aiosqlite.Row
|
||||
await conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
modtime REAL NOT NULL,
|
||||
embedding_vector BLOB NOT NULL
|
||||
);
|
||||
""")
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
batch_sem = asyncio.Semaphore(32)
|
||||
|
||||
modified = set()
|
||||
|
||||
async def do_batch(batch):
|
||||
try:
|
||||
query = { "images": [ arg[2] for arg in batch ] }
|
||||
embeddings = await clip_server(query, False)
|
||||
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
|
||||
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
|
||||
])
|
||||
await conn.commit()
|
||||
for filename, _, _ in batch:
|
||||
modified.add(filename)
|
||||
sys.stdout.write(".")
|
||||
finally:
|
||||
batch_sem.release()
|
||||
|
||||
async def dispatch_batch(batch):
|
||||
await batch_sem.acquire()
|
||||
tg.create_task(do_batch(batch))
|
||||
|
||||
files = {}
|
||||
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
|
||||
files[filename] = modtime
|
||||
await conn.commit()
|
||||
batch = []
|
||||
|
||||
for dirpath, _, filenames in os.walk(CONFIG["files"]):
|
||||
paths = []
|
||||
for file in filenames:
|
||||
path = os.path.join(dirpath, file)
|
||||
file = os.path.relpath(path, CONFIG["files"])
|
||||
st = os.stat(path)
|
||||
if st.st_mtime != files.get(file):
|
||||
paths.append(path)
|
||||
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
|
||||
try:
|
||||
b, path = await task
|
||||
st = os.stat(path)
|
||||
file = os.path.relpath(path, CONFIG["files"])
|
||||
except Exception as e:
|
||||
print(file, "failed", e)
|
||||
continue
|
||||
batch.append((file, st.st_mtime, b))
|
||||
if len(batch) == self.inference_server_config["batch"]:
|
||||
await dispatch_batch(batch)
|
||||
batch = []
|
||||
if batch:
|
||||
await dispatch_batch(batch)
|
||||
|
||||
remove_indices = []
|
||||
for index, filename in enumerate(self.associated_filenames):
|
||||
if filename not in files or filename in modified:
|
||||
remove_indices.append(index)
|
||||
self.associated_filenames[index] = None
|
||||
if filename not in files:
|
||||
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
|
||||
await conn.commit()
|
||||
# TODO concurrency
|
||||
# TODO understand what that comment meant
|
||||
if remove_indices:
|
||||
self.faiss_index.remove_ids(numpy.array(remove_indices))
|
||||
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]
|
||||
|
||||
filenames_set = set(self.associated_filenames)
|
||||
new_data = []
|
||||
new_filenames = []
|
||||
async with conn.execute("SELECT * FROM files") as csr:
|
||||
while row := await csr.fetchone():
|
||||
filename, modtime, embedding_vector = row
|
||||
if filename not in filenames_set:
|
||||
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
|
||||
new_filenames.append(filename)
|
||||
new_data = numpy.array(new_data)
|
||||
self.associated_filenames.extend(new_filenames)
|
||||
self.faiss_index.add(new_data)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
app.router.add_routes(routes)
|
||||
|
||||
cors = aiohttp_cors.setup(app, defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=False,
|
||||
expose_headers="*",
|
||||
allow_headers="*",
|
||||
)
|
||||
})
|
||||
for route in list(app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
async def main():
|
||||
while True:
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
try:
|
||||
async with await sess.get(CONFIG["clip_server"] + "config") as res:
|
||||
inference_server_config = umsgpack.unpackb(await res.read())
|
||||
print("Backend config:", inference_server_config)
|
||||
break
|
||||
except:
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(1)
|
||||
index = Index(inference_server_config)
|
||||
app["index"] = index
|
||||
await index.reload()
|
||||
print("Ready")
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "", CONFIG["port"])
|
||||
await site.start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(main())
|
||||
loop.run_forever()
|
167
misc/src/main.rs
Normal file
167
misc/src/main.rs
Normal file
@ -0,0 +1,167 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_http::{services::ServeDir, add_extension::AddExtensionLayer, services::ServeFile};
|
||||
use axum::{extract::{Json, Extension, Multipart, Path as AxumPath}, http::{StatusCode, Request}, response::{IntoResponse}, body::Body, routing::{get, post, get_service}, Router};
|
||||
use std::sync::Arc;
|
||||
use tokio::{sync::RwLock, runtime::Handle, fs::File};
|
||||
use anyhow::Result;
|
||||
use rusqlite::Connection;
|
||||
use std::collections::HashMap;
|
||||
use futures::{stream, StreamExt};
|
||||
use std::path::Path;
|
||||
use image::{io::Reader as ImageReader, imageops};
|
||||
use tokio::task::block_in_place;
|
||||
use rayon::prelude::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
mod util;
|
||||
|
||||
use util::CONFIG;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct InferenceServerConfig {
|
||||
model: String,
|
||||
embedding_size: usize,
|
||||
batch: usize,
|
||||
image_size: usize
|
||||
}
|
||||
|
||||
struct Index {
|
||||
vectors: faiss::FlatIndex // we need the index to implement Send, which an arbitrary boxed one might not
|
||||
}
|
||||
|
||||
async fn build_index() -> Result<Index> {
|
||||
let mut conn = block_in_place(|| Connection::open(&CONFIG.db_path))?;
|
||||
block_in_place(|| conn.execute("CREATE TABLE IF NOT EXISTS files (
|
||||
filename TEXT PRIMARY KEY,
|
||||
modtime REAL NOT NULL,
|
||||
embedding_vector BLOB NOT NULL
|
||||
)", ()))?;
|
||||
|
||||
let SIZE = 1024; // TODO
|
||||
let IMAGE_SIZE = 384; // TODO
|
||||
let BS = 32;
|
||||
|
||||
let mut files = HashMap::new();
|
||||
let mut new_files: HashMap<String, (i64, Option<Vec<u8>>)> = HashMap::new();
|
||||
let mut vectors = faiss::index_factory(SIZE, "Flat", faiss::MetricType::InnerProduct)?.into_flat()?;
|
||||
|
||||
block_in_place(|| -> Result<()> {
|
||||
let mut stmt = conn.prepare_cached("SELECT filename, modtime FROM files")?;
|
||||
let mut rows = stmt.query([])?;
|
||||
while let Some(row) = rows.next()? {
|
||||
let filename: String = row.get(0)?;
|
||||
let modtime: i64 = row.get(1)?;
|
||||
files.insert(filename, modtime);
|
||||
}
|
||||
|
||||
for entry in walkdir::WalkDir::new(&CONFIG.images_path).follow_links(true) {
|
||||
let entry = entry?;
|
||||
if entry.file_type().is_file() {
|
||||
let metadata = entry.metadata()?;
|
||||
let modtime = metadata.modified()?.duration_since(std::time::UNIX_EPOCH)?.as_secs() as i64;
|
||||
let filename = entry.path().strip_prefix(&CONFIG.images_path)?.to_string_lossy().to_string();
|
||||
match files.get(&filename) {
|
||||
Some(old_modtime) if *old_modtime < modtime => new_files.insert(filename, (modtime, None)),
|
||||
None => new_files.insert(filename, (modtime, None)),
|
||||
_ => None
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
let (itx, mut irx) = tokio::sync::mpsc::channel(BS * 2);
|
||||
|
||||
let new_files_ = new_files.clone();
|
||||
let image_reader_task = tokio::task::spawn_blocking(move || {
|
||||
new_files_.par_iter().try_for_each(|(filename, _)| -> Result<()> {
|
||||
let mut path = Path::new(&CONFIG.images_path).to_path_buf();
|
||||
path.push(filename);
|
||||
let image = ImageReader::open(path)?.with_guessed_format()?.decode()?;
|
||||
let resized = imageops::resize(&image.into_rgb8(), IMAGE_SIZE, IMAGE_SIZE, imageops::Lanczos3);
|
||||
let mut bytes: Vec<u8> = Vec::new();
|
||||
resized.write_to(&mut Cursor::new(&mut bytes), image::ImageOutputFormat::Png)?;
|
||||
itx.blocking_send((filename.to_string(), bytes))?;
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
let dispatch_batch = |batch| {
|
||||
|
||||
};
|
||||
|
||||
let mut batch = vec![];
|
||||
while let Some((filename, image)) = irx.recv().await {
|
||||
if batch.len() == BS {
|
||||
dispatch_batch(std::mem::replace(&mut batch, vec![]));
|
||||
}
|
||||
batch.push((filename, image));
|
||||
}
|
||||
if batch.len() > 0 {
|
||||
dispatch_batch(std::mem::replace(&mut batch, vec![]));
|
||||
}
|
||||
|
||||
// TODO switch to blocking
|
||||
{
|
||||
let tx = conn.transaction()?;
|
||||
{
|
||||
let mut stmt = tx.prepare_cached("INSERT OR REPLACE INTO files VALUES (?, ?, ?)")?;
|
||||
for (filename, (modtime, embedding)) in new_files {
|
||||
stmt.execute((filename, modtime, embedding.unwrap()))?;
|
||||
}
|
||||
}
|
||||
tx.commit()?;
|
||||
}
|
||||
|
||||
Ok(Index {
|
||||
vectors: vectors
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
if std::env::var_os("RUST_LOG").is_none() {
|
||||
std::env::set_var("RUST_LOG", format!("meme-search-engine={}", CONFIG.log_level))
|
||||
}
|
||||
|
||||
let notify = tokio::sync::Notify::new();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
notify.notified().await;
|
||||
let index = build_index().await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
//let db = Arc::new(RwLock::new(DB::init().await?));
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(health))
|
||||
.route("/", post(run_query));
|
||||
//.layer(AddExtensionLayer::new(db));
|
||||
|
||||
let addr = CONFIG.listen_address.parse().unwrap();
|
||||
tracing::info!("listening on {}", addr);
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health() -> String {
|
||||
format!("OK")
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct RawQuery {
|
||||
text: Vec<String>,
|
||||
images: Vec<String> // base64 (sorry)
|
||||
}
|
||||
|
||||
async fn run_query(query: Json<RawQuery>) -> Json<Vec<(String, f32)>> {
|
||||
tracing::info!("{:?}", query);
|
||||
Json(vec![])
|
||||
}
|
23
misc/src/util.rs
Normal file
23
misc/src/util.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use anyhow::{Result, Context};
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub log_level: String,
|
||||
pub listen_address: String,
|
||||
pub images_path: String,
|
||||
pub db_path: String,
|
||||
pub backend_url: String
|
||||
}
|
||||
|
||||
fn load_config() -> Result<Config> {
|
||||
use config::{Config, File};
|
||||
let s = Config::builder()
|
||||
.add_source(File::with_name("./config"))
|
||||
.build().context("loading config")?;
|
||||
Ok(s.try_deserialize().context("parsing config")?)
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
pub static ref CONFIG: Config = load_config().unwrap();
|
||||
}
|
1
misc/top.json
Normal file
1
misc/top.json
Normal file
File diff suppressed because one or more lines are too long
1
misc/top2.json
Normal file
1
misc/top2.json
Normal file
File diff suppressed because one or more lines are too long
19
misc/train_xgboost.py
Normal file
19
misc/train_xgboost.py
Normal file
@ -0,0 +1,19 @@
|
||||
import numpy
|
||||
import xgboost as xgb
|
||||
|
||||
import shared
|
||||
|
||||
trains, validations = shared.fetch_ratings()
|
||||
|
||||
ranker = xgb.XGBRanker(
|
||||
tree_method="hist",
|
||||
lambdarank_num_pair_per_sample=8,
|
||||
objective="rank:ndcg",
|
||||
lambdarank_pair_method="topk",
|
||||
device="cuda"
|
||||
)
|
||||
flat_samples = [ sample for trainss in trains for sample in trainss ]
|
||||
X = numpy.concatenate([ numpy.stack((meme1, meme2)) for meme1, meme2, rating in flat_samples ])
|
||||
Y = numpy.concatenate([ numpy.stack((int(rating), int(1 - rating))) for meme1, meme2, rating in flat_samples ])
|
||||
qid = numpy.concatenate([ numpy.stack((i, i)) for i in range(len(flat_samples)) ])
|
||||
ranker.fit(X, Y, qid=qid, verbose=True)
|
Loading…
Reference in New Issue
Block a user