1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-02-23 06:20:04 +00:00

correct DiskANN algorithm (silly bug with greedy search)

This commit is contained in:
osmarks 2025-01-11 07:35:04 +00:00
parent e9ee563381
commit 8ce51bcb56
6 changed files with 188 additions and 103 deletions

1
.gitignore vendored
View File

@ -15,3 +15,4 @@ diskann/target
*.bin *.bin
*.msgpack *.msgpack
*/flamegraph.svg */flamegraph.svg
*/*.bin

View File

@ -146,8 +146,15 @@ impl NeighbourBuffer {
self.scores.truncate(self.size); self.scores.truncate(self.size);
self.visited.truncate(self.size); self.visited.truncate(self.size);
match self.next_unvisited {
Some(ref mut next_unvisited) => {
*next_unvisited = (loc as u32).min(*next_unvisited);
},
None => {
self.next_unvisited = Some(loc as u32); self.next_unvisited = Some(loc as u32);
} }
}
}
pub fn clear(&mut self) { pub fn clear(&mut self) {
self.ids.clear(); self.ids.clear();
@ -194,7 +201,6 @@ pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs:
let mut counters = GreedySearchCounters { distances: 0 }; let mut counters = GreedySearchCounters { distances: 0 };
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() { while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
//println!("pt {} {:?}", pt, graph.out_neighbours(pt));
scratch.neighbour_pre_buffer.clear(); scratch.neighbour_pre_buffer.clear();
for &neighbour in graph.out_neighbours(pt).iter() { for &neighbour in graph.out_neighbours(pt).iter() {
if scratch.visited.insert(neighbour) { if scratch.visited.insert(neighbour) {
@ -296,14 +302,12 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
let neighbours = graph.out_neighbours(sigma_i).to_owned(); let neighbours = graph.out_neighbours(sigma_i).to_owned();
for neighbour in neighbours { for neighbour in neighbours {
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour); let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
// To cut down pruning time slightly, allow accumulating more neighbours than usual limit if neighbour_neighbours.len() == config.r {
if neighbour_neighbours.len() == config.r_cap {
let mut n = neighbour_neighbours.to_vec();
scratch.visited_list.clear(); scratch.visited_list.clear();
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config); merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config); merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config);
robust_prune(scratch, neighbour, &mut n, vecs, config); robust_prune(scratch, neighbour, &mut neighbour_neighbours, vecs, config);
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap { } else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r {
neighbour_neighbours.push(sigma_i); neighbour_neighbours.push(sigma_i);
} }
} }
@ -387,3 +391,18 @@ impl Drop for Timer {
println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32()); println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32());
} }
} }
pub fn report_degrees(graph: &IndexGraph) {
let mut total_degree = 0;
let mut degrees = Vec::with_capacity(graph.graph.len());
for out_neighbours in graph.graph.iter() {
let deg = out_neighbours.read().unwrap().len();
total_degree += deg;
degrees.push(deg);
}
degrees.sort_unstable();
println!("average degree {}", (total_degree as f64) / (graph.graph.len() as f64));
println!("median degree {}", degrees[degrees.len() / 2]);
println!("min degree {}", degrees[0]);
println!("max degree {}", degrees[degrees.len() - 1]);
}

View File

@ -7,7 +7,7 @@ use std::{io::Read, time::Instant};
use anyhow::Result; use anyhow::Result;
use half::f16; use half::f16;
use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer}; use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer, report_degrees, random_fill_graph};
use simsimd::SpatialSimilarity; use simsimd::SpatialSimilarity;
const D_EMB: usize = 1152; const D_EMB: usize = 1152;
@ -26,12 +26,13 @@ const PQ_TEST_SIZE: usize = 1000;
fn main() -> Result<()> { fn main() -> Result<()> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
/*/
{ {
let file = std::fs::File::open("opq.msgpack")?; let file = std::fs::File::open("opq.msgpack")?;
let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?; let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?;
let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
let codes = codec.quantize_batch(&input); let codes = codec.quantize_batch(&input);
println!("{:?}", codes); //println!("{:?}", codes);
let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
let query = codec.preprocess_query(&raw_query); let query = codec.preprocess_query(&raw_query);
let mut real_scores = vec![]; let mut real_scores = vec![];
@ -41,17 +42,17 @@ fn main() -> Result<()> {
let pq_scores = codec.asymmetric_dot_product(&query, &codes); let pq_scores = codec.asymmetric_dot_product(&query, &codes);
for (x, y) in real_scores.iter().zip(pq_scores.iter()) { for (x, y) in real_scores.iter().zip(pq_scores.iter()) {
let y = (*y as f32) / SCALE; let y = (*y as f32) / SCALE;
println!("{} {} {} {}", x, y, x - y, (x - y) / x); //println!("{} {} {} {}", x, y, x - y, (x - y) / x);
}
} }
}*/
let mut rng = fastrand::Rng::with_seed(1); let mut rng = fastrand::Rng::with_seed(1);
let n = 100000; let n = 100_000;
let vecs = { let vecs = {
let _timer = Timer::new("loaded vectors"); let _timer = Timer::new("loaded vectors");
&load_file("embeddings.bin", Some(D_EMB * n))? &load_file("query.bin", Some(D_EMB * n))?
}; };
let (graph, medioid) = { let (graph, medioid) = {
@ -59,10 +60,10 @@ fn main() -> Result<()> {
let mut config = IndexBuildConfig { let mut config = IndexBuildConfig {
r: 64, r: 64,
r_cap: 80, r_cap: 64,
l: 128, l: 192,
maxc: 750, maxc: 750,
alpha: 65536, alpha: 65200,
}; };
let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap);
@ -70,8 +71,11 @@ fn main() -> Result<()> {
let medioid = medioid(&vecs); let medioid = medioid(&vecs);
build_graph(&mut rng, &mut graph, medioid, &vecs, config); build_graph(&mut rng, &mut graph, medioid, &vecs, config);
config.alpha = 58000; report_degrees(&graph);
build_graph(&mut rng, &mut graph, medioid, &vecs, config); //random_fill_graph(&mut rng, &mut graph, config.r);
//config.alpha = 65536;
//build_graph(&mut rng, &mut graph, medioid, &vecs, config);
report_degrees(&graph);
(graph, medioid) (graph, medioid)
}; };
@ -82,8 +86,6 @@ fn main() -> Result<()> {
edge_ctr += adjlist.read().unwrap().len(); edge_ctr += adjlist.read().unwrap().len();
} }
println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32);
let time = Instant::now(); let time = Instant::now();
let mut recall = 0; let mut recall = 0;
let mut cmps_ctr = 0; let mut cmps_ctr = 0;

