1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-07-25 21:32:52 +00:00

query code

This commit is contained in:
osmarks 2025-01-18 17:09:00 +00:00
parent 63caba2746
commit a5a6e960bb
9 changed files with 683 additions and 329 deletions

120
Cargo.lock generated
View File

@ -194,6 +194,17 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "auto-const-array"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62f7df18977a1ee03650ee4b31b4aefed6d56bac188760b6e37610400fe8d4bb"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.4.0" version = "1.4.0"
@ -1190,6 +1201,15 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]] [[package]]
name = "gemm" name = "gemm"
version = "0.17.1" version = "0.17.1"
@ -1747,6 +1767,16 @@ dependencies = [
"syn 2.0.79", "syn 2.0.79",
] ]
[[package]]
name = "io-uring"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "595a0399f411a508feb2ec1e970a4a30c249351e30208960d58298de8660b0e5"
dependencies = [
"bitflags 1.3.2",
"libc",
]
[[package]] [[package]]
name = "ipnet" name = "ipnet"
version = "2.10.1" version = "2.10.1"
@ -1956,10 +1986,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"num_cpus",
"once_cell",
"rawpointer", "rawpointer",
"thread-tree",
] ]
[[package]] [[package]]
@ -2034,14 +2061,17 @@ dependencies = [
"futures-util", "futures-util",
"half", "half",
"hamming", "hamming",
"http-body-util",
"hyper",
"image", "image",
"itertools 0.13.0", "itertools 0.13.0",
"json5", "json5",
"lazy_static", "lazy_static",
"matrixmultiply",
"maud", "maud",
"memmap2", "memmap2",
"mimalloc", "mimalloc",
"monoio",
"monoio-compat",
"ndarray", "ndarray",
"num_cpus", "num_cpus",
"prometheus", "prometheus",
@ -2076,6 +2106,15 @@ dependencies = [
"stable_deref_trait", "stable_deref_trait",
] ]
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "mimalloc" name = "mimalloc"
version = "0.1.43" version = "0.1.43"
@ -2150,6 +2189,51 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "monoio"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3bd0f8bcde87b1949f95338b547543fcab187bc7e7a5024247e359a5e828ba6a"
dependencies = [
"auto-const-array",
"bytes",
"fxhash",
"io-uring",
"libc",
"memchr",
"mio 0.8.11",
"monoio-macros",
"nix",
"pin-project-lite",
"socket2",
"tokio",
"windows-sys 0.48.0",
]
[[package]]
name = "monoio-compat"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aab377095e8792bb9829fac13c5cb111a0af4b733f4821581ca3f6e1eec5ae6c"
dependencies = [
"hyper",
"monoio",
"pin-project-lite",
"reusable-box-future",
"tokio",
]
[[package]]
name = "monoio-macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "176a5f5e69613d9e88337cf2a65e11135332b4efbcc628404a7c555e4452084c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]] [[package]]
name = "mp4parse" name = "mp4parse"
version = "0.17.0" version = "0.17.0"
@ -2209,6 +2293,19 @@ version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if",
"libc",
"memoffset",
"pin-utils",
]
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.3" version = "7.1.3"
@ -2982,6 +3079,12 @@ dependencies = [
"windows-registry", "windows-registry",
] ]
[[package]]
name = "reusable-box-future"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e0e61cd21fbddd85fbd9367b775660a01d388c08a61c6d2824af480b0309bb9"
[[package]] [[package]]
name = "rgb" name = "rgb"
version = "0.8.50" version = "0.8.50"
@ -3809,15 +3912,6 @@ dependencies = [
"syn 2.0.79", "syn 2.0.79",
] ]
[[package]]
name = "thread-tree"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630"
dependencies = [
"crossbeam-channel",
]
[[package]] [[package]]
name = "thread_local" name = "thread_local"
version = "1.1.8" version = "1.1.8"

View File

@ -57,8 +57,11 @@ bitcode = "0.6"
simsimd = "6" simsimd = "6"
foldhash = "0.1" foldhash = "0.1"
memmap2 = "0.9" memmap2 = "0.9"
matrixmultiply = { version = "0.3", features = ["threading"] }
candle-core = "0.8" candle-core = "0.8"
monoio = "0.2"
hyper = "1"
monoio-compat = { version = "0.2", features = ["hyper"] }
http-body-util = "0.1"
[[bin]] [[bin]]
name = "reddit-dump" name = "reddit-dump"

View File

@ -9,12 +9,18 @@ import time
from tqdm import tqdm from tqdm import tqdm
import sys import sys
from collections import defaultdict from collections import defaultdict
import msgpack
from model import Config, BradleyTerry from model import Config, BradleyTerry
import shared import shared
batch_size = 128 def fetch_files_with_timestamps():
num_pairs = batch_size * 1024 csr = shared.db.execute("SELECT filename, embedding, timestamp FROM files WHERE embedding IS NOT NULL")
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy(), row[2]) for row in csr.fetchall() ]
csr.close()
return x
batch_size = 2048
device = "cuda" device = "cuda"
config = Config( config = Config(
@ -27,36 +33,42 @@ config = Config(
dropout=0.1 dropout=0.1
) )
model = BradleyTerry(config) model = BradleyTerry(config)
model.eval()
modelc, _ = shared.checkpoint_for(int(sys.argv[1])) modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc)) model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters()) params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters") print(f"{params/1e6:.1f}M parameters")
print(model) print(model)
files = shared.fetch_all_files() for x in model.ensemble.models:
results = {} x.output.bias.data.fill_(0) # hack to match behaviour of cut-down implementation
results = defaultdict(list)
model.eval() model.eval()
files = fetch_files_with_timestamps()
with torch.inference_mode(): with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)): for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size] batch = files[bstart:bstart + batch_size]
filenames = [ f1 for f1, e1 in batch ] timestamps = [ t1 for f1, e1, t1 in batch ]
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ]) embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1, t1 in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device) inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy() scores = model.ensemble(inputs).mean(dim=0).cpu().numpy()
for sr in scores:
for i, s in enumerate(sr):
results[i].append(s)
# add an extra timestamp channel
results[config.output_channels].extend(timestamps)
cdfs = []
# we want to encode scores in one byte, and 255/0xFF is reserved for "greater than maximum bucket"
cdf_bins = 255
for i, s in results.items():
quantiles = numpy.linspace(0, 1, cdf_bins)
cdf = numpy.quantile(numpy.array(s), quantiles)
print(cdf)
cdfs.append(cdf.tolist())
channel = int(sys.argv[2]) with open("cdfs.msgpack", "wb") as f:
percentile = float(sys.argv[3]) msgpack.pack(cdfs, f)
output_pairs = int(sys.argv[4])
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
select_from = top[:int(len(top) * percentile)]
out = []
for _ in range(output_pairs):
# dummy score for compatibility with existing code
out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0))
with open("top.json", "w") as f:
json.dump(out, f)

