mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-07-25 21:32:52 +00:00
query code
This commit is contained in:
parent
63caba2746
commit
a5a6e960bb
120
Cargo.lock
generated
120
Cargo.lock
generated
@ -194,6 +194,17 @@ version = "1.1.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "auto-const-array"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "62f7df18977a1ee03650ee4b31b4aefed6d56bac188760b6e37610400fe8d4bb"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.79",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
@ -1190,6 +1201,15 @@ dependencies = [
|
|||||||
"slab",
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fxhash"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gemm"
|
name = "gemm"
|
||||||
version = "0.17.1"
|
version = "0.17.1"
|
||||||
@ -1747,6 +1767,16 @@ dependencies = [
|
|||||||
"syn 2.0.79",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "io-uring"
|
||||||
|
version = "0.6.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "595a0399f411a508feb2ec1e970a4a30c249351e30208960d58298de8660b0e5"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipnet"
|
name = "ipnet"
|
||||||
version = "2.10.1"
|
version = "2.10.1"
|
||||||
@ -1956,10 +1986,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
|
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"num_cpus",
|
|
||||||
"once_cell",
|
|
||||||
"rawpointer",
|
"rawpointer",
|
||||||
"thread-tree",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2034,14 +2061,17 @@ dependencies = [
|
|||||||
"futures-util",
|
"futures-util",
|
||||||
"half",
|
"half",
|
||||||
"hamming",
|
"hamming",
|
||||||
|
"http-body-util",
|
||||||
|
"hyper",
|
||||||
"image",
|
"image",
|
||||||
"itertools 0.13.0",
|
"itertools 0.13.0",
|
||||||
"json5",
|
"json5",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"matrixmultiply",
|
|
||||||
"maud",
|
"maud",
|
||||||
"memmap2",
|
"memmap2",
|
||||||
"mimalloc",
|
"mimalloc",
|
||||||
|
"monoio",
|
||||||
|
"monoio-compat",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
"prometheus",
|
"prometheus",
|
||||||
@ -2076,6 +2106,15 @@ dependencies = [
|
|||||||
"stable_deref_trait",
|
"stable_deref_trait",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memoffset"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mimalloc"
|
name = "mimalloc"
|
||||||
version = "0.1.43"
|
version = "0.1.43"
|
||||||
@ -2150,6 +2189,51 @@ dependencies = [
|
|||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "monoio"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3bd0f8bcde87b1949f95338b547543fcab187bc7e7a5024247e359a5e828ba6a"
|
||||||
|
dependencies = [
|
||||||
|
"auto-const-array",
|
||||||
|
"bytes",
|
||||||
|
"fxhash",
|
||||||
|
"io-uring",
|
||||||
|
"libc",
|
||||||
|
"memchr",
|
||||||
|
"mio 0.8.11",
|
||||||
|
"monoio-macros",
|
||||||
|
"nix",
|
||||||
|
"pin-project-lite",
|
||||||
|
"socket2",
|
||||||
|
"tokio",
|
||||||
|
"windows-sys 0.48.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "monoio-compat"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "aab377095e8792bb9829fac13c5cb111a0af4b733f4821581ca3f6e1eec5ae6c"
|
||||||
|
dependencies = [
|
||||||
|
"hyper",
|
||||||
|
"monoio",
|
||||||
|
"pin-project-lite",
|
||||||
|
"reusable-box-future",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "monoio-macros"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "176a5f5e69613d9e88337cf2a65e11135332b4efbcc628404a7c555e4452084c"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.79",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mp4parse"
|
name = "mp4parse"
|
||||||
version = "0.17.0"
|
version = "0.17.0"
|
||||||
@ -2209,6 +2293,19 @@ version = "1.0.6"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nix"
|
||||||
|
version = "0.26.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"memoffset",
|
||||||
|
"pin-utils",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nom"
|
name = "nom"
|
||||||
version = "7.1.3"
|
version = "7.1.3"
|
||||||
@ -2982,6 +3079,12 @@ dependencies = [
|
|||||||
"windows-registry",
|
"windows-registry",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "reusable-box-future"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e0e61cd21fbddd85fbd9367b775660a01d388c08a61c6d2824af480b0309bb9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rgb"
|
name = "rgb"
|
||||||
version = "0.8.50"
|
version = "0.8.50"
|
||||||
@ -3809,15 +3912,6 @@ dependencies = [
|
|||||||
"syn 2.0.79",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "thread-tree"
|
|
||||||
version = "0.3.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630"
|
|
||||||
dependencies = [
|
|
||||||
"crossbeam-channel",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thread_local"
|
name = "thread_local"
|
||||||
version = "1.1.8"
|
version = "1.1.8"
|
||||||
|
@ -57,8 +57,11 @@ bitcode = "0.6"
|
|||||||
simsimd = "6"
|
simsimd = "6"
|
||||||
foldhash = "0.1"
|
foldhash = "0.1"
|
||||||
memmap2 = "0.9"
|
memmap2 = "0.9"
|
||||||
matrixmultiply = { version = "0.3", features = ["threading"] }
|
|
||||||
candle-core = "0.8"
|
candle-core = "0.8"
|
||||||
|
monoio = "0.2"
|
||||||
|
hyper = "1"
|
||||||
|
monoio-compat = { version = "0.2", features = ["hyper"] }
|
||||||
|
http-body-util = "0.1"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "reddit-dump"
|
name = "reddit-dump"
|
||||||
|
@ -9,12 +9,18 @@ import time
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import sys
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import msgpack
|
||||||
|
|
||||||
from model import Config, BradleyTerry
|
from model import Config, BradleyTerry
|
||||||
import shared
|
import shared
|
||||||
|
|
||||||
batch_size = 128
|
def fetch_files_with_timestamps():
|
||||||
num_pairs = batch_size * 1024
|
csr = shared.db.execute("SELECT filename, embedding, timestamp FROM files WHERE embedding IS NOT NULL")
|
||||||
|
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy(), row[2]) for row in csr.fetchall() ]
|
||||||
|
csr.close()
|
||||||
|
return x
|
||||||
|
|
||||||
|
batch_size = 2048
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
@ -27,36 +33,42 @@ config = Config(
|
|||||||
dropout=0.1
|
dropout=0.1
|
||||||
)
|
)
|
||||||
model = BradleyTerry(config)
|
model = BradleyTerry(config)
|
||||||
|
model.eval()
|
||||||
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
||||||
model.load_state_dict(torch.load(modelc))
|
model.load_state_dict(torch.load(modelc))
|
||||||
params = sum(p.numel() for p in model.parameters())
|
params = sum(p.numel() for p in model.parameters())
|
||||||
print(f"{params/1e6:.1f}M parameters")
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
files = shared.fetch_all_files()
|
for x in model.ensemble.models:
|
||||||
results = {}
|
x.output.bias.data.fill_(0) # hack to match behaviour of cut-down implementation
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
files = fetch_files_with_timestamps()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for bstart in tqdm(range(0, len(files), batch_size)):
|
for bstart in tqdm(range(0, len(files), batch_size)):
|
||||||
batch = files[bstart:bstart + batch_size]
|
batch = files[bstart:bstart + batch_size]
|
||||||
filenames = [ f1 for f1, e1 in batch ]
|
timestamps = [ t1 for f1, e1, t1 in batch ]
|
||||||
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ])
|
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1, t1 in batch ])
|
||||||
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
||||||
scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy()
|
scores = model.ensemble(inputs).mean(dim=0).cpu().numpy()
|
||||||
|
for sr in scores:
|
||||||
|
for i, s in enumerate(sr):
|
||||||
|
results[i].append(s)
|
||||||
|
# add an extra timestamp channel
|
||||||
|
results[config.output_channels].extend(timestamps)
|
||||||
|
|
||||||
|
cdfs = []
|
||||||
|
# we want to encode scores in one byte, and 255/0xFF is reserved for "greater than maximum bucket"
|
||||||
|
cdf_bins = 255
|
||||||
|
for i, s in results.items():
|
||||||
|
quantiles = numpy.linspace(0, 1, cdf_bins)
|
||||||
|
cdf = numpy.quantile(numpy.array(s), quantiles)
|
||||||
|
print(cdf)
|
||||||
|
cdfs.append(cdf.tolist())
|
||||||
|
|
||||||
channel = int(sys.argv[2])
|
with open("cdfs.msgpack", "wb") as f:
|
||||||
percentile = float(sys.argv[3])
|
msgpack.pack(cdfs, f)
|
||||||
output_pairs = int(sys.argv[4])
|
|
||||||
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
|
|
||||||
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
|
|
||||||
select_from = top[:int(len(top) * percentile)]
|
|
||||||
|
|
||||||
out = []
|
|
||||||
for _ in range(output_pairs):
|
|
||||||
# dummy score for compatibility with existing code
|
|
||||||
out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0))
|
|
||||||
|
|
||||||
with open("top.json", "w") as f:
|
|
||||||
json.dump(out, f)
|
|
||||||
|
@ -10,11 +10,12 @@ CREATE TABLE IF NOT EXISTS files (
|
|||||||
title TEXT NOT NULL,
|
title TEXT NOT NULL,
|
||||||
link TEXT NOT NULL,
|
link TEXT NOT NULL,
|
||||||
embedding BLOB NOT NULL,
|
embedding BLOB NOT NULL,
|
||||||
|
timestamp INTEGER NOT NULL,
|
||||||
UNIQUE (filename)
|
UNIQUE (filename)
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
|
|
||||||
with jsonlines.open("sample.jsonl") as reader:
|
with jsonlines.open("sample.jsonl") as reader:
|
||||||
for obj in reader:
|
for obj in reader:
|
||||||
shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes())))
|
shared.db.execute("INSERT OR REPLACE INTO files (filename, title, link, embedding, timestamp) VALUES (?, ?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes()), obj["timestamp"]))
|
||||||
shared.db.commit()
|
shared.db.commit()
|
||||||
|
110
src/common.rs
110
src/common.rs
@ -1,4 +1,5 @@
|
|||||||
use image::codecs::bmp::BmpEncoder;
|
use image::codecs::bmp::BmpEncoder;
|
||||||
|
use lazy_static::lazy_static;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
@ -10,6 +11,11 @@ use tracing::instrument;
|
|||||||
use fast_image_resize::{Resizer, ResizeOptions, ResizeAlg};
|
use fast_image_resize::{Resizer, ResizeOptions, ResizeAlg};
|
||||||
use fast_image_resize::images::{Image, ImageRef};
|
use fast_image_resize::images::{Image, ImageRef};
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use ndarray::ArrayBase;
|
||||||
|
use prometheus::{register_int_counter_vec, IntCounterVec};
|
||||||
|
use base64::prelude::*;
|
||||||
|
use std::future::Future;
|
||||||
|
|
||||||
std::thread_local! {
|
std::thread_local! {
|
||||||
static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
|
static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
|
||||||
@ -22,7 +28,7 @@ pub struct InferenceServerConfig {
|
|||||||
pub embedding_size: usize,
|
pub embedding_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: &InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
||||||
// the model currently in use wants aspect ratio 1:1 regardless of input
|
// the model currently in use wants aspect ratio 1:1 regardless of input
|
||||||
// I think this was previously being handled in the CLIP server but that is slightly lossy
|
// I think this was previously being handled in the CLIP server but that is slightly lossy
|
||||||
|
|
||||||
@ -48,7 +54,7 @@ pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: I
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn resize_for_embed<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
pub async fn resize_for_embed<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
||||||
let resized = tokio::task::spawn_blocking(move || resize_for_embed_sync(config, image)).await??;
|
let resized = tokio::task::spawn_blocking(move || resize_for_embed_sync(&config, image)).await??;
|
||||||
Ok(resized)
|
Ok(resized)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,5 +169,103 @@ pub struct IndexHeader {
|
|||||||
pub count: u32,
|
pub count: u32,
|
||||||
pub dead_count: u32,
|
pub dead_count: u32,
|
||||||
pub record_pad_size: usize,
|
pub record_pad_size: usize,
|
||||||
pub quantizer: diskann::vector::ProductQuantizer
|
pub quantizer: diskann::vector::ProductQuantizer,
|
||||||
|
pub descriptor_cdfs: Vec<Vec<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct FrontendInit {
|
||||||
|
pub n_total: u64,
|
||||||
|
pub predefined_embedding_names: Vec<String>,
|
||||||
|
pub d_emb: usize
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type EmbeddingVector = Vec<f32>;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct QueryResult {
|
||||||
|
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>)>,
|
||||||
|
pub formats: Vec<String>,
|
||||||
|
pub extensions: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct QueryTerm {
|
||||||
|
pub embedding: Option<EmbeddingVector>,
|
||||||
|
pub image: Option<String>,
|
||||||
|
pub text: Option<String>,
|
||||||
|
pub predefined_embedding: Option<String>,
|
||||||
|
pub weight: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct QueryRequest {
|
||||||
|
pub terms: Vec<QueryTerm>,
|
||||||
|
pub k: Option<usize>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub include_video: bool
|
||||||
|
}
|
||||||
|
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref TERMS_COUNTER: IntCounterVec = register_int_counter_vec!("mse_terms", "terms used in queries, by type", &["type"]).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_total_embedding<A: Future<Output = Result<Vec<Vec<u8>>>>, B: Future<Output = Result<serde_bytes::ByteBuf>>, S: Clone, T: Clone, F: Fn(EmbeddingRequest, S) -> A, G: Fn(Vec<u8>, T) -> B>(terms: &Vec<QueryTerm>, ic: &InferenceServerConfig, query_server: F, resize_image: G, predefined_embeddings: &HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<[usize; 1]>>>, image_state: T, query_state: S) -> Result<Vec<f32>> {
|
||||||
|
let mut total_embedding = ndarray::Array::from(vec![0.0; ic.embedding_size]);
|
||||||
|
|
||||||
|
let mut image_batch = Vec::new();
|
||||||
|
let mut image_weights = Vec::new();
|
||||||
|
let mut text_batch = Vec::new();
|
||||||
|
let mut text_weights = Vec::new();
|
||||||
|
|
||||||
|
for term in terms {
|
||||||
|
if let Some(image) = &term.image {
|
||||||
|
TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc();
|
||||||
|
let bytes = BASE64_STANDARD.decode(image)?;
|
||||||
|
image_batch.push(resize_image(bytes, image_state.clone()).await?);
|
||||||
|
image_weights.push(term.weight.unwrap_or(1.0));
|
||||||
|
}
|
||||||
|
if let Some(text) = &term.text {
|
||||||
|
TERMS_COUNTER.get_metric_with_label_values(&["text"]).unwrap().inc();
|
||||||
|
text_batch.push(text.clone());
|
||||||
|
text_weights.push(term.weight.unwrap_or(1.0));
|
||||||
|
}
|
||||||
|
if let Some(embedding) = &term.embedding {
|
||||||
|
TERMS_COUNTER.get_metric_with_label_values(&["embedding"]).unwrap().inc();
|
||||||
|
let weight = term.weight.unwrap_or(1.0);
|
||||||
|
for (i, value) in embedding.iter().enumerate() {
|
||||||
|
total_embedding[i] += value * weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(name) = &term.predefined_embedding {
|
||||||
|
let embedding = predefined_embeddings.get(name).context("name invalid")?;
|
||||||
|
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut batches = vec![];
|
||||||
|
|
||||||
|
if !image_batch.is_empty() {
|
||||||
|
batches.push(
|
||||||
|
(EmbeddingRequest::Images {
|
||||||
|
images: image_batch
|
||||||
|
}, image_weights)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !text_batch.is_empty() {
|
||||||
|
batches.push(
|
||||||
|
(EmbeddingRequest::Text {
|
||||||
|
text: text_batch,
|
||||||
|
}, text_weights)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (batch, weights) in batches {
|
||||||
|
let embs: Vec<Vec<u8>> = query_server(batch, query_state.clone()).await?;
|
||||||
|
for (emb, weight) in embs.into_iter().zip(weights) {
|
||||||
|
total_embedding += &(ndarray::Array::from_vec(decode_fp16_buffer(&emb)) * weight);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(total_embedding.to_vec())
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,9 @@ struct CLIArguments {
|
|||||||
#[argh(option, short='M', description="score model path")]
|
#[argh(option, short='M', description="score model path")]
|
||||||
score_model: Option<String>,
|
score_model: Option<String>,
|
||||||
#[argh(option, short='G', description="GPU (CUDA) device to use")]
|
#[argh(option, short='G', description="GPU (CUDA) device to use")]
|
||||||
gpu: Option<usize>
|
gpu: Option<usize>,
|
||||||
|
#[argh(option, description="descriptor CDFs")]
|
||||||
|
cdfs: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||||
@ -301,6 +303,13 @@ fn main() -> Result<()> {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let cdfs = if let Some(cdfs) = &args.cdfs {
|
||||||
|
let data = fs::read(cdfs).context("read cdfs")?;
|
||||||
|
Some(rmp_serde::from_read::<_, Vec<Vec<f32>>>(&data[..]).context("decode cdfs")?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let mut output_file = args.output_embeddings.map(|x| fs::File::create(x).context("create output file")).transpose()?;
|
let mut output_file = args.output_embeddings.map(|x| fs::File::create(x).context("create output file")).transpose()?;
|
||||||
|
|
||||||
let mut i: u64 = 0;
|
let mut i: u64 = 0;
|
||||||
@ -436,17 +445,33 @@ fn main() -> Result<()> {
|
|||||||
let codes = quantizer.quantize_batch(&batch_embeddings);
|
let codes = quantizer.quantize_batch(&batch_embeddings);
|
||||||
|
|
||||||
let score_model = score_model.as_ref().context("score model needed to output index")?;
|
let score_model = score_model.as_ref().context("score model needed to output index")?;
|
||||||
|
let cdfs = cdfs.as_ref().context("score model CDFs needed to output index")?;
|
||||||
let scores = score_model.score_batch(&batch_embeddings)?;
|
let scores = score_model.score_batch(&batch_embeddings)?;
|
||||||
|
|
||||||
for (i, (x, _embedding)) in batch.into_iter().enumerate() {
|
for (i, (x, _embedding)) in batch.into_iter().enumerate() {
|
||||||
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
|
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
|
||||||
|
|
||||||
|
let mut entry_scores = scores[i..(i + score_model.output_channels)].to_vec();
|
||||||
|
|
||||||
|
entry_scores.push(x.timestamp as f32); // seconds since epoch, so precision issues aren't awful
|
||||||
|
|
||||||
|
for (index, score) in entry_scores.iter().enumerate() {
|
||||||
|
// binary search CDF to invert
|
||||||
|
let cdf_bucket: u8 = match cdfs[index].binary_search_by(|x| x.partial_cmp(score).unwrap()) {
|
||||||
|
Ok(x) => x.try_into().unwrap(),
|
||||||
|
Err(x) => x.try_into().unwrap()
|
||||||
|
};
|
||||||
|
// write score descriptor to descriptors file
|
||||||
|
index_output_file.2.write_all(&[cdf_bucket])?;
|
||||||
|
}
|
||||||
|
|
||||||
let mut entry = PackedIndexEntry {
|
let mut entry = PackedIndexEntry {
|
||||||
id: count + i as u32,
|
id: count + i as u32,
|
||||||
vertices,
|
vertices,
|
||||||
vector: x.embedding.chunks_exact(2).map(|x| u16::from_le_bytes([x[0], x[1]])).collect(),
|
vector: x.embedding.chunks_exact(2).map(|x| u16::from_le_bytes([x[0], x[1]])).collect(),
|
||||||
timestamp: x.timestamp,
|
timestamp: x.timestamp,
|
||||||
dimensions: x.metadata.dimension,
|
dimensions: x.metadata.dimension,
|
||||||
scores: scores[i..(i + score_model.output_channels)].to_vec(),
|
scores: entry_scores,
|
||||||
url: x.metadata.final_url,
|
url: x.metadata.final_url,
|
||||||
shards
|
shards
|
||||||
};
|
};
|
||||||
@ -504,7 +529,8 @@ fn main() -> Result<()> {
|
|||||||
count: count as u32,
|
count: count as u32,
|
||||||
record_pad_size: RECORD_PAD_SIZE,
|
record_pad_size: RECORD_PAD_SIZE,
|
||||||
dead_count,
|
dead_count,
|
||||||
quantizer: pq_codec.unwrap()
|
quantizer: pq_codec.unwrap(),
|
||||||
|
descriptor_cdfs: cdfs.unwrap(),
|
||||||
};
|
};
|
||||||
file.write_all(rmp_serde::to_vec_named(&header)?.as_slice())?;
|
file.write_all(rmp_serde::to_vec_named(&header)?.as_slice())?;
|
||||||
}
|
}
|
||||||
|
128
src/main.rs
128
src/main.rs
@ -13,7 +13,7 @@ use axum::{
|
|||||||
Router,
|
Router,
|
||||||
http::StatusCode
|
http::StatusCode
|
||||||
};
|
};
|
||||||
use common::resize_for_embed_sync;
|
use common::{resize_for_embed_sync, FrontendInit};
|
||||||
use compact_str::CompactString;
|
use compact_str::CompactString;
|
||||||
use image::RgbImage;
|
use image::RgbImage;
|
||||||
use image::{imageops::FilterType, ImageReader, DynamicImage, ImageFormat};
|
use image::{imageops::FilterType, ImageReader, DynamicImage, ImageFormat};
|
||||||
@ -24,7 +24,6 @@ use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
|
|||||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
use walkdir::WalkDir;
|
use walkdir::WalkDir;
|
||||||
use base64::prelude::*;
|
|
||||||
use faiss::{ConcurrentIndex, Index};
|
use faiss::{ConcurrentIndex, Index};
|
||||||
use futures_util::stream::{StreamExt, TryStreamExt};
|
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
@ -32,20 +31,19 @@ use tower_http::cors::CorsLayer;
|
|||||||
use faiss::index::scalar_quantizer;
|
use faiss::index::scalar_quantizer;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
||||||
use ndarray::ArrayBase;
|
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine};
|
||||||
|
|
||||||
mod ocr;
|
mod ocr;
|
||||||
mod common;
|
mod common;
|
||||||
mod video_reader;
|
mod video_reader;
|
||||||
|
|
||||||
use crate::ocr::scan_image;
|
use crate::ocr::scan_image;
|
||||||
use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server, decode_fp16_buffer};
|
use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server, decode_fp16_buffer, QueryRequest, QueryResult, EmbeddingVector};
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap();
|
static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap();
|
||||||
static ref QUERIES_COUNTER: IntCounter = register_int_counter!("mse_queries", "queries executed").unwrap();
|
static ref QUERIES_COUNTER: IntCounter = register_int_counter!("mse_queries", "queries executed").unwrap();
|
||||||
static ref TERMS_COUNTER: IntCounterVec = register_int_counter_vec!("mse_terms", "terms used in queries, by type", &["type"]).unwrap();
|
|
||||||
static ref IMAGES_LOADED_COUNTER: IntCounter = register_int_counter!("mse_loads", "images loaded by ingest process").unwrap();
|
static ref IMAGES_LOADED_COUNTER: IntCounter = register_int_counter!("mse_loads", "images loaded by ingest process").unwrap();
|
||||||
static ref IMAGES_LOADED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_load_errors", "image load fails by ingest process").unwrap();
|
static ref IMAGES_LOADED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_load_errors", "image load fails by ingest process").unwrap();
|
||||||
static ref VIDEOS_LOADED_COUNTER: IntCounter = register_int_counter!("mse_video_loads", "video loaded by ingest process").unwrap();
|
static ref VIDEOS_LOADED_COUNTER: IntCounter = register_int_counter!("mse_video_loads", "video loaded by ingest process").unwrap();
|
||||||
@ -81,6 +79,13 @@ struct Config {
|
|||||||
video_frame_interval: f32
|
video_frame_interval: f32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct WConfig {
|
||||||
|
backend: InferenceServerConfig,
|
||||||
|
service: Config,
|
||||||
|
predefined_embeddings: HashMap<String, ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<[usize; 1]>>>
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct IIndex {
|
struct IIndex {
|
||||||
vectors: scalar_quantizer::ScalarQuantizerIndexImpl,
|
vectors: scalar_quantizer::ScalarQuantizerIndexImpl,
|
||||||
@ -147,13 +152,6 @@ struct FileRecord {
|
|||||||
needs_metadata: bool
|
needs_metadata: bool
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct WConfig {
|
|
||||||
backend: InferenceServerConfig,
|
|
||||||
service: Config,
|
|
||||||
predefined_embeddings: HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::prelude::Dim<[usize; 1]>>>
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct LoadedImage {
|
struct LoadedImage {
|
||||||
image: Arc<DynamicImage>,
|
image: Arc<DynamicImage>,
|
||||||
@ -387,7 +385,7 @@ async fn load_image(record: FileRecord, to_embed_tx: mpsc::Sender<EmbeddingInput
|
|||||||
let mut last_metadata = None;
|
let mut last_metadata = None;
|
||||||
let callback = |frame: RgbImage| {
|
let callback = |frame: RgbImage| {
|
||||||
let frame: Arc<DynamicImage> = Arc::new(frame.into());
|
let frame: Arc<DynamicImage> = Arc::new(frame.into());
|
||||||
let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?;
|
let embed_buf = resize_for_embed_sync(&config.backend, frame.clone())?;
|
||||||
let filename = Filename::VideoFrame(filename.clone(), i);
|
let filename = Filename::VideoFrame(filename.clone(), i);
|
||||||
to_embed_tx.blocking_send(EmbeddingInput {
|
to_embed_tx.blocking_send(EmbeddingInput {
|
||||||
image: embed_buf,
|
image: embed_buf,
|
||||||
@ -893,32 +891,6 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
|
|||||||
Ok(index)
|
Ok(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingVector = Vec<f32>;
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
struct QueryResult {
|
|
||||||
matches: Vec<(f32, String, String, u64, Option<(u32, u32)>)>,
|
|
||||||
formats: Vec<String>,
|
|
||||||
extensions: HashMap<String, String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct QueryTerm {
|
|
||||||
embedding: Option<EmbeddingVector>,
|
|
||||||
image: Option<String>,
|
|
||||||
text: Option<String>,
|
|
||||||
predefined_embedding: Option<String>,
|
|
||||||
weight: Option<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct QueryRequest {
|
|
||||||
terms: Vec<QueryTerm>,
|
|
||||||
k: Option<usize>,
|
|
||||||
#[serde(default)]
|
|
||||||
include_video: bool
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip(index))]
|
#[instrument(skip(index))]
|
||||||
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
|
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
|
||||||
let result = index.vectors.search(&query, k as usize)?;
|
let result = index.vectors.search(&query, k as usize)?;
|
||||||
@ -958,65 +930,22 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
|
|||||||
|
|
||||||
#[instrument(skip(config, client, index))]
|
#[instrument(skip(config, client, index))]
|
||||||
async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
|
async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
|
||||||
let mut total_embedding = ndarray::Array::from(vec![0.0; config.backend.embedding_size]);
|
let embedding = common::get_total_embedding(
|
||||||
|
&req.terms,
|
||||||
let mut image_batch = Vec::new();
|
&config.backend,
|
||||||
let mut image_weights = Vec::new();
|
|batch, (config, client)| async move {
|
||||||
let mut text_batch = Vec::new();
|
query_clip_server(&client, &config.service.clip_server, "", batch).await
|
||||||
let mut text_weights = Vec::new();
|
},
|
||||||
|
|image, config| async move {
|
||||||
for term in &req.terms {
|
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&image))?);
|
||||||
if let Some(image) = &term.image {
|
Ok(serde_bytes::ByteBuf::from(resize_for_embed(config.backend.clone(), image).await?))
|
||||||
TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc();
|
},
|
||||||
let bytes = BASE64_STANDARD.decode(image)?;
|
&config.clone().predefined_embeddings,
|
||||||
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?);
|
config.clone(),
|
||||||
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.backend.clone(), image).await?));
|
(config.clone(), client.clone())).await?;
|
||||||
image_weights.push(term.weight.unwrap_or(1.0));
|
|
||||||
}
|
|
||||||
if let Some(text) = &term.text {
|
|
||||||
TERMS_COUNTER.get_metric_with_label_values(&["text"]).unwrap().inc();
|
|
||||||
text_batch.push(text.clone());
|
|
||||||
text_weights.push(term.weight.unwrap_or(1.0));
|
|
||||||
}
|
|
||||||
if let Some(embedding) = &term.embedding {
|
|
||||||
TERMS_COUNTER.get_metric_with_label_values(&["embedding"]).unwrap().inc();
|
|
||||||
let weight = term.weight.unwrap_or(1.0);
|
|
||||||
for (i, value) in embedding.iter().enumerate() {
|
|
||||||
total_embedding[i] += value * weight;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some(name) = &term.predefined_embedding {
|
|
||||||
let embedding = config.predefined_embeddings.get(name).context("name invalid")?;
|
|
||||||
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut batches = vec![];
|
|
||||||
|
|
||||||
if !image_batch.is_empty() {
|
|
||||||
batches.push(
|
|
||||||
(EmbeddingRequest::Images {
|
|
||||||
images: image_batch
|
|
||||||
}, image_weights)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if !text_batch.is_empty() {
|
|
||||||
batches.push(
|
|
||||||
(EmbeddingRequest::Text {
|
|
||||||
text: text_batch,
|
|
||||||
}, text_weights)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (batch, weights) in batches {
|
|
||||||
let embs: Vec<Vec<u8>> = query_clip_server(&client, &config.service.clip_server, "/", batch).await?;
|
|
||||||
for (emb, weight) in embs.into_iter().zip(weights) {
|
|
||||||
total_embedding += &(ndarray::Array::from_vec(decode_fp16_buffer(&emb)) * weight);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let k = req.k.unwrap_or(1000);
|
let k = req.k.unwrap_or(1000);
|
||||||
let qres = query_index(index, total_embedding.to_vec(), k, req.include_video).await?;
|
let qres = query_index(index, embedding, k, req.include_video).await?;
|
||||||
|
|
||||||
let mut extensions = HashMap::new();
|
let mut extensions = HashMap::new();
|
||||||
for (k, v) in image_formats(&config.service) {
|
for (k, v) in image_formats(&config.service) {
|
||||||
@ -1030,13 +959,6 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
|||||||
}).into_response())
|
}).into_response())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
struct FrontendInit {
|
|
||||||
n_total: u64,
|
|
||||||
predefined_embedding_names: Vec<String>,
|
|
||||||
d_emb: usize
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
console_subscriber::init();
|
console_subscriber::init();
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,25 +1,32 @@
|
|||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use diskann::vector::scale_dot_result_f64;
|
use lazy_static::lazy_static;
|
||||||
use serde::{Serialize, Deserialize};
|
use monoio::fs;
|
||||||
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
|
|
||||||
use std::os::unix::prelude::FileExt;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::fs;
|
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use argh::FromArgs;
|
use argh::FromArgs;
|
||||||
use chrono::{TimeZone, Utc, DateTime};
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use foldhash::{HashSet, HashSetExt};
|
use foldhash::{HashSet, HashSetExt};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, ProductQuantizer, DistanceLUT, scale_dot_result}};
|
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64}};
|
||||||
use simsimd::SpatialSimilarity;
|
use simsimd::SpatialSimilarity;
|
||||||
use memmap2::{Mmap, MmapOptions};
|
use memmap2::{Mmap, MmapOptions};
|
||||||
|
use std::rc::Rc;
|
||||||
|
use monoio::net::{TcpListener, TcpStream};
|
||||||
|
use monoio::io::IntoPollIo;
|
||||||
|
use hyper::{body::{Body, Bytes, Incoming, Frame}, server::conn::http1, Method, Request, Response, StatusCode};
|
||||||
|
use http_body_util::{BodyExt, Empty, Full};
|
||||||
|
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::future::Future;
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::str::FromStr;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use common::{PackedIndexEntry, IndexHeader};
|
use common::{resize_for_embed_sync, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
|
||||||
|
|
||||||
#[derive(FromArgs)]
|
#[derive(FromArgs, Clone)]
|
||||||
#[argh(description="Query disk index")]
|
#[argh(description="Query disk index")]
|
||||||
struct CLIArguments {
|
struct CLIArguments {
|
||||||
#[argh(positional)]
|
#[argh(positional)]
|
||||||
@ -35,77 +42,132 @@ struct CLIArguments {
|
|||||||
#[argh(option, short='L', description="search list size")]
|
#[argh(option, short='L', description="search list size")]
|
||||||
search_list_size: Option<usize>,
|
search_list_size: Option<usize>,
|
||||||
#[argh(switch, description="always use full-precision vectors (slow)")]
|
#[argh(switch, description="always use full-precision vectors (slow)")]
|
||||||
disable_pq: bool
|
disable_pq: bool,
|
||||||
|
#[argh(option, short='l', description="listen address")]
|
||||||
|
listen_address: Option<String>,
|
||||||
|
#[argh(option, short='c', description="clip server")]
|
||||||
|
clip_server: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> {
|
lazy_static! {
|
||||||
let offset = id as usize * header.record_pad_size;
|
static ref QUERIES_COUNTER: IntCounter = register_int_counter!("mse_queries", "queries executed").unwrap();
|
||||||
let mut buf = vec![0; header.record_pad_size as usize];
|
static ref TERMS_COUNTER: IntCounterVec = register_int_counter_vec!("mse_terms", "terms used in queries, by type", &["type"]).unwrap();
|
||||||
data_file.read_exact_at(&mut buf, offset as u64)?;
|
static ref NODE_READS: IntCounter = register_int_counter!("mse_node_reads", "graph nodes read").unwrap();
|
||||||
|
static ref PQ_COMPARISONS: IntCounter = register_int_counter!("mse_pq_comparisons", "product quantization comparisons").unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_node<'a>(id: u32, index: Rc<Index>) -> Result<PackedIndexEntry> {
|
||||||
|
let offset = id as usize * index.header.record_pad_size;
|
||||||
|
let buf = vec![0; index.header.record_pad_size as usize];
|
||||||
|
let (res, buf) = index.data_file.read_exact_at(buf, offset as u64).await;
|
||||||
|
res?;
|
||||||
|
NODE_READS.inc();
|
||||||
let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize;
|
let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize;
|
||||||
Ok(bitcode::decode(&buf[2..len+2])?)
|
Ok(bitcode::decode(&buf[2..len+2])?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_pq_codes(id: u32, codes: &Mmap, buf: &mut Vec<u8>, pq_code_size: usize) {
|
fn next_several_unvisited(s: &mut NeighbourBuffer, n: usize) -> Option<Vec<u32>> {
|
||||||
let loc = (id as usize) * pq_code_size;
|
let mut result = Vec::new();
|
||||||
buf.extend(&codes[loc..loc+pq_code_size])
|
for _ in 0..n {
|
||||||
|
if let Some(neighbour) = s.next_unvisited() {
|
||||||
|
result.push(neighbour);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if result.len() > 0 {
|
||||||
|
Some(result)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
|
||||||
|
let loc = (id as usize) * index.pq_code_size;
|
||||||
|
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Scratch {
|
struct Scratch {
|
||||||
|
visited_adjacent: HashSet<u32>,
|
||||||
visited: HashSet<u32>,
|
visited: HashSet<u32>,
|
||||||
neighbour_buffer: NeighbourBuffer,
|
neighbour_buffer: NeighbourBuffer,
|
||||||
neighbour_pre_buffer: Vec<u32>,
|
neighbour_pre_buffer: Vec<u32>,
|
||||||
visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)>
|
visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)>
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IndexRef<'a> {
|
struct Index {
|
||||||
data_file: &'a mut fs::File,
|
data_file: fs::File,
|
||||||
pq_codes: &'a Mmap,
|
pq_codes: Mmap,
|
||||||
header: &'a IndexHeader,
|
header: Rc<IndexHeader>,
|
||||||
pq_code_size: usize
|
pq_code_size: usize,
|
||||||
|
descriptors: Mmap,
|
||||||
|
n_descriptors: usize
|
||||||
}
|
}
|
||||||
|
|
||||||
fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef, disable_pq: bool) -> Result<(usize, usize)> {
|
struct DescriptorScales(Vec<f32>);
|
||||||
scratch.visited.clear();
|
|
||||||
|
fn descriptor_product(index: Rc<Index>, scales: &DescriptorScales, neighbour: u32) -> i64 {
|
||||||
|
let mut result = 0;
|
||||||
|
// effectively an extra part of the vector to dot product
|
||||||
|
for (j, d) in scales.0.iter().enumerate() {
|
||||||
|
result += scale_dot_result(d * index.descriptors[neighbour as usize * index.n_descriptors + j] as f32);
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn greedy_search<'a>(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &QueryLUT, descriptor_scales: &DescriptorScales, index: Rc<Index>, disable_pq: bool, beamwidth: usize) -> Result<(usize, usize)> {
|
||||||
|
scratch.visited_adjacent.clear();
|
||||||
scratch.neighbour_buffer.clear();
|
scratch.neighbour_buffer.clear();
|
||||||
scratch.visited_list.clear();
|
scratch.visited_list.clear();
|
||||||
|
scratch.visited.clear();
|
||||||
|
|
||||||
let mut cmps = 0;
|
let mut cmps = 0;
|
||||||
let mut pq_cmps = 0;
|
let mut pq_cmps = 0;
|
||||||
|
|
||||||
let node = read_node(start, index.data_file, index.header)?;
|
scratch.neighbour_buffer.insert(start, 0);
|
||||||
let vector = bytemuck::cast_slice(&node.vector);
|
scratch.visited_adjacent.insert(start);
|
||||||
scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vector));
|
|
||||||
scratch.visited.insert(start);
|
|
||||||
|
|
||||||
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
|
while let Some(pts) = next_several_unvisited(&mut scratch.neighbour_buffer, beamwidth) {
|
||||||
//println!("pt {} {:?}", pt, graph.out_neighbours(pt));
|
|
||||||
scratch.neighbour_pre_buffer.clear();
|
scratch.neighbour_pre_buffer.clear();
|
||||||
let node = read_node(pt, index.data_file, index.header)?;
|
|
||||||
let vector = bytemuck::cast_slice(&node.vector);
|
let mut join_handles = Vec::with_capacity(pts.len());
|
||||||
let distance = fast_dot_noprefetch(query, &vector);
|
|
||||||
cmps += 1;
|
for &pt in pts.iter() {
|
||||||
scratch.visited_list.push((pt, distance, node.url, node.shards, node.scores));
|
join_handles.push(monoio::spawn(read_node(pt, index.clone())));
|
||||||
for &neighbour in node.vertices.iter() {
|
}
|
||||||
if scratch.visited.insert(neighbour) {
|
|
||||||
scratch.neighbour_pre_buffer.push(neighbour);
|
for handle in join_handles {
|
||||||
|
let index = index.clone();
|
||||||
|
let node = handle.await?;
|
||||||
|
let vector = bytemuck::cast_slice(&node.vector);
|
||||||
|
let distance = fast_dot_noprefetch(query, &vector);
|
||||||
|
cmps += 1;
|
||||||
|
if scratch.visited.insert(node.id) {
|
||||||
|
scratch.visited_list.push((node.id, distance, node.url, node.shards, node.scores));
|
||||||
|
};
|
||||||
|
for &neighbour in node.vertices.iter() {
|
||||||
|
if scratch.visited_adjacent.insert(neighbour) {
|
||||||
|
scratch.neighbour_pre_buffer.push(neighbour);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
let mut pq_codes = Vec::with_capacity(index.pq_code_size * scratch.neighbour_pre_buffer.len());
|
||||||
let mut pq_codes = Vec::with_capacity(index.pq_code_size * scratch.neighbour_pre_buffer.len());
|
for &neighbour in scratch.neighbour_pre_buffer.iter() {
|
||||||
for &neighbour in scratch.neighbour_pre_buffer.iter() {
|
read_pq_codes(neighbour, index.clone(), &mut pq_codes);
|
||||||
read_pq_codes(neighbour, index.pq_codes, &mut pq_codes, index.pq_code_size);
|
}
|
||||||
}
|
let mut approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes);
|
||||||
let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes);
|
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
|
||||||
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
|
if disable_pq {
|
||||||
if disable_pq {
|
let node = read_node(neighbour, index.clone()).await?;
|
||||||
//let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
|
let vector = bytemuck::cast_slice(&node.vector);
|
||||||
let node = read_node(neighbour, index.data_file, index.header)?;
|
let mut score = fast_dot_noprefetch(query, &vector);
|
||||||
let vector = bytemuck::cast_slice(&node.vector);
|
score += descriptor_product(index.clone(), &descriptor_scales, neighbour);
|
||||||
let distance = fast_dot_noprefetch(query, &vector);
|
scratch.neighbour_buffer.insert(neighbour, score);
|
||||||
scratch.neighbour_buffer.insert(neighbour, distance);
|
} else {
|
||||||
} else {
|
approx_scores[i] += descriptor_product(index.clone(), &descriptor_scales, neighbour);
|
||||||
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]);
|
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]);
|
||||||
pq_cmps += 1;
|
pq_cmps += 1;
|
||||||
|
PQ_COMPARISONS.inc();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -124,17 +186,21 @@ fn summary_stats(ranks: &mut [usize]) {
|
|||||||
|
|
||||||
const K: usize = 20;
|
const K: usize = 20;
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||||
let args: CLIArguments = argh::from_env();
|
let mut top_k_ranks_best_shard = vec![];
|
||||||
|
let mut top_rank_best_shard = vec![];
|
||||||
|
let mut pq_cmps = vec![];
|
||||||
|
let mut cmps = vec![];
|
||||||
|
let mut recall_total = 0;
|
||||||
|
|
||||||
let mut queries = vec![];
|
let mut queries = vec![];
|
||||||
|
|
||||||
if let Some(query_vector_base64) = args.query_vector_base64 {
|
if let Some(query_vector_base64) = &args.query_vector_base64 {
|
||||||
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(query_vector_base64.as_bytes()).context("invalid base64")?);
|
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(query_vector_base64.as_bytes()).context("invalid base64")?);
|
||||||
queries.push(query_vector);
|
queries.push(query_vector);
|
||||||
}
|
}
|
||||||
if let Some(query_vector_file) = args.query_vector_file {
|
if let Some(query_vector_file) = &args.query_vector_file {
|
||||||
let query_vectors = fs::read(query_vector_file)?;
|
let query_vectors = fs::read(query_vector_file).await?;
|
||||||
queries.extend(common::chunk_fp16_buffer(&query_vectors).chunks(1152).map(|x| x.to_vec()).collect::<Vec<_>>());
|
queries.extend(common::chunk_fp16_buffer(&query_vectors).chunks(1152).map(|x| x.to_vec()).collect::<Vec<_>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,30 +208,12 @@ fn main() -> Result<()> {
|
|||||||
queries.truncate(n);
|
queries.truncate(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
let index_path = PathBuf::from(&args.index_path);
|
|
||||||
let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?;
|
|
||||||
let mut data_file = fs::File::open(index_path.join("index.bin"))?;
|
|
||||||
let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin"))?;
|
|
||||||
let pq_codes = unsafe {
|
|
||||||
// This is unsafe because other processes could in principle edit the mmap'd file.
|
|
||||||
// It would be annoying to do anything about this possibility, so ignore it.
|
|
||||||
MmapOptions::new().populate().map(&pq_codes_file)?
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
|
|
||||||
|
|
||||||
let mut top_k_ranks_best_shard = vec![];
|
|
||||||
let mut top_rank_best_shard = vec![];
|
|
||||||
let mut pq_cmps = vec![];
|
|
||||||
let mut cmps = vec![];
|
|
||||||
let mut recall_total = 0;
|
|
||||||
|
|
||||||
for query_vector in queries.iter() {
|
for query_vector in queries.iter() {
|
||||||
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
|
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
|
||||||
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32);
|
let query_preprocessed = index.header.quantizer.preprocess_query(&query_vector_fp32);
|
||||||
|
|
||||||
// TODO slightly dubious
|
// TODO slightly dubious
|
||||||
let selected_shard = header.shards.iter().position_max_by_key(|x| {
|
let selected_shard = index.header.shards.iter().position_max_by_key(|x| {
|
||||||
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
|
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
|
||||||
}).unwrap();
|
}).unwrap();
|
||||||
|
|
||||||
@ -175,8 +223,8 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let mut matches = vec![];
|
let mut matches = vec![];
|
||||||
// brute force scan
|
// brute force scan
|
||||||
for i in 0..header.count {
|
for i in 0..index.header.count {
|
||||||
let node = read_node(i, &mut data_file, &header)?;
|
let node = read_node(i, index.clone()).await?;
|
||||||
//println!("{} {}", i, node.url);
|
//println!("{} {}", i, node.url);
|
||||||
let vector = bytemuck::cast_slice(&node.vector);
|
let vector = bytemuck::cast_slice(&node.vector);
|
||||||
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
|
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
|
||||||
@ -192,23 +240,22 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let mut top_ranks = vec![usize::MAX; K];
|
let mut top_ranks = vec![usize::MAX; K];
|
||||||
|
|
||||||
for shard in 0..header.shards.len() {
|
for shard in 0..index.header.shards.len() {
|
||||||
let selected_start = header.shards[shard].1;
|
let selected_start = index.header.shards[shard].1;
|
||||||
|
|
||||||
|
let beamwidth = 3;
|
||||||
|
|
||||||
let mut scratch = Scratch {
|
let mut scratch = Scratch {
|
||||||
visited: HashSet::new(),
|
visited: HashSet::new(),
|
||||||
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
||||||
neighbour_pre_buffer: Vec::new(),
|
neighbour_pre_buffer: Vec::new(),
|
||||||
visited_list: Vec::new()
|
visited_list: Vec::new(),
|
||||||
|
visited_adjacent: HashSet::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng);
|
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
|
||||||
let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
|
|
||||||
data_file: &mut data_file,
|
let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, &descriptor_scales, index.clone(), args.disable_pq, beamwidth).await?;
|
||||||
header: &header,
|
|
||||||
pq_codes: &pq_codes,
|
|
||||||
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
|
|
||||||
}, args.disable_pq)?;
|
|
||||||
|
|
||||||
// slightly dubious because this is across shards
|
// slightly dubious because this is across shards
|
||||||
pq_cmps.push(cmps_result.1);
|
pq_cmps.push(cmps_result.1);
|
||||||
@ -255,3 +302,240 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn query_clip_server<I, O>(base_url: &str, path: &str, data: Option<I>) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned {
|
||||||
|
// TODO connection pool or something
|
||||||
|
// also this won't work over TLS
|
||||||
|
|
||||||
|
let url = hyper::Uri::from_str(base_url)?;
|
||||||
|
|
||||||
|
let stream = TcpStream::connect(format!("{}:{}", url.host().unwrap(), url.port_u16().unwrap_or(80))).await?;
|
||||||
|
let io = monoio_compat::hyper::MonoioIo::new(stream.into_poll_io()?);
|
||||||
|
|
||||||
|
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
|
||||||
|
monoio::spawn(async move {
|
||||||
|
if let Err(err) = conn.await {
|
||||||
|
tracing::error!("connection failed: {:?}", err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let authority = url.authority().unwrap().clone();
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.uri(path)
|
||||||
|
.header(hyper::header::HOST, authority.as_str())
|
||||||
|
.header(hyper::header::CONTENT_TYPE, "application/msgpack");
|
||||||
|
|
||||||
|
let res = match data {
|
||||||
|
Some(data) => sender.send_request(req.method(Method::POST).body(Full::new(Bytes::from(rmp_serde::to_vec_named(&data)?)))?).await?,
|
||||||
|
None => sender.send_request(req.method(Method::GET).body(Full::new(Bytes::from("")))?).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
if res.status() != StatusCode::OK {
|
||||||
|
return Err(anyhow::anyhow!("unexpected status code: {}", res.status()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = res.collect().await?.to_bytes();
|
||||||
|
|
||||||
|
let result: O = rmp_serde::from_slice(&data)?;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Service {
|
||||||
|
index: Rc<Index>,
|
||||||
|
inference_server_config: Rc<InferenceServerConfig>,
|
||||||
|
args: Rc<CLIArguments>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl hyper::service::Service<Request<Incoming>> for Service {
|
||||||
|
type Response = Response<Full<Bytes>>;
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||||
|
|
||||||
|
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||||
|
let index = self.index.clone();
|
||||||
|
let args = self.args.clone();
|
||||||
|
let inference_server_config = self.inference_server_config.clone();
|
||||||
|
|
||||||
|
Box::pin(async move {
|
||||||
|
let mut body = match (req.method(), req.uri().path()) {
|
||||||
|
(&Method::GET, "/") => Response::new(Full::new(Bytes::from(serde_json::to_vec(&FrontendInit {
|
||||||
|
n_total: (index.header.count - index.header.dead_count) as u64,
|
||||||
|
d_emb: index.header.quantizer.n_dims,
|
||||||
|
predefined_embedding_names: vec![]
|
||||||
|
})?))),
|
||||||
|
(&Method::POST, "/") => {
|
||||||
|
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
|
||||||
|
if upper > 1<<23 {
|
||||||
|
let mut resp = Response::new(Full::new(Bytes::from("Body too big")));
|
||||||
|
*resp.status_mut() = hyper::StatusCode::PAYLOAD_TOO_LARGE;
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
|
||||||
|
let whole_body = req.collect().await?.to_bytes();
|
||||||
|
|
||||||
|
let body: QueryRequest = serde_json::from_slice(&whole_body)?;
|
||||||
|
|
||||||
|
let query = common::get_total_embedding(
|
||||||
|
&body.terms,
|
||||||
|
&*inference_server_config,
|
||||||
|
|batch, _config| {
|
||||||
|
query_clip_server(args.clip_server.as_ref().unwrap(), "/", Some(batch))
|
||||||
|
},
|
||||||
|
|image, config| async move {
|
||||||
|
let image = image::load_from_memory(&image)?;
|
||||||
|
Ok(serde_bytes::ByteBuf::from(resize_for_embed_sync(&*config, image)?))
|
||||||
|
},
|
||||||
|
&std::collections::HashMap::new(),
|
||||||
|
inference_server_config.clone(),
|
||||||
|
()
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
let selected_shard = index.header.shards.iter().position_max_by_key(|x| {
|
||||||
|
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query).unwrap())
|
||||||
|
}).unwrap();
|
||||||
|
let selected_start = index.header.shards[selected_shard].1;
|
||||||
|
|
||||||
|
let beamwidth = 3;
|
||||||
|
|
||||||
|
let mut scratch = Scratch {
|
||||||
|
visited: HashSet::new(),
|
||||||
|
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
||||||
|
neighbour_pre_buffer: Vec::new(),
|
||||||
|
visited_list: Vec::new(),
|
||||||
|
visited_adjacent: HashSet::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
|
||||||
|
|
||||||
|
let query_preprocessed = index.header.quantizer.preprocess_query(&query);
|
||||||
|
|
||||||
|
let query = query.iter().map(|x| half::f16::from_f32(*x)).collect::<Vec<f16>>();
|
||||||
|
|
||||||
|
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), args.disable_pq, beamwidth).await?;
|
||||||
|
|
||||||
|
scratch.visited_list.sort_by_key(|x| -x.1);
|
||||||
|
|
||||||
|
let matches = scratch.visited_list.drain(..).map(|(id, score, url, shards, scores)| (score as f32, url, String::new(), 0, None)).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let result = QueryResult {
|
||||||
|
formats: vec![],
|
||||||
|
extensions: HashMap::new(),
|
||||||
|
matches
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = serde_json::to_vec(&result)?;
|
||||||
|
|
||||||
|
Response::new(Full::new(Bytes::from(result)))
|
||||||
|
},
|
||||||
|
(&Method::GET, "/metrics") => {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
let encoder = prometheus::TextEncoder::new();
|
||||||
|
let metric_families = prometheus::gather();
|
||||||
|
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||||
|
Response::builder()
|
||||||
|
.header(hyper::header::CONTENT_TYPE, "text/plain; version=0.0.4")
|
||||||
|
.body(Full::new(Bytes::from(buffer))).unwrap()
|
||||||
|
},
|
||||||
|
(&Method::OPTIONS, "/") => {
|
||||||
|
Response::builder()
|
||||||
|
.status(StatusCode::NO_CONTENT)
|
||||||
|
.body(Full::new(Bytes::from(""))).unwrap()
|
||||||
|
},
|
||||||
|
_ => Response::builder()
|
||||||
|
.status(StatusCode::NOT_FOUND)
|
||||||
|
.body(Full::new(Bytes::from("Not Found")))
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
body.headers_mut().entry(hyper::header::CONTENT_TYPE).or_insert(hyper::header::HeaderValue::from_static("application/json"));
|
||||||
|
body.headers_mut().entry(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN).or_insert(hyper::header::HeaderValue::from_static("*"));
|
||||||
|
body.headers_mut().entry(hyper::header::ACCESS_CONTROL_ALLOW_METHODS).or_insert(hyper::header::HeaderValue::from_static("GET, POST, OPTIONS"));
|
||||||
|
body.headers_mut().entry(hyper::header::ACCESS_CONTROL_ALLOW_HEADERS).or_insert(hyper::header::HeaderValue::from_static("Content-Type"));
|
||||||
|
|
||||||
|
Result::<_, anyhow::Error>::Ok(body)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceServerConfig> {
|
||||||
|
loop {
|
||||||
|
match query_clip_server(clip_server.as_ref().unwrap(), "/config", Option::<()>::None).await {
|
||||||
|
Ok(config) => return Ok(config),
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!("waiting for clip server: {}", err);
|
||||||
|
monoio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||||
|
let service = Service {
|
||||||
|
index,
|
||||||
|
inference_server_config: Rc::new(get_backend_config(&args.clip_server).await?),
|
||||||
|
args: Rc::new(args.clone())
|
||||||
|
};
|
||||||
|
|
||||||
|
let listener = TcpListener::bind(args.listen_address.as_ref().unwrap())?;
|
||||||
|
println!("Listening");
|
||||||
|
loop {
|
||||||
|
let (stream, _) = listener.accept().await?;
|
||||||
|
let stream_poll = monoio_compat::hyper::MonoioIo::new(stream.into_poll_io()?);
|
||||||
|
let service = service.clone();
|
||||||
|
monoio::spawn(async move {
|
||||||
|
// Handle the connection from the client using HTTP1 and pass any
|
||||||
|
// HTTP requests received on that connection to the `hello` function
|
||||||
|
if let Err(err) = http1::Builder::new()
|
||||||
|
.timer(monoio_compat::hyper::MonoioTimer)
|
||||||
|
.serve_connection(stream_poll, service)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
println!("Error serving connection: {:?}", err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[monoio::main(threads=1, enable_timer=true)]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
let args: CLIArguments = argh::from_env();
|
||||||
|
|
||||||
|
let index_path = PathBuf::from(&args.index_path);
|
||||||
|
let header: IndexHeader = rmp_serde::from_slice(&fs::read(index_path.join("index.msgpack")).await?)?;
|
||||||
|
let header = Rc::new(header);
|
||||||
|
// contains graph structure, full-precision vectors, and bulk metadata
|
||||||
|
let data_file = fs::File::open(index_path.join("index.bin")).await?;
|
||||||
|
// contains product quantization codes
|
||||||
|
let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin")).await?;
|
||||||
|
let pq_codes = unsafe {
|
||||||
|
// This is unsafe because other processes could in principle edit the mmap'd file.
|
||||||
|
// It would be annoying to do anything about this possibility, so ignore it.
|
||||||
|
MmapOptions::new().populate().map(&pq_codes_file)?
|
||||||
|
};
|
||||||
|
// contains metadata descriptors
|
||||||
|
let descriptors_file = fs::File::open(index_path.join("index.descriptor-codes.bin")).await?;
|
||||||
|
let descriptors = unsafe {
|
||||||
|
MmapOptions::new().populate().map(&descriptors_file)?
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
|
||||||
|
|
||||||
|
let index = Rc::new(Index {
|
||||||
|
data_file,
|
||||||
|
header: header.clone(),
|
||||||
|
pq_codes,
|
||||||
|
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
|
||||||
|
descriptors,
|
||||||
|
n_descriptors: header.descriptor_cdfs.len(),
|
||||||
|
});
|
||||||
|
|
||||||
|
if args.listen_address.is_some() {
|
||||||
|
serve(&args, index).await?;
|
||||||
|
} else {
|
||||||
|
evaluate(&args, index).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user