1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-10-24 03:57:39 +00:00

"release" unfinished scripts and miscellaneous JSON files

This commit is contained in:
2024-05-18 14:34:30 +01:00
parent caa8306ff7
commit fa863c2075
14 changed files with 15203 additions and 0 deletions

7041
meme-rater/log.jsonl Normal file

File diff suppressed because it is too large Load Diff

2795
misc/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

29
misc/Cargo.toml Normal file
View File

@@ -0,0 +1,29 @@
[package]
name = "meme-search-engine"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = { version = "0.5", features = ["multipart"] }
tokio = { version = "1.0", features = ["full"] }
anyhow = "1"
tracing = "0.1"
tracing-subscriber = { version="0.3", features = ["env-filter"] }
tower-http = { version = "0.2.0", features = ["fs", "trace", "add-extension"] }
rusty_ulid = "1"
serde = { version = "1.0", features = ["derive"] }
chrono = { version = "0.4", features = ["serde"] }
rmp-serde = "1"
futures-util = "0.3"
regex = "1.5"
lazy_static = "1"
config = { version = "0.13", default-features = false, features = ["toml"] }
faiss = { version = "0.12", features = [] }
reqwest = "0.11"
walkdir = "2"
rusqlite = { version = "0.30.0", features = ["bundled"] }
futures = "0.3"
image = { version = "0.24", features = ["avif", "webp", "default"] }
rayon = "1.8"

189
misc/clip_accursed.py Normal file
View File