View File

@ -55,7 +55,7 @@ struct CLIArguments {
#[argh(switch, short='t', description="print titles")] #[argh(switch, short='t', description="print titles")]
titles: bool, titles: bool,
#[argh(option, description="truncate centroids list")] #[argh(option, description="truncate centroids list")]
clip_centroids: Option<usize>, clip_shards: Option<usize>,
#[argh(switch, description="print original linked URL")] #[argh(switch, description="print original linked URL")]
original_url: bool, original_url: bool,
#[argh(option, short='q', description="product quantization codec path")] #[argh(option, short='q', description="product quantization codec path")]
@ -180,7 +180,7 @@ fn main() -> Result<()> {
let centroids_data = fs::read(centroids).context("read centroids file")?; let centroids_data = fs::read(centroids).context("read centroids file")?;
let mut centroids_data = common::decode_fp16_buffer(&centroids_data); let mut centroids_data = common::decode_fp16_buffer(&centroids_data);
if let Some(clip) = args.clip_centroids { if let Some(clip) = args.clip_shards {
centroids_data.truncate(clip * D_EMB as usize); centroids_data.truncate(clip * D_EMB as usize);
} }
@ -209,6 +209,14 @@ fn main() -> Result<()> {
let path = file.path(); let path = file.path();
let filename = path.file_name().unwrap().to_str().unwrap(); let filename = path.file_name().unwrap().to_str().unwrap();
let (fst, snd) = filename.split_once(".").unwrap(); let (fst, snd) = filename.split_once(".").unwrap();
let id: u32 = str::parse(fst)?;
if let Some(clip) = args.clip_shards {
if id >= (clip as u32) {
continue;
}
}
if snd == "shard-header.msgpack" { if snd == "shard-header.msgpack" {
let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?; let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?;
if original_ids_to_shards.len() < (header.max as usize + 1) { if original_ids_to_shards.len() < (header.max as usize + 1) {
@ -238,7 +246,6 @@ fn main() -> Result<()> {
shard_id_mappings.push((header.id, header.mapping)); shard_id_mappings.push((header.id, header.mapping));
} else if snd == "shard.bin" { } else if snd == "shard.bin" {
let file = fs::File::open(&path).context("open shard file")?; let file = fs::File::open(&path).context("open shard file")?;
let id: u32 = str::parse(fst)?;
files.push((id, file)); files.push((id, file));
} }
} }
@ -246,11 +253,16 @@ fn main() -> Result<()> {
files.sort_by_key(|(id, _)| *id); files.sort_by_key(|(id, _)| *id);
shard_id_mappings.sort_by_key(|(id, _)| *id); shard_id_mappings.sort_by_key(|(id, _)| *id);
let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> { let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
let mut out_vertices: Vec<u32> = vec![]; let mut out_vertices: Vec<u32> = vec![];
let mut shards: Vec<u32> = vec![]; let mut shards: Vec<u32> = vec![];
// look up each location in shard files // look up each location in shard files
for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() { for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() {
if (shard, offset, len) == EMPTY_LOOKUP {
continue;
}
shards.push(shard); shards.push(shard);
let shard = shard as usize; let shard = shard as usize;
// this random access is almost certainly rather slow // this random access is almost certainly rather slow

View File

@ -3,7 +3,7 @@ use itertools::Itertools;
use std::io::{BufReader, BufWriter, Write}; use std::io::{BufReader, BufWriter, Write};
use rmp_serde::decode::Error as DecodeError; use rmp_serde::decode::Error as DecodeError;
use std::fs; use std::fs;
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer}; use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees};
use half::f16; use half::f16;
mod common; mod common;
@ -12,19 +12,6 @@ use common::{ShardInputHeader, ShardedRecord, ShardHeader};
const D_EMB: usize = 1152; const D_EMB: usize = 1152;
fn report_degrees(graph: &IndexGraph) {
let mut total_degree = 0;
let mut degrees = Vec::with_capacity(graph.graph.len());
for out_neighbours in graph.graph.iter() {
let deg = out_neighbours.read().unwrap().len();
total_degree += deg;
degrees.push(deg);
}
degrees.sort_unstable();
println!("average degree {}", (total_degree as f32) / (graph.graph.len() as f32));
println!("median degree {}", degrees[degrees.len() / 2]);
}
fn main() -> Result<()> { fn main() -> Result<()> {
let mut rng = fastrand::Rng::new(); let mut rng = fastrand::Rng::new();
@ -55,10 +42,10 @@ fn main() -> Result<()> {
let mut config = IndexBuildConfig { let mut config = IndexBuildConfig {
r: 64, r: 64,
r_cap: 80, r_cap: 64,
l: 500, l: 200,
maxc: 950, maxc: 750,
alpha: 65536 alpha: 65300
}; };
let vecs = VectorList { let vecs = VectorList {
@ -93,12 +80,12 @@ fn main() -> Result<()> {
report_degrees(&graph); report_degrees(&graph);
{ {
let _timer = Timer::new("second pass"); //let _timer = Timer::new("second pass");
config.alpha = 60000; //config.alpha = 62000;
//build_graph(&mut rng, &mut graph, medioid, &vecs, config); //build_graph(&mut rng, &mut graph, medioid, &vecs, config);
} }
report_degrees(&graph); //report_degrees(&graph);
std::mem::drop(vecs); std::mem::drop(vecs);
@ -115,7 +102,7 @@ fn main() -> Result<()> {
{ {
let _timer = Timer::new("augment bipartite"); let _timer = Timer::new("augment bipartite");
//augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config); augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config);
} }
let len = original_ids.len(); let len = original_ids.len();

View File

@ -7,7 +7,6 @@ use std::fs;
use base64::Engine; use base64::Engine;
use argh::FromArgs; use argh::FromArgs;
use chrono::{TimeZone, Utc, DateTime}; use chrono::{TimeZone, Utc, DateTime};
use std::collections::VecDeque;
use itertools::Itertools; use itertools::Itertools;
use foldhash::{HashSet, HashSetExt}; use foldhash::{HashSet, HashSetExt};
use half::f16; use half::f16;
@ -23,9 +22,17 @@ use common::{PackedIndexEntry, IndexHeader};
#[argh(description="Query disk index")] #[argh(description="Query disk index")]
struct CLIArguments { struct CLIArguments {
#[argh(positional)] #[argh(positional)]
query_vector: String, index_path: String,
#[argh(positional)] #[argh(option, short='q', description="query vector in base64")]
index_path: String query_vector_base64: Option<String>,
#[argh(option, short='f', description="file of FP16 query vectors")]
query_vector_file: Option<String>,
#[argh(switch, short='v', description="verbose")]
verbose: bool,
#[argh(option, short='n', description="stop at n queries")]
n: Option<usize>,
#[argh(switch, description="always use full-precision vectors (slow)")]
disable_pq: bool
} }
fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> { fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> {
@ -56,7 +63,7 @@ struct IndexRef<'a> {
pq_code_size: usize pq_code_size: usize
} }
fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef) -> Result<(usize, usize)> { fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef, disable_pq: bool) -> Result<(usize, usize)> {
scratch.visited.clear(); scratch.visited.clear();
scratch.neighbour_buffer.clear(); scratch.neighbour_buffer.clear();
scratch.visited_list.clear(); scratch.visited_list.clear();
@ -88,24 +95,48 @@ fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preproc
} }
let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes); let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes);
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
if disable_pq {
//let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO //let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
//let node = read_node(neighbour, index.data_file, index.header)?; let node = read_node(neighbour, index.data_file, index.header)?;
//let vector = bytemuck::cast_slice(&node.vector); let vector = bytemuck::cast_slice(&node.vector);
//let distance = fast_dot_noprefetch(query, &vector); let distance = fast_dot_noprefetch(query, &vector);
pq_cmps += 1; scratch.neighbour_buffer.insert(neighbour, distance);
} else {
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]); scratch.neighbour_buffer.insert(neighbour, approx_scores[i]);
//scratch.neighbour_buffer.insert(neighbour, distance); pq_cmps += 1;
}
} }
} }
Ok((cmps, pq_cmps)) Ok((cmps, pq_cmps))
} }
fn summary_stats(ranks: &mut [usize]) {
let sum = ranks.iter().sum::<usize>();
let mean = sum as f64 / ranks.len() as f64 + 1.0;
ranks.sort_unstable();
let median = ranks[ranks.len() / 2] + 1;
let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64;
println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean);
}
fn main() -> Result<()> { fn main() -> Result<()> {
let args: CLIArguments = argh::from_env(); let args: CLIArguments = argh::from_env();
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(args.query_vector.as_bytes()).context("invalid base64")?); let mut queries = vec![];
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
if let Some(query_vector_base64) = args.query_vector_base64 {
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(query_vector_base64.as_bytes()).context("invalid base64")?);
queries.push(query_vector);
}
if let Some(query_vector_file) = args.query_vector_file {
let query_vectors = fs::read(query_vector_file)?;
queries.extend(common::chunk_fp16_buffer(&query_vectors).chunks(1152).map(|x| x.to_vec()).collect::<Vec<_>>());
}
if let Some(n) = args.n {
queries.truncate(n);
}
let index_path = PathBuf::from(&args.index_path); let index_path = PathBuf::from(&args.index_path);
let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?; let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?;
@ -117,16 +148,42 @@ fn main() -> Result<()> {
MmapOptions::new().populate().map(&pq_codes_file)? MmapOptions::new().populate().map(&pq_codes_file)?
}; };
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32);
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
let mut top_20_ranks_best_shard = vec![];
let mut top_rank_best_shard = vec![];
for query_vector in queries.iter() {
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32);
// TODO slightly dubious // TODO slightly dubious
let selected_shard = header.shards.iter().position_max_by_key(|x| { let selected_shard = header.shards.iter().position_max_by_key(|x| {
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap()) scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
}).unwrap(); }).unwrap();
println!("best shard is {}", selected_shard); if args.verbose {
println!("selected shard is {}", selected_shard);
}
let mut matches = vec![];
// brute force scan
for i in 0..header.count {
let node = read_node(i, &mut data_file, &header)?;
//println!("{} {}", i, node.url);
let vector = bytemuck::cast_slice(&node.vector);
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
}
matches.sort_unstable_by_key(|x| -x.1);
let mut matches = matches.into_iter().enumerate().map(|(i, (id, distance, url, shards))| (id, i)).collect::<Vec<_>>();
matches.sort_unstable();
/*for (id, distance, url, shards) in matches.iter().take(20) {
println!("brute force: {} {} {} {:?}", id, distance, url, shards);
}*/
let mut top_ranks = vec![usize::MAX; 20];
for shard in 0..header.shards.len() { for shard in 0..header.shards.len() {
let selected_start = header.shards[shard].1; let selected_start = header.shards[shard].1;
@ -144,30 +201,37 @@ fn main() -> Result<()> {
header: &header, header: &header,
pq_codes: &pq_codes, pq_codes: &pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
})?; }, args.disable_pq)?;
if args.verbose {
println!("index scan {}: {:?} cmps", shard, cmps); println!("index scan {}: {:?} cmps", shard, cmps);
}
scratch.visited_list.sort_by_key(|x| -x.1); scratch.visited_list.sort_by_key(|x| -x.1);
for (id, distance, url, shards) in scratch.visited_list.iter().take(20) { for (i, (id, distance, url, shards)) in scratch.visited_list.iter().take(20).enumerate() {
if args.verbose {
println!("index scan: {} {} {} {:?}", id, distance, url, shards); println!("index scan: {} {} {} {:?}", id, distance, url, shards);
};
let found_id = match matches.binary_search(&(*id, 0)) {
Ok(pos) => pos,
Err(pos) => pos
};
if args.verbose {
println!("rank {}", matches[found_id].1);
};
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
} }
println!(""); if args.verbose { println!("") }
} }
let mut matches = vec![]; top_rank_best_shard.push(top_ranks[0]);
// brute force scan top_20_ranks_best_shard.extend(top_ranks);
for i in 0..header.count {
let node = read_node(i, &mut data_file, &header)?;
//println!("{} {}", i, node.url);
let vector = bytemuck::cast_slice(&node.vector);
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
} }
matches.sort_by_key(|x| -x.1); println!("ranks of top 20:");
for (id, distance, url, shards) in matches.iter().take(20) { summary_stats(&mut top_20_ranks_best_shard);
println!("brute force: {} {} {} {:?}", id, distance, url, shards); println!("ranks of top 1:");
} summary_stats(&mut top_rank_best_shard);
Ok(()) Ok(())
} }