diff --git a/diskann/Cargo.lock b/diskann/Cargo.lock index f9a41da..370cfdc 100644 --- a/diskann/Cargo.lock +++ b/diskann/Cargo.lock @@ -146,6 +146,7 @@ dependencies = [ "foldhash", "half", "matrixmultiply", + "min-max-heap", "rayon", "rmp-serde", "serde", @@ -228,6 +229,12 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "min-max-heap" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18" + [[package]] name = "mio" version = "0.8.11" diff --git a/diskann/src/lib.rs b/diskann/src/lib.rs index f795747..9fd7797 100644 --- a/diskann/src/lib.rs +++ b/diskann/src/lib.rs @@ -19,20 +19,6 @@ pub struct IndexGraph { } impl IndexGraph { - pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self { - let mut graph = Vec::with_capacity(n); - for _ in 0..n { - let mut adjacency = Vec::with_capacity(capacity); - for _ in 0..r { - adjacency.push(rng.u32(0..(n as u32))); - } - graph.push(RwLock::new(adjacency)); - } - IndexGraph { - graph - } - } - pub fn empty(n: usize, capacity: usize) -> IndexGraph { let mut graph = Vec::with_capacity(n); for _ in 0..n { @@ -57,7 +43,8 @@ pub struct IndexBuildConfig { pub r: usize, pub l: usize, pub maxc: usize, - pub alpha: i64 + pub alpha: i64, + pub saturate_graph: bool } @@ -79,6 +66,7 @@ pub fn medioid(vecs: &VectorList) -> u32 { // neighbours list sorted by score descending // TODO: this may actually be an awful datastructure +// we could also have a heap of unvisited things, but the algorithm's stopping condition cares about visited things, and this is probably still the easiest way #[derive(Clone, Debug)] pub struct NeighbourBuffer { pub ids: Vec, @@ -244,8 +232,9 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec, vecs: &Vect let mut candidate_index = 0; while neigh.len() < config.r && candidate_index < candidates.len() { let p_star = candidates[candidate_index].0; + let p_star_score = candidates[candidate_index].1; candidate_index += 1; - if p_star == p || p_star == u32::MAX { + if p_star == p || p_star_score == i64::MIN { continue; } @@ -256,7 +245,7 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec, vecs: &Vect // mark remaining candidates as not-to-be-used if "not much better than" current candidate for i in (candidate_index+1)..candidates.len() { let p_prime = candidates[i].0; - if p_prime != u32::MAX { + if candidates[i].1 != i64::MIN { scratch.robust_prune_scratch_buffer.push((i, p_prime)); } } @@ -268,7 +257,18 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec, vecs: &Vect let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16; if alpha_times_p_star_prime_score >= p_prime_p_score { - candidates[ci].0 = u32::MAX; + candidates[ci].1 = i64::MIN; + } + } + } + + if config.saturate_graph { + for &(id, _score) in candidates.iter() { + if neigh.len() == config.r { + return; + } + if !neigh.contains(&id) { + neigh.push(id); } } } @@ -313,38 +313,7 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V }); } -// RoarGraph's AcquireNeighbours algorithm is actually almost identical to Vamana/DiskANN's RobustPrune, but with fixed α = 1.0. -// We replace Vamana's random initialization of the graph with Neighbourhood-Aware Projection from RoarGraph - there's no way to use a large enough -// query set that I would be confident in using *only* RoarGraph's algorithm -pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec>, query_knns_bwd: &Vec>, config: IndexBuildConfig, vecs: &VectorList) { - let mut sigmas: Vec = (0..(graph.graph.len() as u32)).collect(); - rng.shuffle(&mut sigmas); - - // Iterate through graph vertices in a random order - let rng = Mutex::new(rng.fork()); - sigmas.into_par_iter().for_each_init(|| (rng.lock().unwrap().fork(), Scratch::new(config)), |(rng, scratch), sigma_i| { - scratch.visited.clear(); - scratch.visited_list.clear(); - scratch.neighbour_pre_buffer.clear(); - for &query_neighbour in query_knns[sigma_i as usize].iter() { - for &projected_neighbour in query_knns_bwd[query_neighbour as usize].iter() { - if scratch.visited.insert(projected_neighbour) { - scratch.neighbour_pre_buffer.push(projected_neighbour); - } - } - } - rng.shuffle(&mut scratch.neighbour_pre_buffer); - scratch.neighbour_pre_buffer.truncate(config.maxc * 2); - for (i, &projected_neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { - let score = fast_dot(&vecs[sigma_i as usize], &vecs[projected_neighbour as usize], &vecs[scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()] as usize]); - scratch.visited_list.push((projected_neighbour, score)); - } - let mut neighbours = graph.out_neighbours_mut(sigma_i); - robust_prune(scratch, sigma_i, &mut *neighbours, vecs, config); - }) -} - -pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec>, query_knns_bwd: Vec>, config: IndexBuildConfig) { +pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec>, query_knns_bwd: Vec>, config: IndexBuildConfig, max_iters: usize) { let mut sigmas: Vec = (0..(graph.graph.len() as u32)).collect(); rng.shuffle(&mut sigmas); @@ -353,7 +322,7 @@ pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec< sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| { let mut neighbours = graph.out_neighbours_mut(sigma_i); let mut i = 0; - while neighbours.len() < config.r && i < 100 { + while neighbours.len() < config.r && i < max_iters { let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap(); let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap(); if !neighbours.contains(&projected_neighbour) { diff --git a/diskann/src/main.rs b/diskann/src/main.rs index ba3208a..90a4eb8 100644 --- a/diskann/src/main.rs +++ b/diskann/src/main.rs @@ -52,7 +52,14 @@ fn main() -> Result<()> { let vecs = { let _timer = Timer::new("loaded vectors"); - &load_file("query.bin", Some(D_EMB * n))? + &load_file("real.bin", None)? + }; + + println!("{} vectors", vecs.len()); + + let queries = { + let _timer = Timer::new("loaded queries"); + &load_file("../query5.bin", None)? }; let (graph, medioid) = { @@ -63,9 +70,12 @@ fn main() -> Result<()> { l: 192, maxc: 750, alpha: 65200, + saturate_graph: false }; - let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); + let mut graph = IndexGraph::empty(vecs.len(), config.r); + + random_fill_graph(&mut rng, &mut graph, config.r); let medioid = medioid(&vecs); @@ -92,9 +102,10 @@ fn main() -> Result<()> { let mut config = IndexBuildConfig { r: 64, - l: 50, + l: 200, alpha: 65536, maxc: 0, + saturate_graph: false }; let mut scratch = Scratch::new(config); @@ -112,8 +123,8 @@ fn main() -> Result<()> { let end = time.elapsed(); - println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n); - println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n); + println!("recall@1: {} ({}/{})", recall as f32 / vecs.len() as f32, recall, vecs.len()); + println!("cmps: {} ({}/{})", cmps_ctr as f32 / vecs.len() as f32, cmps_ctr, vecs.len()); println!("median comparisons: {}", cmps[cmps.len() / 2]); //println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries); println!("{} QPS", n as f32 / end.as_secs_f32()); diff --git a/src/dump_processor.rs b/src/dump_processor.rs index 11abf00..5c9a99e 100644 --- a/src/dump_processor.rs +++ b/src/dump_processor.rs @@ -14,6 +14,7 @@ use itertools::Itertools; use simsimd::SpatialSimilarity; use std::hash::Hasher; use foldhash::{HashSet, HashSetExt}; +use std::os::unix::prelude::FileExt; use diskann::vector::{scale_dot_result_f64, ProductQuantizer}; @@ -161,15 +162,29 @@ fn main() -> Result<()> { let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries { println!("constructing index"); // not memory-efficient but this is small - let data = fs::read(queries_file).context("read queries file")?; + let mut file = fs::File::open(queries_file).context("read queries file")?; + let mut size = file.metadata()?.len(); //let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; - let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; + let mut index = faiss::index_factory(D_EMB, "HNSW64,SQ8", faiss::MetricType::InnerProduct)?; //let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?; - let unpacked = common::decode_fp16_buffer(&data); - index.train(&unpacked)?; - index.add(&unpacked)?; + let mut buf = vec![0; (D_EMB as usize) * (1<<18)]; + loop { + if size == 0 { + break; + } + if size < (buf.len() as u64) { + buf.resize(size as usize, 0); + } + file.read_exact(&mut buf)?; + size -= buf.len() as u64; + let unpacked = common::decode_fp16_buffer(&buf); + if !index.is_trained() { index.train(&unpacked)?; print!("train"); } + index.add(&unpacked)?; + print!("."); + } println!("done"); - (Some(index), unpacked.len() / D_EMB as usize) + let ntotal = index.ntotal(); + (Some(index), ntotal as usize) } else { (None, 0) }; @@ -267,14 +282,15 @@ fn main() -> Result<()> { let shard = shard as usize; // this random access is almost certainly rather slow // parallelize? - files[shard].1.seek(SeekFrom::Start(offset))?; let mut buf = vec![0; len as usize]; - files[shard].1.read_exact(&mut buf)?; - let s: &mut [u32] = bytemuck::cast_slice_mut(&mut *buf); - for within_shard_id in s.iter_mut() { - *within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize]; + files[shard].1.read_exact_at(&mut buf, offset)?; + let s: &[u32] = bytemuck::cast_slice(&mut *buf); + for within_shard_id in s.iter() { + let global_id = shard_id_mappings[shard].1[*within_shard_id as usize]; + if !out_vertices.contains(&global_id) { + out_vertices.push(global_id); + } } - out_vertices.extend(s.iter().unique()); } Ok((out_vertices, shards)) @@ -422,7 +438,7 @@ fn main() -> Result<()> { let codes = quantizer.quantize_batch(&batch_embeddings); for (i, (x, _embedding)) in batch.into_iter().enumerate() { - let (vertices, shards) = read_out_vertices(count)?; // TODO: could parallelize this given the batching + let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching let mut entry = PackedIndexEntry { id: count + i as u32, vertices, diff --git a/src/generate_index_shard.rs b/src/generate_index_shard.rs index 3c42863..813c947 100644 --- a/src/generate_index_shard.rs +++ b/src/generate_index_shard.rs @@ -3,7 +3,7 @@ use itertools::Itertools; use std::io::{BufReader, BufWriter, Write}; use rmp_serde::decode::Error as DecodeError; use std::fs; -use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees}; +use diskann::{augment_bipartite, build_graph, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid}; use half::f16; mod common; @@ -41,10 +41,11 @@ fn main() -> Result<()> { } let mut config = IndexBuildConfig { - r: 40, - l: 200, + r: 64, + l: 192, maxc: 750, - alpha: 65300 + alpha: 65200, + saturate_graph: false }; let vecs = VectorList { @@ -67,9 +68,7 @@ fn main() -> Result<()> { report_degrees(&graph); - let medioid = vecs.iter().position_max_by_key(|&v| { - dot(v, ¢roid_fp16) - }).unwrap() as u32; + let medioid = medioid(&vecs); { let _timer = Timer::new("first pass"); @@ -101,7 +100,8 @@ fn main() -> Result<()> { { let _timer = Timer::new("augment bipartite"); - augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config); + //augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config, 50); + //random_fill_graph(&mut rng, &mut graph, config.r); } let len = original_ids.len(); diff --git a/src/query_disk_index.rs b/src/query_disk_index.rs index 6879ed1..b529a5e 100644 --- a/src/query_disk_index.rs +++ b/src/query_disk_index.rs @@ -2,6 +2,7 @@ use anyhow::{bail, Context, Result}; use diskann::vector::scale_dot_result_f64; use serde::{Serialize, Deserialize}; use std::io::{BufReader, Read, Seek, SeekFrom, Write}; +use std::os::unix::prelude::FileExt; use std::path::PathBuf; use std::fs; use base64::Engine; @@ -37,9 +38,8 @@ struct CLIArguments { fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result { let offset = id as usize * header.record_pad_size; - data_file.seek(SeekFrom::Start(offset as u64))?; let mut buf = vec![0; header.record_pad_size as usize]; - data_file.read_exact(&mut buf)?; + data_file.read_exact_at(&mut buf, offset as u64)?; let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize; Ok(bitcode::decode(&buf[2..len+2])?) } @@ -117,9 +117,11 @@ fn summary_stats(ranks: &mut [usize]) { ranks.sort_unstable(); let median = ranks[ranks.len() / 2] + 1; let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::() / ranks.len() as f64; - println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean); + println!("median {} mean {:.2} max {} min {} harmonic mean {:.2}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean); } +const K: usize = 20; + fn main() -> Result<()> { let args: CLIArguments = argh::from_env(); @@ -150,8 +152,11 @@ fn main() -> Result<()> { println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); - let mut top_20_ranks_best_shard = vec![]; + let mut top_k_ranks_best_shard = vec![]; let mut top_rank_best_shard = vec![]; + let mut pq_cmps = vec![]; + let mut cmps = vec![]; + let mut recall_total = 0; for query_vector in queries.iter() { let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::>(); @@ -183,26 +188,30 @@ fn main() -> Result<()> { println!("brute force: {} {} {} {:?}", id, distance, url, shards); }*/ - let mut top_ranks = vec![usize::MAX; 20]; + let mut top_ranks = vec![usize::MAX; K]; for shard in 0..header.shards.len() { let selected_start = header.shards[shard].1; let mut scratch = Scratch { visited: HashSet::new(), - neighbour_buffer: NeighbourBuffer::new(5000), + neighbour_buffer: NeighbourBuffer::new(1000), neighbour_pre_buffer: Vec::new(), visited_list: Vec::new() }; //let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng); - let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef { + let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef { data_file: &mut data_file, header: &header, pq_codes: &pq_codes, pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, }, args.disable_pq)?; + // slightly dubious because this is across shards + pq_cmps.push(cmps_result.1); + cmps.push(cmps_result.0); + if args.verbose { println!("index scan {}: {:?} cmps", shard, cmps); } @@ -221,14 +230,26 @@ fn main() -> Result<()> { if args.verbose { println!("") } } + // results list is always correctly sorted + for &rank in top_ranks.iter() { + if rank < K { + recall_total += 1; + } + } + top_rank_best_shard.push(top_ranks[0]); - top_20_ranks_best_shard.extend(top_ranks); + top_k_ranks_best_shard.extend(top_ranks); } println!("ranks of top 20:"); - summary_stats(&mut top_20_ranks_best_shard); + summary_stats(&mut top_k_ranks_best_shard); println!("ranks of top 1:"); summary_stats(&mut top_rank_best_shard); + println!("pq comparisons:"); + summary_stats(&mut pq_cmps); + println!("comparisons:"); + summary_stats(&mut cmps); + println!("recall@{}: {}", K, recall_total as f64 / (K * queries.len()) as f64); Ok(()) }