@@ -0,0 +1,189 @@
import os
import time
import threading
from aiohttp import web
import aiohttp
import asyncio
import traceback
import umsgpack
import collections
import queue
from PIL import Image
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
import io
import json
import sys
import torch
from transformers import SiglipTokenizer, SiglipImageProcessor, T5TokenizerFast, SiglipTextConfig, SiglipVisionConfig
import numpy
with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)
# blatantly copypasted from colab
# https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb
VARIANT, RES = CONFIG["model"]
CKPT, TXTVARIANT, EMBDIM, SEQLEN, VOCAB = {
("So400m/14", 384): ("webli_en_so400m_384_58765454-fp16.safetensors", "So400m", 1152, 64, 32_000),
}[VARIANT, RES]
model_cfg = ml_collections.ConfigDict()
model_cfg.image_model = "vit" # TODO(lbeyer): remove later, default
model_cfg.text_model = "proj.image_text.text_transformer" # TODO(lbeyer): remove later, default
model_cfg.image = dict(variant=VARIANT, pool_type="map")
model_cfg.text = dict(variant=TXTVARIANT, vocab_size=VOCAB)
model_cfg.out_dim = (None, EMBDIM) # (image_out_dim, text_out_dim)
model_cfg.bias_init = -10.0
model_cfg.temperature_init = 10.0
model = model_mod.Model(**model_cfg)
init_params = None # sanity checks are a low-interest-rate phenomenon
model_params = model_mod.load(init_params, f"{CKPT}", model_cfg) # assume path
pp_img = pp_builder.get_preprocess_fn(f"resize({RES})|value_range(-1, 1)")
TOKENIZERS = {
32_000: "c4_en",
250_000: "mc4",
}
pp_txt = pp_builder.get_preprocess_fn(f'tokenize(max_len={SEQLEN}, model="{TOKENIZERS[VOCAB]}", eos="sticky", pad_value=1, inkey="text")')
print("Model loaded")
BS = CONFIG["max_batch_size"]
MODELNAME = CONFIG["model_name"]
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "images", "callback"])
items_ctr = Counter("modelserver_total_items", "Items run through model server", ["model", "modality"])
inference_time_hist = Histogram("modelserver_inftime", "Time running inference", ["model", "batch_size"])
batch_count_ctr = Counter("modelserver_batchcount", "Inference batches run", ["model"])
@jax.jit
def run_text_model(text_batch):
_, features, out = model.apply({"params": model_params}, None, text_batch)
return features
@jax.jit
def run_image_model(image_batch):
features, _, out = model.apply({"params": model_params}, image_batch, None)
return features
def round_down_to_power_of_two(x):
return 1<<(x.bit_length()-1)
def minimize_jits(fn, batch):
out = numpy.zeros((batch.shape[0], EMBDIM), dtype="float16")
i = 0
while True:
batch_dim = batch.shape[0]
s = round_down_to_power_of_two(batch_dim)
fst = batch[:s,...]
out[i:(i + s), ...] = fn(fst)
i += s
batch = batch[s:, ...]
if batch.shape[0] == 0: break
return out
def do_inference(params: InferenceParameters):
try:
text, images, callback = params
if text is not None:
items_ctr.labels(MODELNAME, "text").inc(text.shape[0])
with inference_time_hist.labels(MODELNAME + "-text", text.shape[0]).time():
features = run_text_model(text)
elif images is not None:
items_ctr.labels(MODELNAME, "image").inc(images.shape[0])
with inference_time_hist.labels(MODELNAME + "-image", images.shape[0]).time():
features = run_image_model(images)
batch_count_ctr.labels(MODELNAME).inc()
callback(True, numpy.asarray(features))
except Exception as e:
traceback.print_exc()
callback(False, str(e))
iq = queue.Queue(100)
def infer_thread():
while True:
do_inference(iq.get())
pq = queue.Queue(100)
def preprocessing_thread():
while True:
text, images, callback = pq.get()
try:
if text:
assert len(text) <= BS, f"max batch size is {BS}"
# I feel like this ought to be batchable but I can't see how to do that
text = numpy.array([pp_txt({"text": text})["labels"] for text in text])
elif images:
assert len(images) <= BS, f"max batch size is {BS}"
images = numpy.array([pp_img({"image": numpy.array(Image.open(io.BytesIO(image)).convert("RGB"))})["image"] for image in images])
else:
assert False, "images or text required"
iq.put(InferenceParameters(text, images, callback))
except Exception as e:
traceback.print_exc()
callback(False, str(e))
app = web.Application(client_max_size=2**26)
routes = web.RouteTableDef()
@routes.post("/")
async def run_inference(request):
loop = asyncio.get_event_loop()
data = umsgpack.loads(await request.read())
event = asyncio.Event()
results = None
def callback(*argv):
nonlocal results
results = argv
loop.call_soon_threadsafe(lambda: event.set())
pq.put_nowait(InferenceParameters(data.get("text"), data.get("images"), callback))
await event.wait()
body_data = results[1]
if results[0]:
status = 200
body_data = [x.astype("float16").tobytes() for x in body_data]
else:
status = 500
print(results[1])
return web.Response(body=umsgpack.dumps(body_data), status=status, content_type="application/msgpack")
@routes.get("/config")
async def config(request):
return web.Response(body=umsgpack.dumps({
"model": CONFIG["model"],
"batch": BS,
"image_size": (RES, RES),
"embedding_size": EMBDIM
}), status=200, content_type="application/msgpack")
@routes.get("/")
async def health(request):
return web.Response(status=204)
@routes.get("/metrics")
async def metrics(request):
return web.Response(body=generate_latest(REGISTRY))
app.router.add_routes(routes)
async def run_webserver():
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", CONFIG["port"])
print("Ready")
await site.start()
try:
th = threading.Thread(target=infer_thread)
th.start()
th = threading.Thread(target=preprocessing_thread)
th.start()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_webserver())
loop.run_forever()
except KeyboardInterrupt:
import sys
sys.exit(0)

4
misc/config.toml Normal file
View File

@@ -0,0 +1,4 @@
log_level = "debug"
listen_address = "[::1]:1710"
db_path = "./mse.sqlite3"
images_path = "/data/public"

13
misc/eval.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

212
misc/mse_accursed.py Normal file
View File

