mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-31 15:23:04 +00:00 
			
		
		
		
	fix entire index algorithm (very silly bug)
This commit is contained in:
		
							
								
								
									
										7
									
								
								diskann/Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										7
									
								
								diskann/Cargo.lock
									
									
									
										generated
									
									
									
								
							| @@ -146,6 +146,7 @@ dependencies = [ | |||||||
|  "foldhash", |  "foldhash", | ||||||
|  "half", |  "half", | ||||||
|  "matrixmultiply", |  "matrixmultiply", | ||||||
|  |  "min-max-heap", | ||||||
|  "rayon", |  "rayon", | ||||||
|  "rmp-serde", |  "rmp-serde", | ||||||
|  "serde", |  "serde", | ||||||
| @@ -228,6 +229,12 @@ dependencies = [ | |||||||
|  "rawpointer", |  "rawpointer", | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "min-max-heap" | ||||||
|  | version = "1.3.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18" | ||||||
|  |  | ||||||
| [[package]] | [[package]] | ||||||
| name = "mio" | name = "mio" | ||||||
| version = "0.8.11" | version = "0.8.11" | ||||||
|   | |||||||
| @@ -19,20 +19,6 @@ pub struct IndexGraph { | |||||||
| } | } | ||||||
|  |  | ||||||
| impl IndexGraph { | impl IndexGraph { | ||||||
|     pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self { |  | ||||||
|         let mut graph = Vec::with_capacity(n); |  | ||||||
|         for _ in 0..n { |  | ||||||
|             let mut adjacency = Vec::with_capacity(capacity); |  | ||||||
|             for _ in 0..r { |  | ||||||
|                 adjacency.push(rng.u32(0..(n as u32))); |  | ||||||
|             } |  | ||||||
|             graph.push(RwLock::new(adjacency)); |  | ||||||
|         } |  | ||||||
|         IndexGraph { |  | ||||||
|             graph |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     pub fn empty(n: usize, capacity: usize) -> IndexGraph { |     pub fn empty(n: usize, capacity: usize) -> IndexGraph { | ||||||
|         let mut graph = Vec::with_capacity(n); |         let mut graph = Vec::with_capacity(n); | ||||||
|         for _ in 0..n { |         for _ in 0..n { | ||||||
| @@ -57,7 +43,8 @@ pub struct IndexBuildConfig { | |||||||
|     pub r: usize, |     pub r: usize, | ||||||
|     pub l: usize, |     pub l: usize, | ||||||
|     pub maxc: usize, |     pub maxc: usize, | ||||||
|     pub alpha: i64 |     pub alpha: i64, | ||||||
|  |     pub saturate_graph: bool | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -79,6 +66,7 @@ pub fn medioid(vecs: &VectorList) -> u32 { | |||||||
|  |  | ||||||
| // neighbours list sorted by score descending | // neighbours list sorted by score descending | ||||||
| // TODO: this may actually be an awful datastructure | // TODO: this may actually be an awful datastructure | ||||||
|  | // we could also have a heap of unvisited things, but the algorithm's stopping condition cares about visited things, and this is probably still the easiest way | ||||||
| #[derive(Clone, Debug)] | #[derive(Clone, Debug)] | ||||||
| pub struct NeighbourBuffer { | pub struct NeighbourBuffer { | ||||||
|     pub ids: Vec<u32>, |     pub ids: Vec<u32>, | ||||||
| @@ -244,8 +232,9 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect | |||||||
|     let mut candidate_index = 0; |     let mut candidate_index = 0; | ||||||
|     while neigh.len() < config.r && candidate_index < candidates.len() { |     while neigh.len() < config.r && candidate_index < candidates.len() { | ||||||
|         let p_star = candidates[candidate_index].0; |         let p_star = candidates[candidate_index].0; | ||||||
|  |         let p_star_score = candidates[candidate_index].1; | ||||||
|         candidate_index += 1; |         candidate_index += 1; | ||||||
|         if p_star == p || p_star == u32::MAX { |         if p_star == p || p_star_score == i64::MIN { | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -256,7 +245,7 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect | |||||||
|         // mark remaining candidates as not-to-be-used if "not much better than" current candidate |         // mark remaining candidates as not-to-be-used if "not much better than" current candidate | ||||||
|         for i in (candidate_index+1)..candidates.len() { |         for i in (candidate_index+1)..candidates.len() { | ||||||
|             let p_prime = candidates[i].0; |             let p_prime = candidates[i].0; | ||||||
|             if p_prime != u32::MAX { |             if candidates[i].1 != i64::MIN { | ||||||
|                 scratch.robust_prune_scratch_buffer.push((i, p_prime)); |                 scratch.robust_prune_scratch_buffer.push((i, p_prime)); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -268,7 +257,18 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect | |||||||
|             let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16; |             let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16; | ||||||
|  |  | ||||||
|             if alpha_times_p_star_prime_score >= p_prime_p_score { |             if alpha_times_p_star_prime_score >= p_prime_p_score { | ||||||
|                 candidates[ci].0 = u32::MAX; |                 candidates[ci].1 = i64::MIN; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if config.saturate_graph { | ||||||
|  |         for &(id, _score) in candidates.iter() { | ||||||
|  |             if neigh.len() == config.r { | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |             if !neigh.contains(&id) { | ||||||
|  |                 neigh.push(id); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -313,38 +313,7 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V | |||||||
|     }); |     }); | ||||||
| } | } | ||||||
|  |  | ||||||
| // RoarGraph's AcquireNeighbours algorithm is actually almost identical to Vamana/DiskANN's RobustPrune, but with fixed α = 1.0. | pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig, max_iters: usize) { | ||||||
| // We replace Vamana's random initialization of the graph with Neighbourhood-Aware Projection from RoarGraph - there's no way to use a large enough |  | ||||||
| // query set that I would be confident in using *only* RoarGraph's algorithm |  | ||||||
| pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec<Vec<u32>>, query_knns_bwd: &Vec<Vec<u32>>, config: IndexBuildConfig, vecs: &VectorList) { |  | ||||||
|     let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect(); |  | ||||||
|     rng.shuffle(&mut sigmas); |  | ||||||
|  |  | ||||||
|     // Iterate through graph vertices in a random order |  | ||||||
|     let rng = Mutex::new(rng.fork()); |  | ||||||
|     sigmas.into_par_iter().for_each_init(|| (rng.lock().unwrap().fork(), Scratch::new(config)), |(rng, scratch), sigma_i| { |  | ||||||
|         scratch.visited.clear(); |  | ||||||
|         scratch.visited_list.clear(); |  | ||||||
|         scratch.neighbour_pre_buffer.clear(); |  | ||||||
|         for &query_neighbour in query_knns[sigma_i as usize].iter() { |  | ||||||
|             for &projected_neighbour in query_knns_bwd[query_neighbour as usize].iter() { |  | ||||||
|                 if scratch.visited.insert(projected_neighbour) { |  | ||||||
|                     scratch.neighbour_pre_buffer.push(projected_neighbour); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         rng.shuffle(&mut scratch.neighbour_pre_buffer); |  | ||||||
|         scratch.neighbour_pre_buffer.truncate(config.maxc * 2); |  | ||||||
|         for (i, &projected_neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { |  | ||||||
|             let score = fast_dot(&vecs[sigma_i as usize], &vecs[projected_neighbour as usize], &vecs[scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()] as usize]); |  | ||||||
|             scratch.visited_list.push((projected_neighbour, score)); |  | ||||||
|         } |  | ||||||
|         let mut neighbours = graph.out_neighbours_mut(sigma_i); |  | ||||||
|         robust_prune(scratch, sigma_i, &mut *neighbours, vecs, config); |  | ||||||
|     }) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig) { |  | ||||||
|     let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect(); |     let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect(); | ||||||
|     rng.shuffle(&mut sigmas); |     rng.shuffle(&mut sigmas); | ||||||
|  |  | ||||||
| @@ -353,7 +322,7 @@ pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec< | |||||||
|     sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| { |     sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| { | ||||||
|         let mut neighbours = graph.out_neighbours_mut(sigma_i); |         let mut neighbours = graph.out_neighbours_mut(sigma_i); | ||||||
|         let mut i = 0; |         let mut i = 0; | ||||||
|         while neighbours.len() < config.r && i < 100 { |         while neighbours.len() < config.r && i < max_iters { | ||||||
|             let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap(); |             let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap(); | ||||||
|             let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap(); |             let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap(); | ||||||
|             if !neighbours.contains(&projected_neighbour) { |             if !neighbours.contains(&projected_neighbour) { | ||||||
|   | |||||||
| @@ -52,7 +52,14 @@ fn main() -> Result<()> { | |||||||
|     let vecs = { |     let vecs = { | ||||||
|         let _timer = Timer::new("loaded vectors"); |         let _timer = Timer::new("loaded vectors"); | ||||||
|  |  | ||||||
|         &load_file("query.bin", Some(D_EMB * n))? |         &load_file("real.bin", None)? | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     println!("{} vectors", vecs.len()); | ||||||
|  |  | ||||||
|  |     let queries = { | ||||||
|  |         let _timer = Timer::new("loaded queries"); | ||||||
|  |         &load_file("../query5.bin", None)? | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let (graph, medioid) = { |     let (graph, medioid) = { | ||||||
| @@ -63,9 +70,12 @@ fn main() -> Result<()> { | |||||||
|             l: 192, |             l: 192, | ||||||
|             maxc: 750, |             maxc: 750, | ||||||
|             alpha: 65200, |             alpha: 65200, | ||||||
|  |             saturate_graph: false | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); |         let mut graph = IndexGraph::empty(vecs.len(), config.r); | ||||||
|  |  | ||||||
|  |         random_fill_graph(&mut rng, &mut graph, config.r); | ||||||
|  |  | ||||||
|         let medioid = medioid(&vecs); |         let medioid = medioid(&vecs); | ||||||
|  |  | ||||||
| @@ -92,9 +102,10 @@ fn main() -> Result<()> { | |||||||
|  |  | ||||||
|     let mut config = IndexBuildConfig { |     let mut config = IndexBuildConfig { | ||||||
|         r: 64, |         r: 64, | ||||||
|         l: 50, |         l: 200, | ||||||
|         alpha: 65536, |         alpha: 65536, | ||||||
|         maxc: 0, |         maxc: 0, | ||||||
|  |         saturate_graph: false | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let mut scratch = Scratch::new(config); |     let mut scratch = Scratch::new(config); | ||||||
| @@ -112,8 +123,8 @@ fn main() -> Result<()> { | |||||||
|  |  | ||||||
|     let end = time.elapsed(); |     let end = time.elapsed(); | ||||||
|  |  | ||||||
|     println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n); |     println!("recall@1: {} ({}/{})", recall as f32 / vecs.len() as f32, recall, vecs.len()); | ||||||
|     println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n); |     println!("cmps: {} ({}/{})", cmps_ctr as f32 / vecs.len() as f32, cmps_ctr, vecs.len()); | ||||||
|     println!("median comparisons: {}", cmps[cmps.len() / 2]); |     println!("median comparisons: {}", cmps[cmps.len() / 2]); | ||||||
|     //println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries); |     //println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries); | ||||||
|     println!("{} QPS", n as f32 / end.as_secs_f32()); |     println!("{} QPS", n as f32 / end.as_secs_f32()); | ||||||
|   | |||||||
| @@ -14,6 +14,7 @@ use itertools::Itertools; | |||||||
| use simsimd::SpatialSimilarity; | use simsimd::SpatialSimilarity; | ||||||
| use std::hash::Hasher; | use std::hash::Hasher; | ||||||
| use foldhash::{HashSet, HashSetExt}; | use foldhash::{HashSet, HashSetExt}; | ||||||
|  | use std::os::unix::prelude::FileExt; | ||||||
|  |  | ||||||
| use diskann::vector::{scale_dot_result_f64, ProductQuantizer}; | use diskann::vector::{scale_dot_result_f64, ProductQuantizer}; | ||||||
|  |  | ||||||
| @@ -161,15 +162,29 @@ fn main() -> Result<()> { | |||||||
|     let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries { |     let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries { | ||||||
|         println!("constructing index"); |         println!("constructing index"); | ||||||
|         // not memory-efficient but this is small |         // not memory-efficient but this is small | ||||||
|         let data = fs::read(queries_file).context("read queries file")?; |         let mut file = fs::File::open(queries_file).context("read queries file")?; | ||||||
|  |         let mut size = file.metadata()?.len(); | ||||||
|         //let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; |         //let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; | ||||||
|         let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; |         let mut index = faiss::index_factory(D_EMB, "HNSW64,SQ8", faiss::MetricType::InnerProduct)?; | ||||||
|         //let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?; |         //let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?; | ||||||
|         let unpacked = common::decode_fp16_buffer(&data); |         let mut buf = vec![0; (D_EMB as usize) * (1<<18)]; | ||||||
|         index.train(&unpacked)?; |         loop { | ||||||
|  |             if size == 0 { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             if size < (buf.len() as u64) { | ||||||
|  |                 buf.resize(size as usize, 0); | ||||||
|  |             } | ||||||
|  |             file.read_exact(&mut buf)?; | ||||||
|  |             size -= buf.len() as u64; | ||||||
|  |             let unpacked = common::decode_fp16_buffer(&buf); | ||||||
|  |             if !index.is_trained() { index.train(&unpacked)?; print!("train"); } | ||||||
|             index.add(&unpacked)?; |             index.add(&unpacked)?; | ||||||
|  |             print!("."); | ||||||
|  |         } | ||||||
|         println!("done"); |         println!("done"); | ||||||
|         (Some(index), unpacked.len() / D_EMB as usize) |         let ntotal = index.ntotal(); | ||||||
|  |         (Some(index), ntotal as usize) | ||||||
|     } else { |     } else { | ||||||
|         (None, 0) |         (None, 0) | ||||||
|     }; |     }; | ||||||
| @@ -267,14 +282,15 @@ fn main() -> Result<()> { | |||||||
|                 let shard = shard as usize; |                 let shard = shard as usize; | ||||||
|                 // this random access is almost certainly rather slow |                 // this random access is almost certainly rather slow | ||||||
|                 // parallelize? |                 // parallelize? | ||||||
|                 files[shard].1.seek(SeekFrom::Start(offset))?; |  | ||||||
|                 let mut buf = vec![0; len as usize]; |                 let mut buf = vec![0; len as usize]; | ||||||
|                 files[shard].1.read_exact(&mut buf)?; |                 files[shard].1.read_exact_at(&mut buf, offset)?; | ||||||
|                 let s: &mut [u32] = bytemuck::cast_slice_mut(&mut *buf); |                 let s: &[u32] = bytemuck::cast_slice(&mut *buf); | ||||||
|                 for within_shard_id in s.iter_mut() { |                 for within_shard_id in s.iter() { | ||||||
|                     *within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize]; |                     let global_id = shard_id_mappings[shard].1[*within_shard_id as usize]; | ||||||
|  |                     if !out_vertices.contains(&global_id) { | ||||||
|  |                         out_vertices.push(global_id); | ||||||
|  |                     } | ||||||
|                 } |                 } | ||||||
|                 out_vertices.extend(s.iter().unique()); |  | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             Ok((out_vertices, shards)) |             Ok((out_vertices, shards)) | ||||||
| @@ -422,7 +438,7 @@ fn main() -> Result<()> { | |||||||
|             let codes = quantizer.quantize_batch(&batch_embeddings); |             let codes = quantizer.quantize_batch(&batch_embeddings); | ||||||
|  |  | ||||||
|             for (i, (x, _embedding)) in batch.into_iter().enumerate() { |             for (i, (x, _embedding)) in batch.into_iter().enumerate() { | ||||||
|                 let (vertices, shards) = read_out_vertices(count)?; // TODO: could parallelize this given the batching |                 let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching | ||||||
|                 let mut entry = PackedIndexEntry { |                 let mut entry = PackedIndexEntry { | ||||||
|                     id: count + i as u32, |                     id: count + i as u32, | ||||||
|                     vertices, |                     vertices, | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ use itertools::Itertools; | |||||||
| use std::io::{BufReader, BufWriter, Write}; | use std::io::{BufReader, BufWriter, Write}; | ||||||
| use rmp_serde::decode::Error as DecodeError; | use rmp_serde::decode::Error as DecodeError; | ||||||
| use std::fs; | use std::fs; | ||||||
| use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees}; | use diskann::{augment_bipartite, build_graph, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid}; | ||||||
| use half::f16; | use half::f16; | ||||||
|  |  | ||||||
| mod common; | mod common; | ||||||
| @@ -41,10 +41,11 @@ fn main() -> Result<()> { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     let mut config = IndexBuildConfig { |     let mut config = IndexBuildConfig { | ||||||
|         r: 40, |         r: 64, | ||||||
|         l: 200, |         l: 192, | ||||||
|         maxc: 750, |         maxc: 750, | ||||||
|         alpha: 65300 |         alpha: 65200, | ||||||
|  |         saturate_graph: false | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let vecs = VectorList { |     let vecs = VectorList { | ||||||
| @@ -67,9 +68,7 @@ fn main() -> Result<()> { | |||||||
|  |  | ||||||
|     report_degrees(&graph); |     report_degrees(&graph); | ||||||
|  |  | ||||||
|     let medioid = vecs.iter().position_max_by_key(|&v| { |     let medioid = medioid(&vecs); | ||||||
|         dot(v, ¢roid_fp16) |  | ||||||
|     }).unwrap() as u32; |  | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         let _timer = Timer::new("first pass"); |         let _timer = Timer::new("first pass"); | ||||||
| @@ -101,7 +100,8 @@ fn main() -> Result<()> { | |||||||
|  |  | ||||||
|     { |     { | ||||||
|         let _timer = Timer::new("augment bipartite"); |         let _timer = Timer::new("augment bipartite"); | ||||||
|         augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config); |         //augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config, 50); | ||||||
|  |         //random_fill_graph(&mut rng, &mut graph, config.r); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     let len = original_ids.len(); |     let len = original_ids.len(); | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ use anyhow::{bail, Context, Result}; | |||||||
| use diskann::vector::scale_dot_result_f64; | use diskann::vector::scale_dot_result_f64; | ||||||
| use serde::{Serialize, Deserialize}; | use serde::{Serialize, Deserialize}; | ||||||
| use std::io::{BufReader, Read, Seek, SeekFrom, Write}; | use std::io::{BufReader, Read, Seek, SeekFrom, Write}; | ||||||
|  | use std::os::unix::prelude::FileExt; | ||||||
| use std::path::PathBuf; | use std::path::PathBuf; | ||||||
| use std::fs; | use std::fs; | ||||||
| use base64::Engine; | use base64::Engine; | ||||||
| @@ -37,9 +38,8 @@ struct CLIArguments { | |||||||
|  |  | ||||||
| fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> { | fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> { | ||||||
|     let offset = id as usize * header.record_pad_size; |     let offset = id as usize * header.record_pad_size; | ||||||
|     data_file.seek(SeekFrom::Start(offset as u64))?; |  | ||||||
|     let mut buf = vec![0; header.record_pad_size as usize]; |     let mut buf = vec![0; header.record_pad_size as usize]; | ||||||
|     data_file.read_exact(&mut buf)?; |     data_file.read_exact_at(&mut buf, offset as u64)?; | ||||||
|     let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize; |     let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize; | ||||||
|     Ok(bitcode::decode(&buf[2..len+2])?) |     Ok(bitcode::decode(&buf[2..len+2])?) | ||||||
| } | } | ||||||
| @@ -117,9 +117,11 @@ fn summary_stats(ranks: &mut [usize]) { | |||||||
|     ranks.sort_unstable(); |     ranks.sort_unstable(); | ||||||
|     let median = ranks[ranks.len() / 2] + 1; |     let median = ranks[ranks.len() / 2] + 1; | ||||||
|     let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64; |     let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64; | ||||||
|     println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean); |     println!("median {} mean {:.2} max {} min {} harmonic mean {:.2}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | const K: usize = 20; | ||||||
|  |  | ||||||
| fn main() -> Result<()> { | fn main() -> Result<()> { | ||||||
|     let args: CLIArguments = argh::from_env(); |     let args: CLIArguments = argh::from_env(); | ||||||
|  |  | ||||||
| @@ -150,8 +152,11 @@ fn main() -> Result<()> { | |||||||
|  |  | ||||||
|     println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); |     println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); | ||||||
|  |  | ||||||
|     let mut top_20_ranks_best_shard = vec![]; |     let mut top_k_ranks_best_shard = vec![]; | ||||||
|     let mut top_rank_best_shard = vec![]; |     let mut top_rank_best_shard = vec![]; | ||||||
|  |     let mut pq_cmps = vec![]; | ||||||
|  |     let mut cmps = vec![]; | ||||||
|  |     let mut recall_total = 0; | ||||||
|  |  | ||||||
|     for query_vector in queries.iter() { |     for query_vector in queries.iter() { | ||||||
|         let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>(); |         let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>(); | ||||||
| @@ -183,26 +188,30 @@ fn main() -> Result<()> { | |||||||
|             println!("brute force: {} {} {} {:?}", id, distance, url, shards); |             println!("brute force: {} {} {} {:?}", id, distance, url, shards); | ||||||
|         }*/ |         }*/ | ||||||
|  |  | ||||||
|         let mut top_ranks = vec![usize::MAX; 20]; |         let mut top_ranks = vec![usize::MAX; K]; | ||||||
|  |  | ||||||
|         for shard in 0..header.shards.len() { |         for shard in 0..header.shards.len() { | ||||||
|             let selected_start = header.shards[shard].1; |             let selected_start = header.shards[shard].1; | ||||||
|  |  | ||||||
|             let mut scratch = Scratch { |             let mut scratch = Scratch { | ||||||
|                 visited: HashSet::new(), |                 visited: HashSet::new(), | ||||||
|                 neighbour_buffer: NeighbourBuffer::new(5000), |                 neighbour_buffer: NeighbourBuffer::new(1000), | ||||||
|                 neighbour_pre_buffer: Vec::new(), |                 neighbour_pre_buffer: Vec::new(), | ||||||
|                 visited_list: Vec::new() |                 visited_list: Vec::new() | ||||||
|             }; |             }; | ||||||
|  |  | ||||||
|             //let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng); |             //let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng); | ||||||
|             let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef { |             let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef { | ||||||
|                 data_file: &mut data_file, |                 data_file: &mut data_file, | ||||||
|                 header: &header, |                 header: &header, | ||||||
|                 pq_codes: &pq_codes, |                 pq_codes: &pq_codes, | ||||||
|                 pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, |                 pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, | ||||||
|             }, args.disable_pq)?; |             }, args.disable_pq)?; | ||||||
|  |  | ||||||
|  |             // slightly dubious because this is across shards | ||||||
|  |             pq_cmps.push(cmps_result.1); | ||||||
|  |             cmps.push(cmps_result.0); | ||||||
|  |  | ||||||
|             if args.verbose { |             if args.verbose { | ||||||
|                 println!("index scan {}: {:?} cmps", shard, cmps); |                 println!("index scan {}: {:?} cmps", shard, cmps); | ||||||
|             } |             } | ||||||
| @@ -221,14 +230,26 @@ fn main() -> Result<()> { | |||||||
|             if args.verbose { println!("") } |             if args.verbose { println!("") } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         // results list is always correctly sorted | ||||||
|  |         for &rank in top_ranks.iter() { | ||||||
|  |             if rank < K { | ||||||
|  |                 recall_total += 1; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|         top_rank_best_shard.push(top_ranks[0]); |         top_rank_best_shard.push(top_ranks[0]); | ||||||
|         top_20_ranks_best_shard.extend(top_ranks); |         top_k_ranks_best_shard.extend(top_ranks); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     println!("ranks of top 20:"); |     println!("ranks of top 20:"); | ||||||
|     summary_stats(&mut top_20_ranks_best_shard); |     summary_stats(&mut top_k_ranks_best_shard); | ||||||
|     println!("ranks of top 1:"); |     println!("ranks of top 1:"); | ||||||
|     summary_stats(&mut top_rank_best_shard); |     summary_stats(&mut top_rank_best_shard); | ||||||
|  |     println!("pq comparisons:"); | ||||||
|  |     summary_stats(&mut pq_cmps); | ||||||
|  |     println!("comparisons:"); | ||||||
|  |     summary_stats(&mut cmps); | ||||||
|  |     println!("recall@{}: {}", K, recall_total as f64 / (K * queries.len()) as f64); | ||||||
|  |  | ||||||
|     Ok(()) |     Ok(()) | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 osmarks
					osmarks