diff --git a/diskann/Cargo.lock b/diskann/Cargo.lock index 370cfdc..64dde60 100644 --- a/diskann/Cargo.lock +++ b/diskann/Cargo.lock @@ -26,18 +26,6 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" -[[package]] -name = "bitvec" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] - [[package]] name = "bytemuck" version = "1.20.0" @@ -140,13 +128,11 @@ name = "diskann" version = "0.1.0" dependencies = [ "anyhow", - "bitvec", "bytemuck", "fastrand", "foldhash", "half", "matrixmultiply", - "min-max-heap", "rayon", "rmp-serde", "serde", @@ -174,12 +160,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" -[[package]] -name = "funty" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" - [[package]] name = "half" version = "2.4.1" @@ -229,12 +209,6 @@ dependencies = [ "rawpointer", ] -[[package]] -name = "min-max-heap" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18" - [[package]] name = "mio" version = "0.8.11" @@ -331,12 +305,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "radium" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" - [[package]] name = "rawpointer" version = "0.2.1" @@ -491,12 +459,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - [[package]] name = "thread_local" version = "1.1.8" @@ -744,12 +706,3 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" - -[[package]] -name = "wyz" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" -dependencies = [ - "tap", -] diff --git a/diskann/src/main.rs b/diskann/src/main.rs index 90a4eb8..cff30c2 100644 --- a/diskann/src/main.rs +++ b/diskann/src/main.rs @@ -70,7 +70,10 @@ fn main() -> Result<()> { l: 192, maxc: 750, alpha: 65200, - saturate_graph: false + saturate_graph: false, + query_breakpoint: vecs.len() as u32, + query_alpha: 65200, + max_add_per_stitch_iter: 0 }; let mut graph = IndexGraph::empty(vecs.len(), config.r); @@ -105,13 +108,16 @@ fn main() -> Result<()> { l: 200, alpha: 65536, maxc: 0, - saturate_graph: false + saturate_graph: false, + query_breakpoint: vecs.len() as u32, + query_alpha: 65200, + max_add_per_stitch_iter: 0 }; let mut scratch = Scratch::new(config); for (i, vec) in tqdm::tqdm(vecs.iter().enumerate()) { - let ctr = greedy_search(&mut scratch, medioid, &vec, &vecs, &graph, config); + let ctr = greedy_search(&mut scratch, medioid, false, &vec, &vecs, &graph, config); cmps_ctr += ctr.distances; cmps.push(ctr.distances); if scratch.neighbour_buffer.ids[0] == (i as u32) { diff --git a/diskann/src/vector.rs b/diskann/src/vector.rs index df98ace..b85ae80 100644 --- a/diskann/src/vector.rs +++ b/diskann/src/vector.rs @@ -1,5 +1,3 @@ -use core::f32; - use half::f16; use simsimd::SpatialSimilarity; use fastrand::Rng; @@ -420,6 +418,7 @@ pub fn scale_dot_result_f64(x: f64) -> i64 { #[cfg(test)] mod bench { use super::*; + use half::slice::HalfFloatSliceExt; use test::Bencher; #[bench] @@ -451,4 +450,27 @@ mod bench { fast_dot_noprefetch(&a, &b) }); } + + #[bench] + fn bench_preprocess_query(be: &mut Bencher) { + let mut rng = fastrand::Rng::with_seed(1); + let pq = rmp_serde::from_slice::(&std::fs::read("opq.msgpack").unwrap()).unwrap(); + let query = Vector::randn(&mut rng, pq.n_dims).to_f32_vec(); + be.iter(|| { + pq.preprocess_query(&query) + }); + } + + #[bench] + fn bench_asymmetric_dot_product(be: &mut Bencher) { + let mut rng = fastrand::Rng::with_seed(1); + let pq = rmp_serde::from_slice::(&std::fs::read("opq.msgpack").unwrap()).unwrap(); + let query = Vector::randn(&mut rng, pq.n_dims).to_f32_vec(); + let lut = pq.preprocess_query(&query); + let mut pq_vectors = vec![0; 100 * pq.n_dims / pq.n_dims_per_code]; + rng.fill(&mut pq_vectors); + be.iter(|| { + pq.asymmetric_dot_product(&lut, &pq_vectors) + }); + } } diff --git a/perf_test.py b/perf_test.py new file mode 100644 index 0000000..ceb6be7 --- /dev/null +++ b/perf_test.py @@ -0,0 +1,29 @@ +import numpy as np +import aiohttp +import asyncio +import sys + +queries = np.random.randn(1000, 1152) + +async def main(): + async with aiohttp.ClientSession() as sess: + async with asyncio.TaskGroup() as tg: + sem = asyncio.Semaphore(100) + async def lookup(embedding): + async with sess.post("http://localhost:5601", json={ + "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry + "k": 10 + }) as res: + sys.stdout.write(".") + sys.stdout.flush() + return (await res.json())["matches"] + + async def dispatch(i): + await lookup(queries[i]) + sem.release() + + for i in range(1000): + await sem.acquire() + tg.create_task(dispatch(i)) + +asyncio.run(main()) diff --git a/src/query_disk_index.rs b/src/query_disk_index.rs index 65c137a..a9e31b6 100644 --- a/src/query_disk_index.rs +++ b/src/query_disk_index.rs @@ -22,6 +22,7 @@ use serde::{Serialize, Deserialize}; use std::str::FromStr; use std::collections::HashMap; use std::io::Write; +use std::sync::Arc; mod common; @@ -97,7 +98,7 @@ const DUPLICATES_THRESHOLD: f32 = 0.95; fn read_pq_codes(id: u32, index: Rc, buf: &mut Vec) { let loc = (id as usize) * index.pq_code_size; - buf.extend(&index.pq_codes[loc..loc+index.pq_code_size]) + buf.extend(&index.memory_maps.pq_codes[loc..loc+index.pq_code_size]) } struct VisitedNode { @@ -121,11 +122,10 @@ struct Scratch { struct Index { data_file: fs::File, - pq_codes: Mmap, header: Rc, pq_code_size: usize, - descriptors: Mmap, - n_descriptors: usize + n_descriptors: usize, + memory_maps: Arc } struct DescriptorScales(Vec); @@ -134,7 +134,7 @@ fn descriptor_product(index: Rc, scales: &DescriptorScales, neighbour: u3 let mut result = 0; // effectively an extra part of the vector to dot product for (j, d) in scales.0.iter().enumerate() { - result += scale_dot_result(d * index.descriptors[neighbour as usize * index.n_descriptors + j] as f32); + result += scale_dot_result(d * index.memory_maps.descriptors[neighbour as usize * index.n_descriptors + j] as f32); } result } @@ -220,7 +220,9 @@ fn summary_stats(ranks: &mut [usize]) { const K: usize = 20; -async fn evaluate(args: &CLIArguments, index: Rc) -> Result<()> { +#[monoio::main(threads=1)] +async fn evaluate(args: Arc, memory_maps: Arc) -> Result<()> { + let index = initialize_index(args.clone(), memory_maps).await?; let mut top_k_ranks_best_shard = vec![]; let mut top_rank_best_shard = vec![]; let mut pq_cmps = vec![]; @@ -384,7 +386,7 @@ struct TelemetryMessage { event: String, #[serde(rename="instanceId")] instance_id: String, - page: String + page: Option } #[derive(Clone)] @@ -610,7 +612,7 @@ fn telemetry_handler(rx: std::sync::mpsc::Receiver, config: Se Ok(()) } -async fn serve(args: &CLIArguments, index: Rc) -> Result<()> { +async fn serve(args: Arc, index: Rc) -> Result<()> { let config: ServerConfig = serde_json::from_slice(&std::fs::read(args.config_path.as_ref().unwrap())?)?; let (telemetry_channel, telemetry_receiver) = std::sync::mpsc::channel(); @@ -645,51 +647,83 @@ async fn serve(args: &CLIArguments, index: Rc) -> Result<()> { } } -#[monoio::main(threads=1, enable_timer=true)] -async fn main() -> Result<()> { - let args: CLIArguments = argh::from_env(); +struct MemoryMaps { + pq_codes: memmap2::Mmap, + descriptors: memmap2::Mmap, + guards: Vec +} +async fn initialize_index(args: Arc, memory_maps: Arc) -> Result> { let index_path = PathBuf::from(&args.index_path); let header: IndexHeader = rmp_serde::from_slice(&fs::read(index_path.join("index.msgpack")).await?)?; let header = Rc::new(header); // contains graph structure, full-precision vectors, and bulk metadata let data_file = fs::File::open(index_path.join("index.bin")).await?; // contains product quantization codes - let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin")).await?; - let pq_codes = unsafe { - // This is unsafe because other processes could in principle edit the mmap'd file. - // It would be annoying to do anything about this possibility, so ignore it. - MmapOptions::new().populate().map_copy_read_only(&pq_codes_file)? - }; - // contains metadata descriptors - let descriptors_file = fs::File::open(index_path.join("index.descriptor-codes.bin")).await?; - let descriptors = unsafe { - MmapOptions::new().populate().map_copy_read_only(&descriptors_file)? - }; - let _guards = if args.lock_memory { - let g1 = region::lock(descriptors.as_ptr(), descriptors.len())?; - let g2 = region::lock(pq_codes.as_ptr(), pq_codes.len())?; - Some((g1, g2)) - } else { - None - }; println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); let index = Rc::new(Index { data_file, header: header.clone(), - pq_codes, pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, - descriptors, n_descriptors: header.descriptor_cdfs.len(), + memory_maps }); - if args.config_path.is_some() { - serve(&args, index).await?; + Ok(index) +} + +fn initialize_memory_maps(args: &CLIArguments) -> Result { + let index_path = PathBuf::from(&args.index_path); + let pq_codes_file = std::fs::File::open(index_path.join("index.pq-codes.bin"))?; + let pq_codes = unsafe { + // This is unsafe because other processes could in principle edit the mmap'd file. + // It would be annoying to do anything about this possibility, so ignore it. + MmapOptions::new().populate().map_copy_read_only(&pq_codes_file)? + }; + // contains metadata descriptors + let descriptors_file = std::fs::File::open(index_path.join("index.descriptor-codes.bin"))?; + let descriptors = unsafe { + MmapOptions::new().populate().map_copy_read_only(&descriptors_file)? + }; + + let guards = if args.lock_memory { + let g1 = region::lock(descriptors.as_ptr(), descriptors.len())?; + let g2 = region::lock(pq_codes.as_ptr(), pq_codes.len())?; + vec![g1, g2] } else { - evaluate(&args, index).await?; + vec![] + }; + + Ok(MemoryMaps { pq_codes, descriptors, guards }) +} + +fn main() -> Result<()> { + let args: CLIArguments = argh::from_env(); + + let maps = Arc::new(initialize_memory_maps(&args)?); + + let args = Arc::new(args); + + if args.config_path.is_some() { + let mut join_handles = vec![]; + for _ in 0..num_cpus::get() { + let args_ = args.clone(); + let maps_ = maps.clone(); + let handle = std::thread::spawn(move || { + let mut rt = monoio::RuntimeBuilder::::new().enable_timer().build().unwrap(); + let index = rt.block_on(initialize_index(args_.clone(), maps_))?; + rt.block_on(serve(args_, index)) + }); + join_handles.push(handle); + } + for handle in join_handles { + handle.join().unwrap()?; + } + } else { + evaluate(args, maps)?; } Ok(())