@@ -0,0 +1,212 @@
from aiohttp import web
import aiohttp
import asyncio
import traceback
import umsgpack
from PIL import Image
import base64
import aiosqlite
import faiss
import numpy
import os
import aiohttp_cors
import json
import io
import sys
from concurrent.futures import ProcessPoolExecutor
with open(sys.argv[1], "r") as config_file:
CONFIG = json.load(config_file)
app = web.Application(client_max_size=32*1024**2)
routes = web.RouteTableDef()
async def clip_server(query, unpack_buffer=True):
async with aiohttp.ClientSession() as sess:
async with sess.post(CONFIG["clip_server"], data=umsgpack.dumps(query)) as res:
response = umsgpack.loads(await res.read())
if res.status == 200:
if unpack_buffer:
response = [ numpy.frombuffer(x, dtype="float16") for x in response ]
return response
else:
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
@routes.post("/")
async def run_query(request):
data = await request.json()
embeddings = []
if images := data.get("images", []):
embeddings.extend(await clip_server({ "images": [ base64.b64decode(x) for x, w in images ] }))
if text := data.get("text", []):
embeddings.extend(await clip_server({ "text": [ x for x, w in text ] }))
weights = [ w for x, w in images ] + [ w for x, w in text ]
embeddings = [ e * w for e, w in zip(embeddings, weights) ]
if not embeddings:
return web.json_response([])
return web.json_response(app["index"].search(sum(embeddings)))
@routes.get("/")
async def health_check(request):
return web.Response(text="OK")
@routes.post("/reload_index")
async def reload_index_route(request):
await request.app["index"].reload()
return web.json_response(True)
def load_image(path, image_size):
im = Image.open(path)
im.draft("RGB", image_size)
buf = io.BytesIO()
im.resize(image_size).convert("RGB").save(buf, format="BMP")
return buf.getvalue(), path
class Index:
def __init__(self, inference_server_config):
self.faiss_index = faiss.IndexFlatIP(inference_server_config["embedding_size"])
self.associated_filenames = []
self.inference_server_config = inference_server_config
self.lock = asyncio.Lock()
def search(self, query):
distances, indices = self.faiss_index.search(numpy.array([query]), 4000)
distances = distances[0]
indices = indices[0]
try:
indices = indices[:numpy.where(indices==-1)[0][0]]
except IndexError: pass
return [ { "score": float(distance), "file": self.associated_filenames[index] } for index, distance in zip(indices, distances) ]
async def reload(self):
async with self.lock:
with ProcessPoolExecutor(max_workers=12) as executor:
print("Indexing")
conn = await aiosqlite.connect(CONFIG["db_path"], parent_loop=asyncio.get_running_loop())
conn.row_factory = aiosqlite.Row
await conn.executescript("""
CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
modtime REAL NOT NULL,
embedding_vector BLOB NOT NULL
);
""")
try:
async with asyncio.TaskGroup() as tg:
batch_sem = asyncio.Semaphore(32)
modified = set()
async def do_batch(batch):
try:
query = { "images": [ arg[2] for arg in batch ] }
embeddings = await clip_server(query, False)
await conn.executemany("INSERT OR REPLACE INTO files VALUES (?, ?, ?)", [
(filename, modtime, embedding) for (filename, modtime, _), embedding in zip(batch, embeddings)
])
await conn.commit()
for filename, _, _ in batch:
modified.add(filename)
sys.stdout.write(".")
finally:
batch_sem.release()
async def dispatch_batch(batch):
await batch_sem.acquire()
tg.create_task(do_batch(batch))
files = {}
for filename, modtime in await conn.execute_fetchall("SELECT filename, modtime FROM files"):
files[filename] = modtime
await conn.commit()
batch = []
for dirpath, _, filenames in os.walk(CONFIG["files"]):
paths = []
for file in filenames:
path = os.path.join(dirpath, file)
file = os.path.relpath(path, CONFIG["files"])
st = os.stat(path)
if st.st_mtime != files.get(file):
paths.append(path)
for task in asyncio.as_completed([ asyncio.get_running_loop().run_in_executor(executor, load_image, path, self.inference_server_config["image_size"]) for path in paths ]):
try:
b, path = await task
st = os.stat(path)
file = os.path.relpath(path, CONFIG["files"])
except Exception as e:
print(file, "failed", e)
continue
batch.append((file, st.st_mtime, b))
if len(batch) == self.inference_server_config["batch"]:
await dispatch_batch(batch)
batch = []
if batch:
await dispatch_batch(batch)
remove_indices = []
for index, filename in enumerate(self.associated_filenames):
if filename not in files or filename in modified:
remove_indices.append(index)
self.associated_filenames[index] = None
if filename not in files:
await conn.execute("DELETE FROM files WHERE filename = ?", (filename,))
await conn.commit()
# TODO concurrency
# TODO understand what that comment meant
if remove_indices:
self.faiss_index.remove_ids(numpy.array(remove_indices))
self.associated_filenames = [ x for x in self.associated_filenames if x is not None ]
filenames_set = set(self.associated_filenames)
new_data = []
new_filenames = []
async with conn.execute("SELECT * FROM files") as csr:
while row := await csr.fetchone():
filename, modtime, embedding_vector = row
if filename not in filenames_set:
new_data.append(numpy.frombuffer(embedding_vector, dtype="float16"))
new_filenames.append(filename)
new_data = numpy.array(new_data)
self.associated_filenames.extend(new_filenames)
self.faiss_index.add(new_data)
finally:
await conn.close()
app.router.add_routes(routes)
cors = aiohttp_cors.setup(app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=False,
expose_headers="*",
allow_headers="*",
)
})
for route in list(app.router.routes()):
cors.add(route)
async def main():
while True:
async with aiohttp.ClientSession() as sess:
try:
async with await sess.get(CONFIG["clip_server"] + "config") as res:
inference_server_config = umsgpack.unpackb(await res.read())
print("Backend config:", inference_server_config)
break
except:
traceback.print_exc()
await asyncio.sleep(1)
index = Index(inference_server_config)
app["index"] = index
await index.reload()
print("Ready")
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "", CONFIG["port"])
await site.start()
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
loop.run_forever()

