mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-31 15:23:04 +00:00 
			
		
		
		
	correct DiskANN algorithm (silly bug with greedy search)
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -15,3 +15,4 @@ diskann/target | |||||||
| *.bin | *.bin | ||||||
| *.msgpack | *.msgpack | ||||||
| */flamegraph.svg | */flamegraph.svg | ||||||
|  | */*.bin | ||||||
|   | |||||||
| @@ -146,8 +146,15 @@ impl NeighbourBuffer { | |||||||
|         self.scores.truncate(self.size); |         self.scores.truncate(self.size); | ||||||
|         self.visited.truncate(self.size); |         self.visited.truncate(self.size); | ||||||
|  |  | ||||||
|  |         match self.next_unvisited { | ||||||
|  |             Some(ref mut next_unvisited) => { | ||||||
|  |                 *next_unvisited = (loc as u32).min(*next_unvisited); | ||||||
|  |             }, | ||||||
|  |             None => { | ||||||
|                 self.next_unvisited = Some(loc as u32); |                 self.next_unvisited = Some(loc as u32); | ||||||
|             } |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn clear(&mut self) { |     pub fn clear(&mut self) { | ||||||
|         self.ids.clear(); |         self.ids.clear(); | ||||||
| @@ -194,7 +201,6 @@ pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs: | |||||||
|     let mut counters = GreedySearchCounters { distances: 0 }; |     let mut counters = GreedySearchCounters { distances: 0 }; | ||||||
|  |  | ||||||
|     while let Some(pt) = scratch.neighbour_buffer.next_unvisited() { |     while let Some(pt) = scratch.neighbour_buffer.next_unvisited() { | ||||||
|         //println!("pt {} {:?}", pt, graph.out_neighbours(pt)); |  | ||||||
|         scratch.neighbour_pre_buffer.clear(); |         scratch.neighbour_pre_buffer.clear(); | ||||||
|         for &neighbour in graph.out_neighbours(pt).iter() { |         for &neighbour in graph.out_neighbours(pt).iter() { | ||||||
|             if scratch.visited.insert(neighbour) { |             if scratch.visited.insert(neighbour) { | ||||||
| @@ -296,14 +302,12 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V | |||||||
|         let neighbours = graph.out_neighbours(sigma_i).to_owned(); |         let neighbours = graph.out_neighbours(sigma_i).to_owned(); | ||||||
|         for neighbour in neighbours { |         for neighbour in neighbours { | ||||||
|             let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour); |             let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour); | ||||||
|             // To cut down pruning time slightly, allow accumulating more neighbours than usual limit |             if neighbour_neighbours.len() == config.r { | ||||||
|             if neighbour_neighbours.len() == config.r_cap { |  | ||||||
|                 let mut n = neighbour_neighbours.to_vec(); |  | ||||||
|                 scratch.visited_list.clear(); |                 scratch.visited_list.clear(); | ||||||
|                 merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config); |                 merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config); | ||||||
|                 merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config); |                 merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config); | ||||||
|                 robust_prune(scratch, neighbour, &mut n, vecs, config); |                 robust_prune(scratch, neighbour, &mut neighbour_neighbours, vecs, config); | ||||||
|             } else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap { |             } else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r { | ||||||
|                 neighbour_neighbours.push(sigma_i); |                 neighbour_neighbours.push(sigma_i); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -387,3 +391,18 @@ impl Drop for Timer { | |||||||
|         println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32()); |         println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32()); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | pub fn report_degrees(graph: &IndexGraph) { | ||||||
|  |     let mut total_degree = 0; | ||||||
|  |     let mut degrees = Vec::with_capacity(graph.graph.len()); | ||||||
|  |     for out_neighbours in graph.graph.iter() { | ||||||
|  |         let deg = out_neighbours.read().unwrap().len(); | ||||||
|  |         total_degree += deg; | ||||||
|  |         degrees.push(deg); | ||||||
|  |     } | ||||||
|  |     degrees.sort_unstable(); | ||||||
|  |     println!("average degree {}", (total_degree as f64) / (graph.graph.len() as f64)); | ||||||
|  |     println!("median degree {}", degrees[degrees.len() / 2]); | ||||||
|  |     println!("min degree {}", degrees[0]); | ||||||
|  |     println!("max degree {}", degrees[degrees.len() - 1]); | ||||||
|  | } | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ use std::{io::Read, time::Instant}; | |||||||
| use anyhow::Result; | use anyhow::Result; | ||||||
| use half::f16; | use half::f16; | ||||||
|  |  | ||||||
| use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer}; | use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer, report_degrees, random_fill_graph}; | ||||||
| use simsimd::SpatialSimilarity; | use simsimd::SpatialSimilarity; | ||||||
|  |  | ||||||
| const D_EMB: usize = 1152; | const D_EMB: usize = 1152; | ||||||
| @@ -26,12 +26,13 @@ const PQ_TEST_SIZE: usize = 1000; | |||||||
| fn main() -> Result<()> { | fn main() -> Result<()> { | ||||||
|     tracing_subscriber::fmt::init(); |     tracing_subscriber::fmt::init(); | ||||||
|  |  | ||||||
|  |     /*/ | ||||||
|     { |     { | ||||||
|         let file = std::fs::File::open("opq.msgpack")?; |         let file = std::fs::File::open("opq.msgpack")?; | ||||||
|         let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?; |         let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?; | ||||||
|         let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); |         let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); | ||||||
|         let codes = codec.quantize_batch(&input); |         let codes = codec.quantize_batch(&input); | ||||||
|         println!("{:?}", codes); |         //println!("{:?}", codes); | ||||||
|         let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); |         let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); | ||||||
|         let query = codec.preprocess_query(&raw_query); |         let query = codec.preprocess_query(&raw_query); | ||||||
|         let mut real_scores = vec![]; |         let mut real_scores = vec![]; | ||||||
| @@ -41,17 +42,17 @@ fn main() -> Result<()> { | |||||||
|         let pq_scores = codec.asymmetric_dot_product(&query, &codes); |         let pq_scores = codec.asymmetric_dot_product(&query, &codes); | ||||||
|         for (x, y) in real_scores.iter().zip(pq_scores.iter()) { |         for (x, y) in real_scores.iter().zip(pq_scores.iter()) { | ||||||
|             let y = (*y as f32) / SCALE; |             let y = (*y as f32) / SCALE; | ||||||
|             println!("{} {} {} {}", x, y, x - y, (x - y) / x); |             //println!("{} {} {} {}", x, y, x - y, (x - y) / x); | ||||||
|         } |  | ||||||
|         } |         } | ||||||
|  |     }*/ | ||||||
|  |  | ||||||
|     let mut rng = fastrand::Rng::with_seed(1); |     let mut rng = fastrand::Rng::with_seed(1); | ||||||
|  |  | ||||||
|     let n = 100000; |     let n = 100_000; | ||||||
|     let vecs = { |     let vecs = { | ||||||
|         let _timer = Timer::new("loaded vectors"); |         let _timer = Timer::new("loaded vectors"); | ||||||
|  |  | ||||||
|         &load_file("embeddings.bin", Some(D_EMB * n))? |         &load_file("query.bin", Some(D_EMB * n))? | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let (graph, medioid) = { |     let (graph, medioid) = { | ||||||
| @@ -59,10 +60,10 @@ fn main() -> Result<()> { | |||||||
|  |  | ||||||
|         let mut config = IndexBuildConfig { |         let mut config = IndexBuildConfig { | ||||||
|             r: 64, |             r: 64, | ||||||
|             r_cap: 80, |             r_cap: 64, | ||||||
|             l: 128, |             l: 192, | ||||||
|             maxc: 750, |             maxc: 750, | ||||||
|             alpha: 65536, |             alpha: 65200, | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); |         let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); | ||||||
| @@ -70,8 +71,11 @@ fn main() -> Result<()> { | |||||||
|         let medioid = medioid(&vecs); |         let medioid = medioid(&vecs); | ||||||
|  |  | ||||||
|         build_graph(&mut rng, &mut graph, medioid, &vecs, config); |         build_graph(&mut rng, &mut graph, medioid, &vecs, config); | ||||||
|         config.alpha = 58000; |         report_degrees(&graph); | ||||||
|         build_graph(&mut rng, &mut graph, medioid, &vecs, config); |         //random_fill_graph(&mut rng, &mut graph, config.r); | ||||||
|  |         //config.alpha = 65536; | ||||||
|  |         //build_graph(&mut rng, &mut graph, medioid, &vecs, config); | ||||||
|  |         report_degrees(&graph); | ||||||
|  |  | ||||||
|         (graph, medioid) |         (graph, medioid) | ||||||
|     }; |     }; | ||||||
| @@ -82,8 +86,6 @@ fn main() -> Result<()> { | |||||||
|         edge_ctr += adjlist.read().unwrap().len(); |         edge_ctr += adjlist.read().unwrap().len(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32); |  | ||||||
|  |  | ||||||
|     let time = Instant::now(); |     let time = Instant::now(); | ||||||
|     let mut recall = 0; |     let mut recall = 0; | ||||||
|     let mut cmps_ctr = 0; |     let mut cmps_ctr = 0; | ||||||
|   | |||||||
| @@ -55,7 +55,7 @@ struct CLIArguments { | |||||||
|     #[argh(switch, short='t', description="print titles")] |     #[argh(switch, short='t', description="print titles")] | ||||||
|     titles: bool, |     titles: bool, | ||||||
|     #[argh(option, description="truncate centroids list")] |     #[argh(option, description="truncate centroids list")] | ||||||
|     clip_centroids: Option<usize>, |     clip_shards: Option<usize>, | ||||||
|     #[argh(switch, description="print original linked URL")] |     #[argh(switch, description="print original linked URL")] | ||||||
|     original_url: bool, |     original_url: bool, | ||||||
|     #[argh(option, short='q', description="product quantization codec path")] |     #[argh(option, short='q', description="product quantization codec path")] | ||||||
| @@ -180,7 +180,7 @@ fn main() -> Result<()> { | |||||||
|         let centroids_data = fs::read(centroids).context("read centroids file")?; |         let centroids_data = fs::read(centroids).context("read centroids file")?; | ||||||
|         let mut centroids_data = common::decode_fp16_buffer(¢roids_data); |         let mut centroids_data = common::decode_fp16_buffer(¢roids_data); | ||||||
|  |  | ||||||
|         if let Some(clip) = args.clip_centroids { |         if let Some(clip) = args.clip_shards { | ||||||
|             centroids_data.truncate(clip * D_EMB as usize); |             centroids_data.truncate(clip * D_EMB as usize); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -209,6 +209,14 @@ fn main() -> Result<()> { | |||||||
|             let path = file.path(); |             let path = file.path(); | ||||||
|             let filename = path.file_name().unwrap().to_str().unwrap(); |             let filename = path.file_name().unwrap().to_str().unwrap(); | ||||||
|             let (fst, snd) = filename.split_once(".").unwrap(); |             let (fst, snd) = filename.split_once(".").unwrap(); | ||||||
|  |  | ||||||
|  |             let id: u32 = str::parse(fst)?; | ||||||
|  |             if let Some(clip) = args.clip_shards { | ||||||
|  |                 if id >= (clip as u32) { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|             if snd == "shard-header.msgpack" { |             if snd == "shard-header.msgpack" { | ||||||
|                 let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?; |                 let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?; | ||||||
|                 if original_ids_to_shards.len() < (header.max as usize + 1) { |                 if original_ids_to_shards.len() < (header.max as usize + 1) { | ||||||
| @@ -238,7 +246,6 @@ fn main() -> Result<()> { | |||||||
|                 shard_id_mappings.push((header.id, header.mapping)); |                 shard_id_mappings.push((header.id, header.mapping)); | ||||||
|             } else if snd == "shard.bin" { |             } else if snd == "shard.bin" { | ||||||
|                 let file = fs::File::open(&path).context("open shard file")?; |                 let file = fs::File::open(&path).context("open shard file")?; | ||||||
|                 let id: u32 = str::parse(fst)?; |  | ||||||
|                 files.push((id, file)); |                 files.push((id, file)); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -246,11 +253,16 @@ fn main() -> Result<()> { | |||||||
|         files.sort_by_key(|(id, _)| *id); |         files.sort_by_key(|(id, _)| *id); | ||||||
|         shard_id_mappings.sort_by_key(|(id, _)| *id); |         shard_id_mappings.sort_by_key(|(id, _)| *id); | ||||||
|  |  | ||||||
|  |  | ||||||
|         let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> { |         let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> { | ||||||
|             let mut out_vertices: Vec<u32> = vec![]; |             let mut out_vertices: Vec<u32> = vec![]; | ||||||
|             let mut shards: Vec<u32> = vec![]; |             let mut shards: Vec<u32> = vec![]; | ||||||
|             // look up each location in shard files |             // look up each location in shard files | ||||||
|             for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() { |             for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() { | ||||||
|  |                 if (shard, offset, len) == EMPTY_LOOKUP { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |  | ||||||
|                 shards.push(shard); |                 shards.push(shard); | ||||||
|                 let shard = shard as usize; |                 let shard = shard as usize; | ||||||
|                 // this random access is almost certainly rather slow |                 // this random access is almost certainly rather slow | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ use itertools::Itertools; | |||||||
| use std::io::{BufReader, BufWriter, Write}; | use std::io::{BufReader, BufWriter, Write}; | ||||||
| use rmp_serde::decode::Error as DecodeError; | use rmp_serde::decode::Error as DecodeError; | ||||||
| use std::fs; | use std::fs; | ||||||
| use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer}; | use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees}; | ||||||
| use half::f16; | use half::f16; | ||||||
|  |  | ||||||
| mod common; | mod common; | ||||||
| @@ -12,19 +12,6 @@ use common::{ShardInputHeader, ShardedRecord, ShardHeader}; | |||||||
|  |  | ||||||
| const D_EMB: usize = 1152; | const D_EMB: usize = 1152; | ||||||
|  |  | ||||||
| fn report_degrees(graph: &IndexGraph) { |  | ||||||
|     let mut total_degree = 0; |  | ||||||
|     let mut degrees = Vec::with_capacity(graph.graph.len()); |  | ||||||
|     for out_neighbours in graph.graph.iter() { |  | ||||||
|         let deg = out_neighbours.read().unwrap().len(); |  | ||||||
|         total_degree += deg; |  | ||||||
|         degrees.push(deg); |  | ||||||
|     } |  | ||||||
|     degrees.sort_unstable(); |  | ||||||
|     println!("average degree {}", (total_degree as f32) / (graph.graph.len() as f32)); |  | ||||||
|     println!("median degree {}", degrees[degrees.len() / 2]); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| fn main() -> Result<()> { | fn main() -> Result<()> { | ||||||
|     let mut rng = fastrand::Rng::new(); |     let mut rng = fastrand::Rng::new(); | ||||||
|  |  | ||||||
| @@ -55,10 +42,10 @@ fn main() -> Result<()> { | |||||||
|  |  | ||||||
|     let mut config = IndexBuildConfig { |     let mut config = IndexBuildConfig { | ||||||
|         r: 64, |         r: 64, | ||||||
|         r_cap: 80, |         r_cap: 64, | ||||||
|         l: 500, |         l: 200, | ||||||
|         maxc: 950, |         maxc: 750, | ||||||
|         alpha: 65536 |         alpha: 65300 | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let vecs = VectorList { |     let vecs = VectorList { | ||||||
| @@ -93,12 +80,12 @@ fn main() -> Result<()> { | |||||||
|     report_degrees(&graph); |     report_degrees(&graph); | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         let _timer = Timer::new("second pass"); |         //let _timer = Timer::new("second pass"); | ||||||
|         config.alpha = 60000; |         //config.alpha = 62000; | ||||||
|         //build_graph(&mut rng, &mut graph, medioid, &vecs, config); |         //build_graph(&mut rng, &mut graph, medioid, &vecs, config); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     report_degrees(&graph); |     //report_degrees(&graph); | ||||||
|  |  | ||||||
|     std::mem::drop(vecs); |     std::mem::drop(vecs); | ||||||
|  |  | ||||||
| @@ -115,7 +102,7 @@ fn main() -> Result<()> { | |||||||
|  |  | ||||||
|     { |     { | ||||||
|         let _timer = Timer::new("augment bipartite"); |         let _timer = Timer::new("augment bipartite"); | ||||||
|         //augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config); |         augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     let len = original_ids.len(); |     let len = original_ids.len(); | ||||||
|   | |||||||
| @@ -7,7 +7,6 @@ use std::fs; | |||||||
| use base64::Engine; | use base64::Engine; | ||||||
| use argh::FromArgs; | use argh::FromArgs; | ||||||
| use chrono::{TimeZone, Utc, DateTime}; | use chrono::{TimeZone, Utc, DateTime}; | ||||||
| use std::collections::VecDeque; |  | ||||||
| use itertools::Itertools; | use itertools::Itertools; | ||||||
| use foldhash::{HashSet, HashSetExt}; | use foldhash::{HashSet, HashSetExt}; | ||||||
| use half::f16; | use half::f16; | ||||||
| @@ -23,9 +22,17 @@ use common::{PackedIndexEntry, IndexHeader}; | |||||||
| #[argh(description="Query disk index")] | #[argh(description="Query disk index")] | ||||||
| struct CLIArguments { | struct CLIArguments { | ||||||
|     #[argh(positional)] |     #[argh(positional)] | ||||||
|     query_vector: String, |     index_path: String, | ||||||
|     #[argh(positional)] |     #[argh(option, short='q', description="query vector in base64")] | ||||||
|     index_path: String |     query_vector_base64: Option<String>, | ||||||
|  |     #[argh(option, short='f', description="file of FP16 query vectors")] | ||||||
|  |     query_vector_file: Option<String>, | ||||||
|  |     #[argh(switch, short='v', description="verbose")] | ||||||
|  |     verbose: bool, | ||||||
|  |     #[argh(option, short='n', description="stop at n queries")] | ||||||
|  |     n: Option<usize>, | ||||||
|  |     #[argh(switch, description="always use full-precision vectors (slow)")] | ||||||
|  |     disable_pq: bool | ||||||
| } | } | ||||||
|  |  | ||||||
| fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> { | fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> { | ||||||
| @@ -56,7 +63,7 @@ struct IndexRef<'a> { | |||||||
|     pq_code_size: usize |     pq_code_size: usize | ||||||
| } | } | ||||||
|  |  | ||||||
| fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef) -> Result<(usize, usize)> { | fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef, disable_pq: bool) -> Result<(usize, usize)> { | ||||||
|     scratch.visited.clear(); |     scratch.visited.clear(); | ||||||
|     scratch.neighbour_buffer.clear(); |     scratch.neighbour_buffer.clear(); | ||||||
|     scratch.visited_list.clear(); |     scratch.visited_list.clear(); | ||||||
| @@ -88,24 +95,48 @@ fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preproc | |||||||
|         } |         } | ||||||
|         let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes); |         let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes); | ||||||
|         for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { |         for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { | ||||||
|  |             if disable_pq { | ||||||
|                 //let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO |                 //let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO | ||||||
|             //let node = read_node(neighbour, index.data_file, index.header)?; |                 let node = read_node(neighbour, index.data_file, index.header)?; | ||||||
|             //let vector = bytemuck::cast_slice(&node.vector); |                 let vector = bytemuck::cast_slice(&node.vector); | ||||||
|             //let distance = fast_dot_noprefetch(query, &vector); |                 let distance = fast_dot_noprefetch(query, &vector); | ||||||
|             pq_cmps += 1; |                 scratch.neighbour_buffer.insert(neighbour, distance); | ||||||
|  |             } else { | ||||||
|                 scratch.neighbour_buffer.insert(neighbour, approx_scores[i]); |                 scratch.neighbour_buffer.insert(neighbour, approx_scores[i]); | ||||||
|             //scratch.neighbour_buffer.insert(neighbour, distance); |                 pq_cmps += 1; | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     Ok((cmps, pq_cmps)) |     Ok((cmps, pq_cmps)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | fn summary_stats(ranks: &mut [usize]) { | ||||||
|  |     let sum = ranks.iter().sum::<usize>(); | ||||||
|  |     let mean = sum as f64 / ranks.len() as f64 + 1.0; | ||||||
|  |     ranks.sort_unstable(); | ||||||
|  |     let median = ranks[ranks.len() / 2] + 1; | ||||||
|  |     let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64; | ||||||
|  |     println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean); | ||||||
|  | } | ||||||
|  |  | ||||||
| fn main() -> Result<()> { | fn main() -> Result<()> { | ||||||
|     let args: CLIArguments = argh::from_env(); |     let args: CLIArguments = argh::from_env(); | ||||||
|  |  | ||||||
|     let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(args.query_vector.as_bytes()).context("invalid base64")?); |     let mut queries = vec![]; | ||||||
|     let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>(); |  | ||||||
|  |     if let Some(query_vector_base64) = args.query_vector_base64 { | ||||||
|  |         let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(query_vector_base64.as_bytes()).context("invalid base64")?); | ||||||
|  |         queries.push(query_vector); | ||||||
|  |     } | ||||||
|  |     if let Some(query_vector_file) = args.query_vector_file { | ||||||
|  |         let query_vectors = fs::read(query_vector_file)?; | ||||||
|  |         queries.extend(common::chunk_fp16_buffer(&query_vectors).chunks(1152).map(|x| x.to_vec()).collect::<Vec<_>>()); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if let Some(n) = args.n { | ||||||
|  |         queries.truncate(n); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     let index_path = PathBuf::from(&args.index_path); |     let index_path = PathBuf::from(&args.index_path); | ||||||
|     let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?; |     let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?; | ||||||
| @@ -117,16 +148,42 @@ fn main() -> Result<()> { | |||||||
|         MmapOptions::new().populate().map(&pq_codes_file)? |         MmapOptions::new().populate().map(&pq_codes_file)? | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32); |  | ||||||
|  |  | ||||||
|     println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); |     println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); | ||||||
|  |  | ||||||
|  |     let mut top_20_ranks_best_shard = vec![]; | ||||||
|  |     let mut top_rank_best_shard = vec![]; | ||||||
|  |  | ||||||
|  |     for query_vector in queries.iter() { | ||||||
|  |         let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>(); | ||||||
|  |         let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32); | ||||||
|  |  | ||||||
|         // TODO slightly dubious |         // TODO slightly dubious | ||||||
|         let selected_shard = header.shards.iter().position_max_by_key(|x| { |         let selected_shard = header.shards.iter().position_max_by_key(|x| { | ||||||
|             scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap()) |             scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap()) | ||||||
|         }).unwrap(); |         }).unwrap(); | ||||||
|  |  | ||||||
|     println!("best shard is {}", selected_shard); |         if args.verbose { | ||||||
|  |             println!("selected shard is {}", selected_shard); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         let mut matches = vec![]; | ||||||
|  |         // brute force scan | ||||||
|  |         for i in 0..header.count { | ||||||
|  |             let node = read_node(i, &mut data_file, &header)?; | ||||||
|  |             //println!("{} {}", i, node.url); | ||||||
|  |             let vector = bytemuck::cast_slice(&node.vector); | ||||||
|  |             matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         matches.sort_unstable_by_key(|x| -x.1); | ||||||
|  |         let mut matches = matches.into_iter().enumerate().map(|(i, (id, distance, url, shards))| (id, i)).collect::<Vec<_>>(); | ||||||
|  |         matches.sort_unstable(); | ||||||
|  |  | ||||||
|  |         /*for (id, distance, url, shards) in matches.iter().take(20) { | ||||||
|  |             println!("brute force: {} {} {} {:?}", id, distance, url, shards); | ||||||
|  |         }*/ | ||||||
|  |  | ||||||
|  |         let mut top_ranks = vec![usize::MAX; 20]; | ||||||
|  |  | ||||||
|         for shard in 0..header.shards.len() { |         for shard in 0..header.shards.len() { | ||||||
|             let selected_start = header.shards[shard].1; |             let selected_start = header.shards[shard].1; | ||||||
| @@ -144,30 +201,37 @@ fn main() -> Result<()> { | |||||||
|                 header: &header, |                 header: &header, | ||||||
|                 pq_codes: &pq_codes, |                 pq_codes: &pq_codes, | ||||||
|                 pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, |                 pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, | ||||||
|         })?; |             }, args.disable_pq)?; | ||||||
|  |  | ||||||
|  |             if args.verbose { | ||||||
|                 println!("index scan {}: {:?} cmps", shard, cmps); |                 println!("index scan {}: {:?} cmps", shard, cmps); | ||||||
|  |             } | ||||||
|  |  | ||||||
|             scratch.visited_list.sort_by_key(|x| -x.1); |             scratch.visited_list.sort_by_key(|x| -x.1); | ||||||
|         for (id, distance, url, shards) in scratch.visited_list.iter().take(20) { |             for (i, (id, distance, url, shards)) in scratch.visited_list.iter().take(20).enumerate() { | ||||||
|  |                 if args.verbose { | ||||||
|                     println!("index scan: {} {} {} {:?}", id, distance, url, shards); |                     println!("index scan: {} {} {} {:?}", id, distance, url, shards); | ||||||
|  |                 }; | ||||||
|  |                 let found_id = match matches.binary_search(&(*id, 0)) { | ||||||
|  |                     Ok(pos) => pos, | ||||||
|  |                     Err(pos) => pos | ||||||
|  |                 }; | ||||||
|  |                 if args.verbose { | ||||||
|  |                     println!("rank {}", matches[found_id].1); | ||||||
|  |                 }; | ||||||
|  |                 top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1); | ||||||
|             } |             } | ||||||
|         println!(""); |             if args.verbose { println!("") } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     let mut matches = vec![]; |         top_rank_best_shard.push(top_ranks[0]); | ||||||
|     // brute force scan |         top_20_ranks_best_shard.extend(top_ranks); | ||||||
|     for i in 0..header.count { |  | ||||||
|         let node = read_node(i, &mut data_file, &header)?; |  | ||||||
|         //println!("{} {}", i, node.url); |  | ||||||
|         let vector = bytemuck::cast_slice(&node.vector); |  | ||||||
|         matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards)); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     matches.sort_by_key(|x| -x.1); |     println!("ranks of top 20:"); | ||||||
|     for (id, distance, url, shards) in matches.iter().take(20) { |     summary_stats(&mut top_20_ranks_best_shard); | ||||||
|         println!("brute force: {} {} {} {:?}", id, distance, url, shards); |     println!("ranks of top 1:"); | ||||||
|     } |     summary_stats(&mut top_rank_best_shard); | ||||||
|  |  | ||||||
|     Ok(()) |     Ok(()) | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 osmarks
					osmarks