View File

@ -10,11 +10,12 @@ CREATE TABLE IF NOT EXISTS files (
title TEXT NOT NULL, title TEXT NOT NULL,
link TEXT NOT NULL, link TEXT NOT NULL,
embedding BLOB NOT NULL, embedding BLOB NOT NULL,
timestamp INTEGER NOT NULL,
UNIQUE (filename) UNIQUE (filename)
); );
""") """)
with jsonlines.open("sample.jsonl") as reader: with jsonlines.open("sample.jsonl") as reader:
for obj in reader: for obj in reader:
shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes()))) shared.db.execute("INSERT OR REPLACE INTO files (filename, title, link, embedding, timestamp) VALUES (?, ?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes()), obj["timestamp"]))
shared.db.commit() shared.db.commit()

View File

@ -1,4 +1,5 @@
use image::codecs::bmp::BmpEncoder; use image::codecs::bmp::BmpEncoder;
use lazy_static::lazy_static;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use std::borrow::Borrow; use std::borrow::Borrow;
use std::cell::RefCell; use std::cell::RefCell;
@ -10,6 +11,11 @@ use tracing::instrument;
use fast_image_resize::{Resizer, ResizeOptions, ResizeAlg}; use fast_image_resize::{Resizer, ResizeOptions, ResizeAlg};
use fast_image_resize::images::{Image, ImageRef}; use fast_image_resize::images::{Image, ImageRef};
use anyhow::Context; use anyhow::Context;
use std::collections::HashMap;
use ndarray::ArrayBase;
use prometheus::{register_int_counter_vec, IntCounterVec};
use base64::prelude::*;
use std::future::Future;
std::thread_local! { std::thread_local! {
static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new()); static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
@ -22,7 +28,7 @@ pub struct InferenceServerConfig {
pub embedding_size: usize, pub embedding_size: usize,
} }
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> { pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: &InferenceServerConfig, image: T) -> Result<Vec<u8>> {
// the model currently in use wants aspect ratio 1:1 regardless of input // the model currently in use wants aspect ratio 1:1 regardless of input
// I think this was previously being handled in the CLIP server but that is slightly lossy // I think this was previously being handled in the CLIP server but that is slightly lossy
@ -48,7 +54,7 @@ pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: I
} }
pub async fn resize_for_embed<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> { pub async fn resize_for_embed<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
let resized = tokio::task::spawn_blocking(move || resize_for_embed_sync(config, image)).await??; let resized = tokio::task::spawn_blocking(move || resize_for_embed_sync(&config, image)).await??;
Ok(resized) Ok(resized)
} }
@ -163,5 +169,103 @@ pub struct IndexHeader {
pub count: u32, pub count: u32,
pub dead_count: u32, pub dead_count: u32,
pub record_pad_size: usize, pub record_pad_size: usize,
pub quantizer: diskann::vector::ProductQuantizer pub quantizer: diskann::vector::ProductQuantizer,
pub descriptor_cdfs: Vec<Vec<f32>>
}
#[derive(Serialize, Deserialize)]
pub struct FrontendInit {
pub n_total: u64,
pub predefined_embedding_names: Vec<String>,
pub d_emb: usize
}
pub type EmbeddingVector = Vec<f32>;
#[derive(Debug, Serialize)]
pub struct QueryResult {
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>)>,
pub formats: Vec<String>,
pub extensions: HashMap<String, String>,
}
#[derive(Debug, Deserialize)]
pub struct QueryTerm {
pub embedding: Option<EmbeddingVector>,
pub image: Option<String>,
pub text: Option<String>,
pub predefined_embedding: Option<String>,
pub weight: Option<f32>,
}
#[derive(Debug, Deserialize)]
pub struct QueryRequest {
pub terms: Vec<QueryTerm>,
pub k: Option<usize>,
#[serde(default)]
pub include_video: bool
}
lazy_static::lazy_static! {
static ref TERMS_COUNTER: IntCounterVec = register_int_counter_vec!("mse_terms", "terms used in queries, by type", &["type"]).unwrap();
}
pub async fn get_total_embedding<A: Future<Output = Result<Vec<Vec<u8>>>>, B: Future<Output = Result<serde_bytes::ByteBuf>>, S: Clone, T: Clone, F: Fn(EmbeddingRequest, S) -> A, G: Fn(Vec<u8>, T) -> B>(terms: &Vec<QueryTerm>, ic: &InferenceServerConfig, query_server: F, resize_image: G, predefined_embeddings: &HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<[usize; 1]>>>, image_state: T, query_state: S) -> Result<Vec<f32>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; ic.embedding_size]);
let mut image_batch = Vec::new();
let mut image_weights = Vec::new();
let mut text_batch = Vec::new();
let mut text_weights = Vec::new();
for term in terms {
if let Some(image) = &term.image {
TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc();
let bytes = BASE64_STANDARD.decode(image)?;
image_batch.push(resize_image(bytes, image_state.clone()).await?);
image_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(text) = &term.text {
TERMS_COUNTER.get_metric_with_label_values(&["text"]).unwrap().inc();
text_batch.push(text.clone());
text_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(embedding) = &term.embedding {
TERMS_COUNTER.get_metric_with_label_values(&["embedding"]).unwrap().inc();
let weight = term.weight.unwrap_or(1.0);
for (i, value) in embedding.iter().enumerate() {
total_embedding[i] += value * weight;
}
}
if let Some(name) = &term.predefined_embedding {
let embedding = predefined_embeddings.get(name).context("name invalid")?;
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
}
}
let mut batches = vec![];
if !image_batch.is_empty() {
batches.push(
(EmbeddingRequest::Images {
images: image_batch
}, image_weights)
);
}
if !text_batch.is_empty() {
batches.push(
(EmbeddingRequest::Text {
text: text_batch,
}, text_weights)
);
}
for (batch, weights) in batches {
let embs: Vec<Vec<u8>> = query_server(batch, query_state.clone()).await?;
for (emb, weight) in embs.into_iter().zip(weights) {
total_embedding += &(ndarray::Array::from_vec(decode_fp16_buffer(&emb)) * weight);
}
}
Ok(total_embedding.to_vec())
} }