167
misc/src/main.rs Normal file
View File

@@ -0,0 +1,167 @@
use serde::{Deserialize, Serialize};
use tower_http::{services::ServeDir, add_extension::AddExtensionLayer, services::ServeFile};
use axum::{extract::{Json, Extension, Multipart, Path as AxumPath}, http::{StatusCode, Request}, response::{IntoResponse}, body::Body, routing::{get, post, get_service}, Router};
use std::sync::Arc;
use tokio::{sync::RwLock, runtime::Handle, fs::File};
use anyhow::Result;
use rusqlite::Connection;
use std::collections::HashMap;
use futures::{stream, StreamExt};
use std::path::Path;
use image::{io::Reader as ImageReader, imageops};
use tokio::task::block_in_place;
use rayon::prelude::*;
use std::io::Cursor;
mod util;
use util::CONFIG;
#[derive(Serialize, Deserialize)]
struct InferenceServerConfig {
model: String,
embedding_size: usize,
batch: usize,
image_size: usize
}
struct Index {
vectors: faiss::FlatIndex // we need the index to implement Send, which an arbitrary boxed one might not
}
async fn build_index() -> Result<Index> {
let mut conn = block_in_place(|| Connection::open(&CONFIG.db_path))?;
block_in_place(|| conn.execute("CREATE TABLE IF NOT EXISTS files (
filename TEXT PRIMARY KEY,
modtime REAL NOT NULL,
embedding_vector BLOB NOT NULL
)", ()))?;
let SIZE = 1024; // TODO
let IMAGE_SIZE = 384; // TODO
let BS = 32;
let mut files = HashMap::new();
let mut new_files: HashMap<String, (i64, Option<Vec<u8>>)> = HashMap::new();
let mut vectors = faiss::index_factory(SIZE, "Flat", faiss::MetricType::InnerProduct)?.into_flat()?;
block_in_place(|| -> Result<()> {
let mut stmt = conn.prepare_cached("SELECT filename, modtime FROM files")?;
let mut rows = stmt.query([])?;
while let Some(row) = rows.next()? {
let filename: String = row.get(0)?;
let modtime: i64 = row.get(1)?;
files.insert(filename, modtime);
}
for entry in walkdir::WalkDir::new(&CONFIG.images_path).follow_links(true) {
let entry = entry?;
if entry.file_type().is_file() {
let metadata = entry.metadata()?;
let modtime = metadata.modified()?.duration_since(std::time::UNIX_EPOCH)?.as_secs() as i64;
let filename = entry.path().strip_prefix(&CONFIG.images_path)?.to_string_lossy().to_string();
match files.get(&filename) {
Some(old_modtime) if *old_modtime < modtime => new_files.insert(filename, (modtime, None)),
None => new_files.insert(filename, (modtime, None)),
_ => None
};
}
}
Ok(())
})?;
let (itx, mut irx) = tokio::sync::mpsc::channel(BS * 2);
let new_files_ = new_files.clone();
let image_reader_task = tokio::task::spawn_blocking(move || {
new_files_.par_iter().try_for_each(|(filename, _)| -> Result<()> {
let mut path = Path::new(&CONFIG.images_path).to_path_buf();
path.push(filename);
let image = ImageReader::open(path)?.with_guessed_format()?.decode()?;
let resized = imageops::resize(&image.into_rgb8(), IMAGE_SIZE, IMAGE_SIZE, imageops::Lanczos3);
let mut bytes: Vec<u8> = Vec::new();
resized.write_to(&mut Cursor::new(&mut bytes), image::ImageOutputFormat::Png)?;
itx.blocking_send((filename.to_string(), bytes))?;
Ok(())
})
});
let dispatch_batch = |batch| {
};
let mut batch = vec![];
while let Some((filename, image)) = irx.recv().await {
if batch.len() == BS {
dispatch_batch(std::mem::replace(&mut batch, vec![]));
}
batch.push((filename, image));
}
if batch.len() > 0 {
dispatch_batch(std::mem::replace(&mut batch, vec![]));
}
// TODO switch to blocking
{
let tx = conn.transaction()?;
{
let mut stmt = tx.prepare_cached("INSERT OR REPLACE INTO files VALUES (?, ?, ?)")?;
for (filename, (modtime, embedding)) in new_files {
stmt.execute((filename, modtime, embedding.unwrap()))?;
}
}
tx.commit()?;
}
Ok(Index {
vectors: vectors
})
}
#[tokio::main]
async fn main() -> Result<()> {
if std::env::var_os("RUST_LOG").is_none() {
std::env::set_var("RUST_LOG", format!("meme-search-engine={}", CONFIG.log_level))
}
let notify = tokio::sync::Notify::new();
tokio::spawn(async move {
loop {
notify.notified().await;
let index = build_index().await.unwrap();
}
});
tracing_subscriber::fmt::init();
//let db = Arc::new(RwLock::new(DB::init().await?));
let app = Router::new()
.route("/", get(health))
.route("/", post(run_query));
//.layer(AddExtensionLayer::new(db));
let addr = CONFIG.listen_address.parse().unwrap();
tracing::info!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await?;
Ok(())
}
async fn health() -> String {
format!("OK")
}
#[derive(Debug, Serialize, Deserialize)]
struct RawQuery {
text: Vec<String>,
images: Vec<String> // base64 (sorry)
}
async fn run_query(query: Json<RawQuery>) -> Json<Vec<(String, f32)>> {
tracing::info!("{:?}", query);
Json(vec![])
}

23
misc/src/util.rs Normal file
View File

@@ -0,0 +1,23 @@
use anyhow::{Result, Context};
use serde::{Serialize, Deserialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct Config {
pub log_level: String,
pub listen_address: String,
pub images_path: String,
pub db_path: String,
pub backend_url: String
}
fn load_config() -> Result<Config> {
use config::{Config, File};
let s = Config::builder()
.add_source(File::with_name("./config"))
.build().context("loading config")?;
Ok(s.try_deserialize().context("parsing config")?)
}
lazy_static::lazy_static! {
pub static ref CONFIG: Config = load_config().unwrap();
}

1
misc/top.json Normal file

File diff suppressed because one or more lines are too long

1
misc/top2.json Normal file

File diff suppressed because one or more lines are too long

19
misc/train_xgboost.py Normal file
View File

@@ -0,0 +1,19 @@
import numpy
import xgboost as xgb
import shared
trains, validations = shared.fetch_ratings()
ranker = xgb.XGBRanker(
tree_method="hist",
lambdarank_num_pair_per_sample=8,
objective="rank:ndcg",
lambdarank_pair_method="topk",
device="cuda"
)
flat_samples = [ sample for trainss in trains for sample in trainss ]
X = numpy.concatenate([ numpy.stack((meme1, meme2)) for meme1, meme2, rating in flat_samples ])
Y = numpy.concatenate([ numpy.stack((int(rating), int(1 - rating))) for meme1, meme2, rating in flat_samples ])
qid = numpy.concatenate([ numpy.stack((i, i)) for i in range(len(flat_samples)) ])
ranker.fit(X, Y, qid=qid, verbose=True)