1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-09-11 15:26:02 +00:00

Multithread query server

While profiling suggests that most operations are cheap and IO-bound rather than CPU-bound, the GEMM for deduplication is pretty slow. As such, use multiple threads for higher throughput.
This commit is contained in:
osmarks
2025-01-31 13:47:47 +00:00
parent 5215822e39
commit e57931d47f
5 changed files with 130 additions and 86 deletions

47
diskann/Cargo.lock generated
View File

@@ -26,18 +26,6 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "bytemuck"
version = "1.20.0"
@@ -140,13 +128,11 @@ name = "diskann"
version = "0.1.0"
dependencies = [
"anyhow",
"bitvec",
"bytemuck",
"fastrand",
"foldhash",
"half",
"matrixmultiply",
"min-max-heap",
"rayon",
"rmp-serde",
"serde",
@@ -174,12 +160,6 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "half"
version = "2.4.1"
@@ -229,12 +209,6 @@ dependencies = [
"rawpointer",
]
[[package]]
name = "min-max-heap"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18"
[[package]]
name = "mio"
version = "0.8.11"
@@ -331,12 +305,6 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rawpointer"
version = "0.2.1"
@@ -491,12 +459,6 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "thread_local"
version = "1.1.8"
@@ -744,12 +706,3 @@ name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]

View File