View File

@ -66,7 +66,9 @@ struct CLIArguments {
#[argh(option, short='M', description="score model path")] #[argh(option, short='M', description="score model path")]
score_model: Option<String>, score_model: Option<String>,
#[argh(option, short='G', description="GPU (CUDA) device to use")] #[argh(option, short='G', description="GPU (CUDA) device to use")]
gpu: Option<usize> gpu: Option<usize>,
#[argh(option, description="descriptor CDFs")]
cdfs: Option<String>,
} }
#[derive(Clone, Deserialize, Serialize, Debug)] #[derive(Clone, Deserialize, Serialize, Debug)]
@ -301,6 +303,13 @@ fn main() -> Result<()> {
None None
}; };
let cdfs = if let Some(cdfs) = &args.cdfs {
let data = fs::read(cdfs).context("read cdfs")?;
Some(rmp_serde::from_read::<_, Vec<Vec<f32>>>(&data[..]).context("decode cdfs")?)
} else {
None
};
let mut output_file = args.output_embeddings.map(|x| fs::File::create(x).context("create output file")).transpose()?; let mut output_file = args.output_embeddings.map(|x| fs::File::create(x).context("create output file")).transpose()?;
let mut i: u64 = 0; let mut i: u64 = 0;
@ -436,17 +445,33 @@ fn main() -> Result<()> {
let codes = quantizer.quantize_batch(&batch_embeddings); let codes = quantizer.quantize_batch(&batch_embeddings);
let score_model = score_model.as_ref().context("score model needed to output index")?; let score_model = score_model.as_ref().context("score model needed to output index")?;
let cdfs = cdfs.as_ref().context("score model CDFs needed to output index")?;
let scores = score_model.score_batch(&batch_embeddings)?; let scores = score_model.score_batch(&batch_embeddings)?;
for (i, (x, _embedding)) in batch.into_iter().enumerate() { for (i, (x, _embedding)) in batch.into_iter().enumerate() {
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
let mut entry_scores = scores[i..(i + score_model.output_channels)].to_vec();
entry_scores.push(x.timestamp as f32); // seconds since epoch, so precision issues aren't awful
for (index, score) in entry_scores.iter().enumerate() {
// binary search CDF to invert
let cdf_bucket: u8 = match cdfs[index].binary_search_by(|x| x.partial_cmp(score).unwrap()) {
Ok(x) => x.try_into().unwrap(),
Err(x) => x.try_into().unwrap()
};
// write score descriptor to descriptors file
index_output_file.2.write_all(&[cdf_bucket])?;
}
let mut entry = PackedIndexEntry { let mut entry = PackedIndexEntry {
id: count + i as u32, id: count + i as u32,
vertices, vertices,
vector: x.embedding.chunks_exact(2).map(|x| u16::from_le_bytes([x[0], x[1]])).collect(), vector: x.embedding.chunks_exact(2).map(|x| u16::from_le_bytes([x[0], x[1]])).collect(),
timestamp: x.timestamp, timestamp: x.timestamp,
dimensions: x.metadata.dimension, dimensions: x.metadata.dimension,
scores: scores[i..(i + score_model.output_channels)].to_vec(), scores: entry_scores,
url: x.metadata.final_url, url: x.metadata.final_url,
shards shards
}; };
@ -504,7 +529,8 @@ fn main() -> Result<()> {
count: count as u32, count: count as u32,
record_pad_size: RECORD_PAD_SIZE, record_pad_size: RECORD_PAD_SIZE,
dead_count, dead_count,
quantizer: pq_codec.unwrap() quantizer: pq_codec.unwrap(),
descriptor_cdfs: cdfs.unwrap(),
}; };
file.write_all(rmp_serde::to_vec_named(&header)?.as_slice())?; file.write_all(rmp_serde::to_vec_named(&header)?.as_slice())?;
} }

View File

