mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-09-11 15:26:02 +00:00
Multithread query server
While profiling suggests that most operations are cheap and IO-bound rather than CPU-bound, the GEMM for deduplication is pretty slow. As such, use multiple threads for higher throughput.
This commit is contained in:
47
diskann/Cargo.lock
generated
47
diskann/Cargo.lock
generated
@@ -26,18 +26,6 @@ version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
|
||||
dependencies = [
|
||||
"funty",
|
||||
"radium",
|
||||
"tap",
|
||||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.20.0"
|
||||
@@ -140,13 +128,11 @@ name = "diskann"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bitvec",
|
||||
"bytemuck",
|
||||
"fastrand",
|
||||
"foldhash",
|
||||
"half",
|
||||
"matrixmultiply",
|
||||
"min-max-heap",
|
||||
"rayon",
|
||||
"rmp-serde",
|
||||
"serde",
|
||||
@@ -174,12 +160,6 @@ version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
|
||||
|
||||
[[package]]
|
||||
name = "funty"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.4.1"
|
||||
@@ -229,12 +209,6 @@ dependencies = [
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "min-max-heap"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18"
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.11"
|
||||
@@ -331,12 +305,6 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "radium"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
@@ -491,12 +459,6 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tap"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.8"
|
||||
@@ -744,12 +706,3 @@ name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "wyz"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
|
||||
dependencies = [
|
||||
"tap",
|
||||
]
|
||||
|
@@ -70,7 +70,10 @@ fn main() -> Result<()> {
|
||||
l: 192,
|
||||
maxc: 750,
|
||||
alpha: 65200,
|
||||
saturate_graph: false
|
||||
saturate_graph: false,
|
||||
query_breakpoint: vecs.len() as u32,
|
||||
query_alpha: 65200,
|
||||
max_add_per_stitch_iter: 0
|
||||
};
|
||||
|
||||
let mut graph = IndexGraph::empty(vecs.len(), config.r);
|
||||
@@ -105,13 +108,16 @@ fn main() -> Result<()> {
|
||||
l: 200,
|
||||
alpha: 65536,
|
||||
maxc: 0,
|
||||
saturate_graph: false
|
||||
saturate_graph: false,
|
||||
query_breakpoint: vecs.len() as u32,
|
||||
query_alpha: 65200,
|
||||
max_add_per_stitch_iter: 0
|
||||
};
|
||||
|
||||
let mut scratch = Scratch::new(config);
|
||||
|
||||
for (i, vec) in tqdm::tqdm(vecs.iter().enumerate()) {
|
||||
let ctr = greedy_search(&mut scratch, medioid, &vec, &vecs, &graph, config);
|
||||
let ctr = greedy_search(&mut scratch, medioid, false, &vec, &vecs, &graph, config);
|
||||
cmps_ctr += ctr.distances;
|
||||
cmps.push(ctr.distances);
|
||||
if scratch.neighbour_buffer.ids[0] == (i as u32) {
|
||||
|
@@ -1,5 +1,3 @@
|
||||
use core::f32;
|
||||
|
||||
use half::f16;
|
||||
use simsimd::SpatialSimilarity;
|
||||
use fastrand::Rng;
|
||||
@@ -420,6 +418,7 @@ pub fn scale_dot_result_f64(x: f64) -> i64 {
|
||||
#[cfg(test)]
|
||||
mod bench {
|
||||
use super::*;
|
||||
use half::slice::HalfFloatSliceExt;
|
||||
use test::Bencher;
|
||||
|
||||
#[bench]
|
||||
@@ -451,4 +450,27 @@ mod bench {
|
||||
fast_dot_noprefetch(&a, &b)
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_preprocess_query(be: &mut Bencher) {
|
||||
let mut rng = fastrand::Rng::with_seed(1);
|
||||
let pq = rmp_serde::from_slice::<ProductQuantizer>(&std::fs::read("opq.msgpack").unwrap()).unwrap();
|
||||
let query = Vector::randn(&mut rng, pq.n_dims).to_f32_vec();
|
||||
be.iter(|| {
|
||||
pq.preprocess_query(&query)
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_asymmetric_dot_product(be: &mut Bencher) {
|
||||
let mut rng = fastrand::Rng::with_seed(1);
|
||||
let pq = rmp_serde::from_slice::<ProductQuantizer>(&std::fs::read("opq.msgpack").unwrap()).unwrap();
|
||||
let query = Vector::randn(&mut rng, pq.n_dims).to_f32_vec();
|
||||
let lut = pq.preprocess_query(&query);
|
||||
let mut pq_vectors = vec![0; 100 * pq.n_dims / pq.n_dims_per_code];
|
||||
rng.fill(&mut pq_vectors);
|
||||
be.iter(|| {
|
||||
pq.asymmetric_dot_product(&lut, &pq_vectors)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
29
perf_test.py
Normal file
29
perf_test.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import numpy as np
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
queries = np.random.randn(1000, 1152)
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
sem = asyncio.Semaphore(100)
|
||||
async def lookup(embedding):
|
||||
async with sess.post("http://localhost:5601", json={
|
||||
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
||||
"k": 10
|
||||
}) as res:
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
return (await res.json())["matches"]
|
||||
|
||||
async def dispatch(i):
|
||||
await lookup(queries[i])
|
||||
sem.release()
|
||||
|
||||
for i in range(1000):
|
||||
await sem.acquire()
|
||||
tg.create_task(dispatch(i))
|
||||
|
||||
asyncio.run(main())
|
@@ -22,6 +22,7 @@ use serde::{Serialize, Deserialize};
|
||||
use std::str::FromStr;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod common;
|
||||
|
||||
@@ -97,7 +98,7 @@ const DUPLICATES_THRESHOLD: f32 = 0.95;
|
||||
|
||||
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
|
||||
let loc = (id as usize) * index.pq_code_size;
|
||||
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
|
||||
buf.extend(&index.memory_maps.pq_codes[loc..loc+index.pq_code_size])
|
||||
}
|
||||
|
||||
struct VisitedNode {
|
||||
@@ -121,11 +122,10 @@ struct Scratch {
|
||||
|
||||
struct Index {
|
||||
data_file: fs::File,
|
||||
pq_codes: Mmap,
|
||||
header: Rc<IndexHeader>,
|
||||
pq_code_size: usize,
|
||||
descriptors: Mmap,
|
||||
n_descriptors: usize
|
||||
n_descriptors: usize,
|
||||
memory_maps: Arc<MemoryMaps>
|
||||
}
|
||||
|
||||
struct DescriptorScales(Vec<f32>);
|
||||
@@ -134,7 +134,7 @@ fn descriptor_product(index: Rc<Index>, scales: &DescriptorScales, neighbour: u3
|
||||
let mut result = 0;
|
||||
// effectively an extra part of the vector to dot product
|
||||
for (j, d) in scales.0.iter().enumerate() {
|
||||
result += scale_dot_result(d * index.descriptors[neighbour as usize * index.n_descriptors + j] as f32);
|
||||
result += scale_dot_result(d * index.memory_maps.descriptors[neighbour as usize * index.n_descriptors + j] as f32);
|
||||
}
|
||||
result
|
||||
}
|
||||
@@ -220,7 +220,9 @@ fn summary_stats(ranks: &mut [usize]) {
|
||||
|
||||
const K: usize = 20;
|
||||
|
||||
async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||
#[monoio::main(threads=1)]
|
||||
async fn evaluate(args: Arc<CLIArguments>, memory_maps: Arc<MemoryMaps>) -> Result<()> {
|
||||
let index = initialize_index(args.clone(), memory_maps).await?;
|
||||
let mut top_k_ranks_best_shard = vec![];
|
||||
let mut top_rank_best_shard = vec![];
|
||||
let mut pq_cmps = vec![];
|
||||
@@ -384,7 +386,7 @@ struct TelemetryMessage {
|
||||
event: String,
|
||||
#[serde(rename="instanceId")]
|
||||
instance_id: String,
|
||||
page: String
|
||||
page: Option<String>
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -610,7 +612,7 @@ fn telemetry_handler(rx: std::sync::mpsc::Receiver<TelemetryMessage>, config: Se
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||
async fn serve(args: Arc<CLIArguments>, index: Rc<Index>) -> Result<()> {
|
||||
let config: ServerConfig = serde_json::from_slice(&std::fs::read(args.config_path.as_ref().unwrap())?)?;
|
||||
|
||||
let (telemetry_channel, telemetry_receiver) = std::sync::mpsc::channel();
|
||||
@@ -645,51 +647,83 @@ async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
#[monoio::main(threads=1, enable_timer=true)]
|
||||
async fn main() -> Result<()> {
|
||||
let args: CLIArguments = argh::from_env();
|
||||
struct MemoryMaps {
|
||||
pq_codes: memmap2::Mmap,
|
||||
descriptors: memmap2::Mmap,
|
||||
guards: Vec<region::LockGuard>
|
||||
}
|
||||
|
||||
async fn initialize_index(args: Arc<CLIArguments>, memory_maps: Arc<MemoryMaps>) -> Result<Rc<Index>> {
|
||||
let index_path = PathBuf::from(&args.index_path);
|
||||
let header: IndexHeader = rmp_serde::from_slice(&fs::read(index_path.join("index.msgpack")).await?)?;
|
||||
let header = Rc::new(header);
|
||||
// contains graph structure, full-precision vectors, and bulk metadata
|
||||
let data_file = fs::File::open(index_path.join("index.bin")).await?;
|
||||
// contains product quantization codes
|
||||
let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin")).await?;
|
||||
let pq_codes = unsafe {
|
||||
// This is unsafe because other processes could in principle edit the mmap'd file.
|
||||
// It would be annoying to do anything about this possibility, so ignore it.
|
||||
MmapOptions::new().populate().map_copy_read_only(&pq_codes_file)?
|
||||
};
|
||||
// contains metadata descriptors
|
||||
let descriptors_file = fs::File::open(index_path.join("index.descriptor-codes.bin")).await?;
|
||||
let descriptors = unsafe {
|
||||
MmapOptions::new().populate().map_copy_read_only(&descriptors_file)?
|
||||
};
|
||||
|
||||
let _guards = if args.lock_memory {
|
||||
let g1 = region::lock(descriptors.as_ptr(), descriptors.len())?;
|
||||
let g2 = region::lock(pq_codes.as_ptr(), pq_codes.len())?;
|
||||
Some((g1, g2))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
|
||||
|
||||
let index = Rc::new(Index {
|
||||
data_file,
|
||||
header: header.clone(),
|
||||
pq_codes,
|
||||
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
|
||||
descriptors,
|
||||
n_descriptors: header.descriptor_cdfs.len(),
|
||||
memory_maps
|
||||
});
|
||||
|
||||
if args.config_path.is_some() {
|
||||
serve(&args, index).await?;
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
fn initialize_memory_maps(args: &CLIArguments) -> Result<MemoryMaps> {
|
||||
let index_path = PathBuf::from(&args.index_path);
|
||||
let pq_codes_file = std::fs::File::open(index_path.join("index.pq-codes.bin"))?;
|
||||
let pq_codes = unsafe {
|
||||
// This is unsafe because other processes could in principle edit the mmap'd file.
|
||||
// It would be annoying to do anything about this possibility, so ignore it.
|
||||
MmapOptions::new().populate().map_copy_read_only(&pq_codes_file)?
|
||||
};
|
||||
// contains metadata descriptors
|
||||
let descriptors_file = std::fs::File::open(index_path.join("index.descriptor-codes.bin"))?;
|
||||
let descriptors = unsafe {
|
||||
MmapOptions::new().populate().map_copy_read_only(&descriptors_file)?
|
||||
};
|
||||
|
||||
let guards = if args.lock_memory {
|
||||
let g1 = region::lock(descriptors.as_ptr(), descriptors.len())?;
|
||||
let g2 = region::lock(pq_codes.as_ptr(), pq_codes.len())?;
|
||||
vec![g1, g2]
|
||||
} else {
|
||||
evaluate(&args, index).await?;
|
||||
vec![]
|
||||
};
|
||||
|
||||
Ok(MemoryMaps { pq_codes, descriptors, guards })
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: CLIArguments = argh::from_env();
|
||||
|
||||
let maps = Arc::new(initialize_memory_maps(&args)?);
|
||||
|
||||
let args = Arc::new(args);
|
||||
|
||||
if args.config_path.is_some() {
|
||||
let mut join_handles = vec![];
|
||||
for _ in 0..num_cpus::get() {
|
||||
let args_ = args.clone();
|
||||
let maps_ = maps.clone();
|
||||
let handle = std::thread::spawn(move || {
|
||||
let mut rt = monoio::RuntimeBuilder::<monoio::IoUringDriver>::new().enable_timer().build().unwrap();
|
||||
let index = rt.block_on(initialize_index(args_.clone(), maps_))?;
|
||||
rt.block_on(serve(args_, index))
|
||||
});
|
||||
join_handles.push(handle);
|
||||
}
|
||||
for handle in join_handles {
|
||||
handle.join().unwrap()?;
|
||||
}
|
||||
} else {
|
||||
evaluate(args, maps)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user