@@ -70,7 +70,10 @@ fn main() -> Result<()> {
l: 192,
maxc: 750,
alpha: 65200,
saturate_graph: false
saturate_graph: false,
query_breakpoint: vecs.len() as u32,
query_alpha: 65200,
max_add_per_stitch_iter: 0
};
let mut graph = IndexGraph::empty(vecs.len(), config.r);
@@ -105,13 +108,16 @@ fn main() -> Result<()> {
l: 200,
alpha: 65536,
maxc: 0,
saturate_graph: false
saturate_graph: false,
query_breakpoint: vecs.len() as u32,
query_alpha: 65200,
max_add_per_stitch_iter: 0
};
let mut scratch = Scratch::new(config);
for (i, vec) in tqdm::tqdm(vecs.iter().enumerate()) {
let ctr = greedy_search(&mut scratch, medioid, &vec, &vecs, &graph, config);
let ctr = greedy_search(&mut scratch, medioid, false, &vec, &vecs, &graph, config);
cmps_ctr += ctr.distances;
cmps.push(ctr.distances);
if scratch.neighbour_buffer.ids[0] == (i as u32) {

View File

@@ -1,5 +1,3 @@
use core::f32;
use half::f16;
use simsimd::SpatialSimilarity;
use fastrand::Rng;
@@ -420,6 +418,7 @@ pub fn scale_dot_result_f64(x: f64) -> i64 {
#[cfg(test)]
mod bench {
use super::*;
use half::slice::HalfFloatSliceExt;
use test::Bencher;
#[bench]
@@ -451,4 +450,27 @@ mod bench {
fast_dot_noprefetch(&a, &b)
});
}
#[bench]
fn bench_preprocess_query(be: &mut Bencher) {
let mut rng = fastrand::Rng::with_seed(1);
let pq = rmp_serde::from_slice::<ProductQuantizer>(&std::fs::read("opq.msgpack").unwrap()).unwrap();
let query = Vector::randn(&mut rng, pq.n_dims).to_f32_vec();
be.iter(|| {
pq.preprocess_query(&query)
});
}
#[bench]
fn bench_asymmetric_dot_product(be: &mut Bencher) {
let mut rng = fastrand::Rng::with_seed(1);
let pq = rmp_serde::from_slice::<ProductQuantizer>(&std::fs::read("opq.msgpack").unwrap()).unwrap();
let query = Vector::randn(&mut rng, pq.n_dims).to_f32_vec();
let lut = pq.preprocess_query(&query);
let mut pq_vectors = vec![0; 100 * pq.n_dims / pq.n_dims_per_code];
rng.fill(&mut pq_vectors);
be.iter(|| {
pq.asymmetric_dot_product(&lut, &pq_vectors)
});
}
}

29
perf_test.py Normal file
View File

@@ -0,0 +1,29 @@
import numpy as np
import aiohttp
import asyncio
import sys
queries = np.random.randn(1000, 1152)
async def main():
async with aiohttp.ClientSession() as sess:
async with asyncio.TaskGroup() as tg:
sem = asyncio.Semaphore(100)
async def lookup(embedding):
async with sess.post("http://localhost:5601", json={
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
"k": 10
}) as res:
sys.stdout.write(".")
sys.stdout.flush()
return (await res.json())["matches"]
async def dispatch(i):
await lookup(queries[i])
sem.release()
for i in range(1000):
await sem.acquire()
tg.create_task(dispatch(i))
asyncio.run(main())

View File

@@ -22,6 +22,7 @@ use serde::{Serialize, Deserialize};
use std::str::FromStr;
use std::collections::HashMap;
use std::io::Write;
use std::sync::Arc;
mod common;
@@ -97,7 +98,7 @@ const DUPLICATES_THRESHOLD: f32 = 0.95;
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
let loc = (id as usize) * index.pq_code_size;
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
buf.extend(&index.memory_maps.pq_codes[loc..loc+index.pq_code_size])
}
struct VisitedNode {
@@ -121,11 +122,10 @@ struct Scratch {
struct Index {
data_file: fs::File,
pq_codes: Mmap,
header: Rc<IndexHeader>,
pq_code_size: usize,
descriptors: Mmap,
n_descriptors: usize
n_descriptors: usize,
memory_maps: Arc<MemoryMaps>
}
struct DescriptorScales(Vec<f32>);
@@ -134,7 +134,7 @@ fn descriptor_product(index: Rc<Index>, scales: &DescriptorScales, neighbour: u3
let mut result = 0;
// effectively an extra part of the vector to dot product
for (j, d) in scales.0.iter().enumerate() {
result += scale_dot_result(d * index.descriptors[neighbour as usize * index.n_descriptors + j] as f32);
result += scale_dot_result(d * index.memory_maps.descriptors[neighbour as usize * index.n_descriptors + j] as f32);
}
result
}
@@ -220,7 +220,9 @@ fn summary_stats(ranks: &mut [usize]) {
const K: usize = 20;
async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
#[monoio::main(threads=1)]
async fn evaluate(args: Arc<CLIArguments>, memory_maps: Arc<MemoryMaps>) -> Result<()> {
let index = initialize_index(args.clone(), memory_maps).await?;
let mut top_k_ranks_best_shard = vec![];
let mut top_rank_best_shard = vec![];
let mut pq_cmps = vec![];
@@ -384,7 +386,7 @@ struct TelemetryMessage {
event: String,
#[serde(rename="instanceId")]
instance_id: String,
page: String
page: Option<String>
}
#[derive(Clone)]
@@ -610,7 +612,7 @@ fn telemetry_handler(rx: std::sync::mpsc::Receiver<TelemetryMessage>, config: Se
Ok(())
}
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
async fn serve(args: Arc<CLIArguments>, index: Rc<Index>) -> Result<()> {
let config: ServerConfig = serde_json::from_slice(&std::fs::read(args.config_path.as_ref().unwrap())?)?;
let (telemetry_channel, telemetry_receiver) = std::sync::mpsc::channel();
@@ -645,51 +647,83 @@ async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
}
}
#[monoio::main(threads=1, enable_timer=true)]
async fn main() -> Result<()> {
let args: CLIArguments = argh::from_env();
struct MemoryMaps {
pq_codes: memmap2::Mmap,
descriptors: memmap2::Mmap,
guards: Vec<region::LockGuard>
}
async fn initialize_index(args: Arc<CLIArguments>, memory_maps: Arc<MemoryMaps>) -> Result<Rc<Index>> {
let index_path = PathBuf::from(&args.index_path);
let header: IndexHeader = rmp_serde::from_slice(&fs::read(index_path.join("index.msgpack")).await?)?;
let header = Rc::new(header);
// contains graph structure, full-precision vectors, and bulk metadata
let data_file = fs::File::open(index_path.join("index.bin")).await?;
// contains product quantization codes
let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin")).await?;
let pq_codes = unsafe {
// This is unsafe because other processes could in principle edit the mmap'd file.
// It would be annoying to do anything about this possibility, so ignore it.
MmapOptions::new().populate().map_copy_read_only(&pq_codes_file)?
};
// contains metadata descriptors
let descriptors_file = fs::File::open(index_path.join("index.descriptor-codes.bin")).await?;
let descriptors = unsafe {
MmapOptions::new().populate().map_copy_read_only(&descriptors_file)?
};
let _guards = if args.lock_memory {
let g1 = region::lock(descriptors.as_ptr(), descriptors.len())?;
let g2 = region::lock(pq_codes.as_ptr(), pq_codes.len())?;
Some((g1, g2))
} else {
None
};
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
let index = Rc::new(Index {
data_file,
header: header.clone(),
pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
descriptors,
n_descriptors: header.descriptor_cdfs.len(),
memory_maps
});
if args.config_path.is_some() {
serve(&args, index).await?;
Ok(index)
}
fn initialize_memory_maps(args: &CLIArguments) -> Result<MemoryMaps> {
let index_path = PathBuf::from(&args.index_path);
let pq_codes_file = std::fs::File::open(index_path.join("index.pq-codes.bin"))?;
let pq_codes = unsafe {
// This is unsafe because other processes could in principle edit the mmap'd file.
// It would be annoying to do anything about this possibility, so ignore it.
MmapOptions::new().populate().map_copy_read_only(&pq_codes_file)?
};
// contains metadata descriptors
let descriptors_file = std::fs::File::open(index_path.join("index.descriptor-codes.bin"))?;
let descriptors = unsafe {
MmapOptions::new().populate().map_copy_read_only(&descriptors_file)?
};
let guards = if args.lock_memory {
let g1 = region::lock(descriptors.as_ptr(), descriptors.len())?;
let g2 = region::lock(pq_codes.as_ptr(), pq_codes.len())?;
vec![g1, g2]
} else {
evaluate(&args, index).await?;
vec![]
};
Ok(MemoryMaps { pq_codes, descriptors, guards })
}
fn main() -> Result<()> {
let args: CLIArguments = argh::from_env();
let maps = Arc::new(initialize_memory_maps(&args)?);
let args = Arc::new(args);
if args.config_path.is_some() {
let mut join_handles = vec![];
for _ in 0..num_cpus::get() {
let args_ = args.clone();
let maps_ = maps.clone();
let handle = std::thread::spawn(move || {
let mut rt = monoio::RuntimeBuilder::<monoio::IoUringDriver>::new().enable_timer().build().unwrap();
let index = rt.block_on(initialize_index(args_.clone(), maps_))?;
rt.block_on(serve(args_, index))
});
join_handles.push(handle);
}
for handle in join_handles {
handle.join().unwrap()?;
}
} else {
evaluate(args, maps)?;
}
Ok(())