@ -13,7 +13,7 @@ use axum::{
Router, Router,
http::StatusCode http::StatusCode
}; };
use common::resize_for_embed_sync; use common::{resize_for_embed_sync, FrontendInit};
use compact_str::CompactString; use compact_str::CompactString;
use image::RgbImage; use image::RgbImage;
use image::{imageops::FilterType, ImageReader, DynamicImage, ImageFormat}; use image::{imageops::FilterType, ImageReader, DynamicImage, ImageFormat};
@ -24,7 +24,6 @@ use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use walkdir::WalkDir; use walkdir::WalkDir;
use base64::prelude::*;
use faiss::{ConcurrentIndex, Index}; use faiss::{ConcurrentIndex, Index};
use futures_util::stream::{StreamExt, TryStreamExt}; use futures_util::stream::{StreamExt, TryStreamExt};
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
@ -32,20 +31,19 @@ use tower_http::cors::CorsLayer;
use faiss::index::scalar_quantizer; use faiss::index::scalar_quantizer;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec}; use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
use ndarray::ArrayBase;
use tracing::instrument; use tracing::instrument;
use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine};
mod ocr; mod ocr;
mod common; mod common;
mod video_reader; mod video_reader;
use crate::ocr::scan_image; use crate::ocr::scan_image;
use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server, decode_fp16_buffer}; use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server, decode_fp16_buffer, QueryRequest, QueryResult, EmbeddingVector};
lazy_static! { lazy_static! {
static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap(); static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap();
static ref QUERIES_COUNTER: IntCounter = register_int_counter!("mse_queries", "queries executed").unwrap(); static ref QUERIES_COUNTER: IntCounter = register_int_counter!("mse_queries", "queries executed").unwrap();
static ref TERMS_COUNTER: IntCounterVec = register_int_counter_vec!("mse_terms", "terms used in queries, by type", &["type"]).unwrap();
static ref IMAGES_LOADED_COUNTER: IntCounter = register_int_counter!("mse_loads", "images loaded by ingest process").unwrap(); static ref IMAGES_LOADED_COUNTER: IntCounter = register_int_counter!("mse_loads", "images loaded by ingest process").unwrap();
static ref IMAGES_LOADED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_load_errors", "image load fails by ingest process").unwrap(); static ref IMAGES_LOADED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_load_errors", "image load fails by ingest process").unwrap();
static ref VIDEOS_LOADED_COUNTER: IntCounter = register_int_counter!("mse_video_loads", "video loaded by ingest process").unwrap(); static ref VIDEOS_LOADED_COUNTER: IntCounter = register_int_counter!("mse_video_loads", "video loaded by ingest process").unwrap();
@ -81,6 +79,13 @@ struct Config {
video_frame_interval: f32 video_frame_interval: f32
} }
#[derive(Debug, Clone)]
struct WConfig {
backend: InferenceServerConfig,
service: Config,
predefined_embeddings: HashMap<String, ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<[usize; 1]>>>
}
#[derive(Debug)] #[derive(Debug)]
struct IIndex { struct IIndex {
vectors: scalar_quantizer::ScalarQuantizerIndexImpl, vectors: scalar_quantizer::ScalarQuantizerIndexImpl,
@ -147,13 +152,6 @@ struct FileRecord {
needs_metadata: bool needs_metadata: bool
} }
#[derive(Debug, Clone)]
struct WConfig {
backend: InferenceServerConfig,
service: Config,
predefined_embeddings: HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::prelude::Dim<[usize; 1]>>>
}
#[derive(Debug)] #[derive(Debug)]
struct LoadedImage { struct LoadedImage {
image: Arc<DynamicImage>, image: Arc<DynamicImage>,
@ -387,7 +385,7 @@ async fn load_image(record: FileRecord, to_embed_tx: mpsc::Sender<EmbeddingInput
let mut last_metadata = None; let mut last_metadata = None;
let callback = |frame: RgbImage| { let callback = |frame: RgbImage| {
let frame: Arc<DynamicImage> = Arc::new(frame.into()); let frame: Arc<DynamicImage> = Arc::new(frame.into());
let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?; let embed_buf = resize_for_embed_sync(&config.backend, frame.clone())?;
let filename = Filename::VideoFrame(filename.clone(), i); let filename = Filename::VideoFrame(filename.clone(), i);
to_embed_tx.blocking_send(EmbeddingInput { to_embed_tx.blocking_send(EmbeddingInput {
image: embed_buf, image: embed_buf,
@ -893,32 +891,6 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
Ok(index) Ok(index)
} }
type EmbeddingVector = Vec<f32>;
#[derive(Debug, Serialize)]
struct QueryResult {
matches: Vec<(f32, String, String, u64, Option<(u32, u32)>)>,
formats: Vec<String>,
extensions: HashMap<String, String>,
}
#[derive(Debug, Deserialize)]
struct QueryTerm {
embedding: Option<EmbeddingVector>,
image: Option<String>,
text: Option<String>,
predefined_embedding: Option<String>,
weight: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct QueryRequest {
terms: Vec<QueryTerm>,
k: Option<usize>,
#[serde(default)]
include_video: bool
}
#[instrument(skip(index))] #[instrument(skip(index))]
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> { async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
let result = index.vectors.search(&query, k as usize)?; let result = index.vectors.search(&query, k as usize)?;
@ -958,65 +930,22 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
#[instrument(skip(config, client, index))] #[instrument(skip(config, client, index))]
async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> { async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; config.backend.embedding_size]); let embedding = common::get_total_embedding(
&req.terms,
let mut image_batch = Vec::new(); &config.backend,
let mut image_weights = Vec::new(); |batch, (config, client)| async move {
let mut text_batch = Vec::new(); query_clip_server(&client, &config.service.clip_server, "", batch).await
let mut text_weights = Vec::new(); },
|image, config| async move {
for term in &req.terms { let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&image))?);
if let Some(image) = &term.image { Ok(serde_bytes::ByteBuf::from(resize_for_embed(config.backend.clone(), image).await?))
TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc(); },
let bytes = BASE64_STANDARD.decode(image)?; &config.clone().predefined_embeddings,
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?); config.clone(),
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.backend.clone(), image).await?)); (config.clone(), client.clone())).await?;
image_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(text) = &term.text {
TERMS_COUNTER.get_metric_with_label_values(&["text"]).unwrap().inc();
text_batch.push(text.clone());
text_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(embedding) = &term.embedding {
TERMS_COUNTER.get_metric_with_label_values(&["embedding"]).unwrap().inc();
let weight = term.weight.unwrap_or(1.0);
for (i, value) in embedding.iter().enumerate() {
total_embedding[i] += value * weight;
}
}
if let Some(name) = &term.predefined_embedding {
let embedding = config.predefined_embeddings.get(name).context("name invalid")?;
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
}
}
let mut batches = vec![];
if !image_batch.is_empty() {
batches.push(
(EmbeddingRequest::Images {
images: image_batch
}, image_weights)
);
}
if !text_batch.is_empty() {
batches.push(
(EmbeddingRequest::Text {
text: text_batch,
}, text_weights)
);
}
for (batch, weights) in batches {
let embs: Vec<Vec<u8>> = query_clip_server(&client, &config.service.clip_server, "/", batch).await?;
for (emb, weight) in embs.into_iter().zip(weights) {
total_embedding += &(ndarray::Array::from_vec(decode_fp16_buffer(&emb)) * weight);
}
}
let k = req.k.unwrap_or(1000); let k = req.k.unwrap_or(1000);
let qres = query_index(index, total_embedding.to_vec(), k, req.include_video).await?; let qres = query_index(index, embedding, k, req.include_video).await?;
let mut extensions = HashMap::new(); let mut extensions = HashMap::new();
for (k, v) in image_formats(&config.service) { for (k, v) in image_formats(&config.service) {
@ -1030,13 +959,6 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
}).into_response()) }).into_response())
} }
#[derive(Serialize, Deserialize)]
struct FrontendInit {
n_total: u64,
predefined_embedding_names: Vec<String>,
d_emb: usize
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
console_subscriber::init(); console_subscriber::init();

File diff suppressed because one or more lines are too long

View File

@ -1,25 +1,32 @@
use anyhow::{bail, Context, Result}; use anyhow::{Context, Result};
use diskann::vector::scale_dot_result_f64; use lazy_static::lazy_static;
use serde::{Serialize, Deserialize}; use monoio::fs;
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
use std::os::unix::prelude::FileExt;
use std::path::PathBuf; use std::path::PathBuf;
use std::fs;
use base64::Engine; use base64::Engine;
use argh::FromArgs; use argh::FromArgs;
use chrono::{TimeZone, Utc, DateTime};
use itertools::Itertools; use itertools::Itertools;
use foldhash::{HashSet, HashSetExt}; use foldhash::{HashSet, HashSetExt};
use half::f16; use half::f16;
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, ProductQuantizer, DistanceLUT, scale_dot_result}}; use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64}};
use simsimd::SpatialSimilarity; use simsimd::SpatialSimilarity;
use memmap2::{Mmap, MmapOptions}; use memmap2::{Mmap, MmapOptions};
use std::rc::Rc;
use monoio::net::{TcpListener, TcpStream};
use monoio::io::IntoPollIo;
use hyper::{body::{Body, Bytes, Incoming, Frame}, server::conn::http1, Method, Request, Response, StatusCode};
use http_body_util::{BodyExt, Empty, Full};
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
use std::pin::Pin;
use std::future::Future;
use serde::Serialize;
use std::str::FromStr;
use std::collections::HashMap;
mod common; mod common;
use common::{PackedIndexEntry, IndexHeader}; use common::{resize_for_embed_sync, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
#[derive(FromArgs)] #[derive(FromArgs, Clone)]
#[argh(description="Query disk index")] #[argh(description="Query disk index")]
struct CLIArguments { struct CLIArguments {
#[argh(positional)] #[argh(positional)]
@ -35,77 +42,132 @@ struct CLIArguments {
#[argh(option, short='L', description="search list size")] #[argh(option, short='L', description="search list size")]
search_list_size: Option<usize>, search_list_size: Option<usize>,
#[argh(switch, description="always use full-precision vectors (slow)")] #[argh(switch, description="always use full-precision vectors (slow)")]
disable_pq: bool disable_pq: bool,
#[argh(option, short='l', description="listen address")]
listen_address: Option<String>,
#[argh(option, short='c', description="clip server")]
clip_server: Option<String>,
} }
fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> { lazy_static! {
let offset = id as usize * header.record_pad_size; static ref QUERIES_COUNTER: IntCounter = register_int_counter!("mse_queries", "queries executed").unwrap();
let mut buf = vec![0; header.record_pad_size as usize]; static ref TERMS_COUNTER: IntCounterVec = register_int_counter_vec!("mse_terms", "terms used in queries, by type", &["type"]).unwrap();
data_file.read_exact_at(&mut buf, offset as u64)?; static ref NODE_READS: IntCounter = register_int_counter!("mse_node_reads", "graph nodes read").unwrap();
static ref PQ_COMPARISONS: IntCounter = register_int_counter!("mse_pq_comparisons", "product quantization comparisons").unwrap();
}
async fn read_node<'a>(id: u32, index: Rc<Index>) -> Result<PackedIndexEntry> {
let offset = id as usize * index.header.record_pad_size;
let buf = vec![0; index.header.record_pad_size as usize];
let (res, buf) = index.data_file.read_exact_at(buf, offset as u64).await;
res?;
NODE_READS.inc();
let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize; let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize;
Ok(bitcode::decode(&buf[2..len+2])?) Ok(bitcode::decode(&buf[2..len+2])?)
} }
fn read_pq_codes(id: u32, codes: &Mmap, buf: &mut Vec<u8>, pq_code_size: usize) { fn next_several_unvisited(s: &mut NeighbourBuffer, n: usize) -> Option<Vec<u32>> {
let loc = (id as usize) * pq_code_size; let mut result = Vec::new();
buf.extend(&codes[loc..loc+pq_code_size]) for _ in 0..n {
if let Some(neighbour) = s.next_unvisited() {
result.push(neighbour);
} else {
break;
}
}
if result.len() > 0 {
Some(result)
} else {
None
}
}
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
let loc = (id as usize) * index.pq_code_size;
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
} }
struct Scratch { struct Scratch {
visited_adjacent: HashSet<u32>,
visited: HashSet<u32>, visited: HashSet<u32>,
neighbour_buffer: NeighbourBuffer, neighbour_buffer: NeighbourBuffer,
neighbour_pre_buffer: Vec<u32>, neighbour_pre_buffer: Vec<u32>,
visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)> visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)>
} }
struct IndexRef<'a> { struct Index {
data_file: &'a mut fs::File, data_file: fs::File,
pq_codes: &'a Mmap, pq_codes: Mmap,
header: &'a IndexHeader, header: Rc<IndexHeader>,
pq_code_size: usize pq_code_size: usize,
descriptors: Mmap,
n_descriptors: usize
} }
fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef, disable_pq: bool) -> Result<(usize, usize)> { struct DescriptorScales(Vec<f32>);
scratch.visited.clear();
fn descriptor_product(index: Rc<Index>, scales: &DescriptorScales, neighbour: u32) -> i64 {
let mut result = 0;
// effectively an extra part of the vector to dot product
for (j, d) in scales.0.iter().enumerate() {
result += scale_dot_result(d * index.descriptors[neighbour as usize * index.n_descriptors + j] as f32);
}
result
}
async fn greedy_search<'a>(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &QueryLUT, descriptor_scales: &DescriptorScales, index: Rc<Index>, disable_pq: bool, beamwidth: usize) -> Result<(usize, usize)> {
scratch.visited_adjacent.clear();
scratch.neighbour_buffer.clear(); scratch.neighbour_buffer.clear();
scratch.visited_list.clear(); scratch.visited_list.clear();
scratch.visited.clear();
let mut cmps = 0; let mut cmps = 0;
let mut pq_cmps = 0; let mut pq_cmps = 0;
let node = read_node(start, index.data_file, index.header)?; scratch.neighbour_buffer.insert(start, 0);
let vector = bytemuck::cast_slice(&node.vector); scratch.visited_adjacent.insert(start);
scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vector));
scratch.visited.insert(start);
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() { while let Some(pts) = next_several_unvisited(&mut scratch.neighbour_buffer, beamwidth) {
//println!("pt {} {:?}", pt, graph.out_neighbours(pt));
scratch.neighbour_pre_buffer.clear(); scratch.neighbour_pre_buffer.clear();
let node = read_node(pt, index.data_file, index.header)?;
let vector = bytemuck::cast_slice(&node.vector); let mut join_handles = Vec::with_capacity(pts.len());
let distance = fast_dot_noprefetch(query, &vector);
cmps += 1; for &pt in pts.iter() {
scratch.visited_list.push((pt, distance, node.url, node.shards, node.scores)); join_handles.push(monoio::spawn(read_node(pt, index.clone())));
for &neighbour in node.vertices.iter() { }
if scratch.visited.insert(neighbour) {
scratch.neighbour_pre_buffer.push(neighbour); for handle in join_handles {
let index = index.clone();
let node = handle.await?;
let vector = bytemuck::cast_slice(&node.vector);
let distance = fast_dot_noprefetch(query, &vector);
cmps += 1;
if scratch.visited.insert(node.id) {
scratch.visited_list.push((node.id, distance, node.url, node.shards, node.scores));
};
for &neighbour in node.vertices.iter() {
if scratch.visited_adjacent.insert(neighbour) {
scratch.neighbour_pre_buffer.push(neighbour);
}
} }
} let mut pq_codes = Vec::with_capacity(index.pq_code_size * scratch.neighbour_pre_buffer.len());
let mut pq_codes = Vec::with_capacity(index.pq_code_size * scratch.neighbour_pre_buffer.len()); for &neighbour in scratch.neighbour_pre_buffer.iter() {
for &neighbour in scratch.neighbour_pre_buffer.iter() { read_pq_codes(neighbour, index.clone(), &mut pq_codes);
read_pq_codes(neighbour, index.pq_codes, &mut pq_codes, index.pq_code_size); }
} let mut approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes);
let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes); for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { if disable_pq {
if disable_pq { let node = read_node(neighbour, index.clone()).await?;
//let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO let vector = bytemuck::cast_slice(&node.vector);
let node = read_node(neighbour, index.data_file, index.header)?; let mut score = fast_dot_noprefetch(query, &vector);
let vector = bytemuck::cast_slice(&node.vector); score += descriptor_product(index.clone(), &descriptor_scales, neighbour);
let distance = fast_dot_noprefetch(query, &vector); scratch.neighbour_buffer.insert(neighbour, score);
scratch.neighbour_buffer.insert(neighbour, distance); } else {
} else { approx_scores[i] += descriptor_product(index.clone(), &descriptor_scales, neighbour);
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]); scratch.neighbour_buffer.insert(neighbour, approx_scores[i]);
pq_cmps += 1; pq_cmps += 1;
PQ_COMPARISONS.inc();
}
} }
} }
} }
@ -124,17 +186,21 @@ fn summary_stats(ranks: &mut [usize]) {
const K: usize = 20; const K: usize = 20;
fn main() -> Result<()> { async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
let args: CLIArguments = argh::from_env(); let mut top_k_ranks_best_shard = vec![];
let mut top_rank_best_shard = vec![];
let mut pq_cmps = vec![];
let mut cmps = vec![];
let mut recall_total = 0;
let mut queries = vec![]; let mut queries = vec![];
if let Some(query_vector_base64) = args.query_vector_base64 { if let Some(query_vector_base64) = &args.query_vector_base64 {
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(query_vector_base64.as_bytes()).context("invalid base64")?); let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(query_vector_base64.as_bytes()).context("invalid base64")?);
queries.push(query_vector); queries.push(query_vector);
} }
if let Some(query_vector_file) = args.query_vector_file { if let Some(query_vector_file) = &args.query_vector_file {
let query_vectors = fs::read(query_vector_file)?; let query_vectors = fs::read(query_vector_file).await?;
queries.extend(common::chunk_fp16_buffer(&query_vectors).chunks(1152).map(|x| x.to_vec()).collect::<Vec<_>>()); queries.extend(common::chunk_fp16_buffer(&query_vectors).chunks(1152).map(|x| x.to_vec()).collect::<Vec<_>>());
} }
@ -142,30 +208,12 @@ fn main() -> Result<()> {
queries.truncate(n); queries.truncate(n);
} }
let index_path = PathBuf::from(&args.index_path);
let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?;
let mut data_file = fs::File::open(index_path.join("index.bin"))?;
let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin"))?;
let pq_codes = unsafe {
// This is unsafe because other processes could in principle edit the mmap'd file.
// It would be annoying to do anything about this possibility, so ignore it.
MmapOptions::new().populate().map(&pq_codes_file)?
};
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
let mut top_k_ranks_best_shard = vec![];
let mut top_rank_best_shard = vec![];
let mut pq_cmps = vec![];
let mut cmps = vec![];
let mut recall_total = 0;
for query_vector in queries.iter() { for query_vector in queries.iter() {
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>(); let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32); let query_preprocessed = index.header.quantizer.preprocess_query(&query_vector_fp32);
// TODO slightly dubious // TODO slightly dubious
let selected_shard = header.shards.iter().position_max_by_key(|x| { let selected_shard = index.header.shards.iter().position_max_by_key(|x| {
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap()) scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
}).unwrap(); }).unwrap();
@ -175,8 +223,8 @@ fn main() -> Result<()> {
let mut matches = vec![]; let mut matches = vec![];
// brute force scan // brute force scan
for i in 0..header.count { for i in 0..index.header.count {
let node = read_node(i, &mut data_file, &header)?; let node = read_node(i, index.clone()).await?;
//println!("{} {}", i, node.url); //println!("{} {}", i, node.url);
let vector = bytemuck::cast_slice(&node.vector); let vector = bytemuck::cast_slice(&node.vector);
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards)); matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
@ -192,23 +240,22 @@ fn main() -> Result<()> {
let mut top_ranks = vec![usize::MAX; K]; let mut top_ranks = vec![usize::MAX; K];
for shard in 0..header.shards.len() { for shard in 0..index.header.shards.len() {
let selected_start = header.shards[shard].1; let selected_start = index.header.shards[shard].1;
let beamwidth = 3;
let mut scratch = Scratch { let mut scratch = Scratch {
visited: HashSet::new(), visited: HashSet::new(),
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)), neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
neighbour_pre_buffer: Vec::new(), neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new() visited_list: Vec::new(),
visited_adjacent: HashSet::new()
}; };
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng); let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
data_file: &mut data_file, let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, &descriptor_scales, index.clone(), args.disable_pq, beamwidth).await?;
header: &header,
pq_codes: &pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
}, args.disable_pq)?;
// slightly dubious because this is across shards // slightly dubious because this is across shards
pq_cmps.push(cmps_result.1); pq_cmps.push(cmps_result.1);
@ -255,3 +302,240 @@ fn main() -> Result<()> {
Ok(()) Ok(())
} }
pub async fn query_clip_server<I, O>(base_url: &str, path: &str, data: Option<I>) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned {
// TODO connection pool or something
// also this won't work over TLS
let url = hyper::Uri::from_str(base_url)?;
let stream = TcpStream::connect(format!("{}:{}", url.host().unwrap(), url.port_u16().unwrap_or(80))).await?;
let io = monoio_compat::hyper::MonoioIo::new(stream.into_poll_io()?);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
monoio::spawn(async move {
if let Err(err) = conn.await {
tracing::error!("connection failed: {:?}", err);
}
});
let authority = url.authority().unwrap().clone();
let req = Request::builder()
.uri(path)
.header(hyper::header::HOST, authority.as_str())
.header(hyper::header::CONTENT_TYPE, "application/msgpack");
let res = match data {
Some(data) => sender.send_request(req.method(Method::POST).body(Full::new(Bytes::from(rmp_serde::to_vec_named(&data)?)))?).await?,
None => sender.send_request(req.method(Method::GET).body(Full::new(Bytes::from("")))?).await?
};
if res.status() != StatusCode::OK {
return Err(anyhow::anyhow!("unexpected status code: {}", res.status()));
}
let data = res.collect().await?.to_bytes();
let result: O = rmp_serde::from_slice(&data)?;
Ok(result)
}
#[derive(Clone)]
struct Service {
index: Rc<Index>,
inference_server_config: Rc<InferenceServerConfig>,
args: Rc<CLIArguments>
}
impl hyper::service::Service<Request<Incoming>> for Service {
type Response = Response<Full<Bytes>>;
type Error = anyhow::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let index = self.index.clone();
let args = self.args.clone();
let inference_server_config = self.inference_server_config.clone();
Box::pin(async move {
let mut body = match (req.method(), req.uri().path()) {
(&Method::GET, "/") => Response::new(Full::new(Bytes::from(serde_json::to_vec(&FrontendInit {
n_total: (index.header.count - index.header.dead_count) as u64,
d_emb: index.header.quantizer.n_dims,
predefined_embedding_names: vec![]
})?))),
(&Method::POST, "/") => {
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
if upper > 1<<23 {
let mut resp = Response::new(Full::new(Bytes::from("Body too big")));
*resp.status_mut() = hyper::StatusCode::PAYLOAD_TOO_LARGE;
return Ok(resp);
}
let whole_body = req.collect().await?.to_bytes();
let body: QueryRequest = serde_json::from_slice(&whole_body)?;
let query = common::get_total_embedding(
&body.terms,
&*inference_server_config,
|batch, _config| {
query_clip_server(args.clip_server.as_ref().unwrap(), "/", Some(batch))
},
|image, config| async move {
let image = image::load_from_memory(&image)?;
Ok(serde_bytes::ByteBuf::from(resize_for_embed_sync(&*config, image)?))
},
&std::collections::HashMap::new(),
inference_server_config.clone(),
()
).await?;
let selected_shard = index.header.shards.iter().position_max_by_key(|x| {
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query).unwrap())
}).unwrap();
let selected_start = index.header.shards[selected_shard].1;
let beamwidth = 3;
let mut scratch = Scratch {
visited: HashSet::new(),
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new(),
visited_adjacent: HashSet::new()
};
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
let query_preprocessed = index.header.quantizer.preprocess_query(&query);
let query = query.iter().map(|x| half::f16::from_f32(*x)).collect::<Vec<f16>>();
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), args.disable_pq, beamwidth).await?;
scratch.visited_list.sort_by_key(|x| -x.1);
let matches = scratch.visited_list.drain(..).map(|(id, score, url, shards, scores)| (score as f32, url, String::new(), 0, None)).collect::<Vec<_>>();
let result = QueryResult {
formats: vec![],
extensions: HashMap::new(),
matches
};
let result = serde_json::to_vec(&result)?;
Response::new(Full::new(Bytes::from(result)))
},
(&Method::GET, "/metrics") => {
let mut buffer = Vec::new();
let encoder = prometheus::TextEncoder::new();
let metric_families = prometheus::gather();
encoder.encode(&metric_families, &mut buffer).unwrap();
Response::builder()
.header(hyper::header::CONTENT_TYPE, "text/plain; version=0.0.4")
.body(Full::new(Bytes::from(buffer))).unwrap()
},
(&Method::OPTIONS, "/") => {
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Full::new(Bytes::from(""))).unwrap()
},
_ => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found")))
.unwrap()
};
body.headers_mut().entry(hyper::header::CONTENT_TYPE).or_insert(hyper::header::HeaderValue::from_static("application/json"));
body.headers_mut().entry(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN).or_insert(hyper::header::HeaderValue::from_static("*"));
body.headers_mut().entry(hyper::header::ACCESS_CONTROL_ALLOW_METHODS).or_insert(hyper::header::HeaderValue::from_static("GET, POST, OPTIONS"));
body.headers_mut().entry(hyper::header::ACCESS_CONTROL_ALLOW_HEADERS).or_insert(hyper::header::HeaderValue::from_static("Content-Type"));
Result::<_, anyhow::Error>::Ok(body)
})
}
}
async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceServerConfig> {
loop {
match query_clip_server(clip_server.as_ref().unwrap(), "/config", Option::<()>::None).await {
Ok(config) => return Ok(config),
Err(err) => {
tracing::warn!("waiting for clip server: {}", err);
monoio::time::sleep(std::time::Duration::from_secs(1)).await;
}
};
}
}
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
let service = Service {
index,
inference_server_config: Rc::new(get_backend_config(&args.clip_server).await?),
args: Rc::new(args.clone())
};
let listener = TcpListener::bind(args.listen_address.as_ref().unwrap())?;
println!("Listening");
loop {
let (stream, _) = listener.accept().await?;
let stream_poll = monoio_compat::hyper::MonoioIo::new(stream.into_poll_io()?);
let service = service.clone();
monoio::spawn(async move {
// Handle the connection from the client using HTTP1 and pass any
// HTTP requests received on that connection to the `hello` function
if let Err(err) = http1::Builder::new()
.timer(monoio_compat::hyper::MonoioTimer)
.serve_connection(stream_poll, service)
.await
{
println!("Error serving connection: {:?}", err);
}
});
}
}
#[monoio::main(threads=1, enable_timer=true)]
async fn main() -> Result<()> {
let args: CLIArguments = argh::from_env();
let index_path = PathBuf::from(&args.index_path);
let header: IndexHeader = rmp_serde::from_slice(&fs::read(index_path.join("index.msgpack")).await?)?;
let header = Rc::new(header);
// contains graph structure, full-precision vectors, and bulk metadata
let data_file = fs::File::open(index_path.join("index.bin")).await?;
// contains product quantization codes
let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin")).await?;
let pq_codes = unsafe {
// This is unsafe because other processes could in principle edit the mmap'd file.
// It would be annoying to do anything about this possibility, so ignore it.
MmapOptions::new().populate().map(&pq_codes_file)?
};
// contains metadata descriptors
let descriptors_file = fs::File::open(index_path.join("index.descriptor-codes.bin")).await?;
let descriptors = unsafe {
MmapOptions::new().populate().map(&descriptors_file)?
};
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
let index = Rc::new(Index {
data_file,
header: header.clone(),
pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
descriptors,
n_descriptors: header.descriptor_cdfs.len(),
});
if args.listen_address.is_some() {
serve(&args, index).await?;
} else {
evaluate(&args, index).await?;
}
Ok(())
}