1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-06-06 16:34:06 +00:00

fix entire index algorithm (very silly bug)

This commit is contained in:
osmarks 2025-01-12 19:48:53 +00:00
parent 0a196694b1
commit 4dd97631df
6 changed files with 110 additions and 86 deletions

7
diskann/Cargo.lock generated
View File

@ -146,6 +146,7 @@ dependencies = [
"foldhash", "foldhash",
"half", "half",
"matrixmultiply", "matrixmultiply",
"min-max-heap",
"rayon", "rayon",
"rmp-serde", "rmp-serde",
"serde", "serde",
@ -228,6 +229,12 @@ dependencies = [
"rawpointer", "rawpointer",
] ]
[[package]]
name = "min-max-heap"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18"
[[package]] [[package]]
name = "mio" name = "mio"
version = "0.8.11" version = "0.8.11"

View File

@ -19,20 +19,6 @@ pub struct IndexGraph {
} }
impl IndexGraph { impl IndexGraph {
pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self {
let mut graph = Vec::with_capacity(n);
for _ in 0..n {
let mut adjacency = Vec::with_capacity(capacity);
for _ in 0..r {
adjacency.push(rng.u32(0..(n as u32)));
}
graph.push(RwLock::new(adjacency));
}
IndexGraph {
graph
}
}
pub fn empty(n: usize, capacity: usize) -> IndexGraph { pub fn empty(n: usize, capacity: usize) -> IndexGraph {
let mut graph = Vec::with_capacity(n); let mut graph = Vec::with_capacity(n);
for _ in 0..n { for _ in 0..n {
@ -57,7 +43,8 @@ pub struct IndexBuildConfig {
pub r: usize, pub r: usize,
pub l: usize, pub l: usize,
pub maxc: usize, pub maxc: usize,
pub alpha: i64 pub alpha: i64,
pub saturate_graph: bool
} }
@ -79,6 +66,7 @@ pub fn medioid(vecs: &VectorList) -> u32 {
// neighbours list sorted by score descending // neighbours list sorted by score descending
// TODO: this may actually be an awful datastructure // TODO: this may actually be an awful datastructure
// we could also have a heap of unvisited things, but the algorithm's stopping condition cares about visited things, and this is probably still the easiest way
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct NeighbourBuffer { pub struct NeighbourBuffer {
pub ids: Vec<u32>, pub ids: Vec<u32>,
@ -244,8 +232,9 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
let mut candidate_index = 0; let mut candidate_index = 0;
while neigh.len() < config.r && candidate_index < candidates.len() { while neigh.len() < config.r && candidate_index < candidates.len() {
let p_star = candidates[candidate_index].0; let p_star = candidates[candidate_index].0;
let p_star_score = candidates[candidate_index].1;
candidate_index += 1; candidate_index += 1;
if p_star == p || p_star == u32::MAX { if p_star == p || p_star_score == i64::MIN {
continue; continue;
} }
@ -256,7 +245,7 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
// mark remaining candidates as not-to-be-used if "not much better than" current candidate // mark remaining candidates as not-to-be-used if "not much better than" current candidate
for i in (candidate_index+1)..candidates.len() { for i in (candidate_index+1)..candidates.len() {
let p_prime = candidates[i].0; let p_prime = candidates[i].0;
if p_prime != u32::MAX { if candidates[i].1 != i64::MIN {
scratch.robust_prune_scratch_buffer.push((i, p_prime)); scratch.robust_prune_scratch_buffer.push((i, p_prime));
} }
} }
@ -268,7 +257,18 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16; let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16;
if alpha_times_p_star_prime_score >= p_prime_p_score { if alpha_times_p_star_prime_score >= p_prime_p_score {
candidates[ci].0 = u32::MAX; candidates[ci].1 = i64::MIN;
}
}
}
if config.saturate_graph {
for &(id, _score) in candidates.iter() {
if neigh.len() == config.r {
return;
}
if !neigh.contains(&id) {
neigh.push(id);
} }
} }
} }
@ -313,38 +313,7 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
}); });
} }
// RoarGraph's AcquireNeighbours algorithm is actually almost identical to Vamana/DiskANN's RobustPrune, but with fixed α = 1.0. pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig, max_iters: usize) {
// We replace Vamana's random initialization of the graph with Neighbourhood-Aware Projection from RoarGraph - there's no way to use a large enough
// query set that I would be confident in using *only* RoarGraph's algorithm
pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec<Vec<u32>>, query_knns_bwd: &Vec<Vec<u32>>, config: IndexBuildConfig, vecs: &VectorList) {
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
rng.shuffle(&mut sigmas);
// Iterate through graph vertices in a random order
let rng = Mutex::new(rng.fork());
sigmas.into_par_iter().for_each_init(|| (rng.lock().unwrap().fork(), Scratch::new(config)), |(rng, scratch), sigma_i| {
scratch.visited.clear();
scratch.visited_list.clear();
scratch.neighbour_pre_buffer.clear();
for &query_neighbour in query_knns[sigma_i as usize].iter() {
for &projected_neighbour in query_knns_bwd[query_neighbour as usize].iter() {
if scratch.visited.insert(projected_neighbour) {
scratch.neighbour_pre_buffer.push(projected_neighbour);
}
}
}
rng.shuffle(&mut scratch.neighbour_pre_buffer);
scratch.neighbour_pre_buffer.truncate(config.maxc * 2);
for (i, &projected_neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
let score = fast_dot(&vecs[sigma_i as usize], &vecs[projected_neighbour as usize], &vecs[scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()] as usize]);
scratch.visited_list.push((projected_neighbour, score));
}
let mut neighbours = graph.out_neighbours_mut(sigma_i);
robust_prune(scratch, sigma_i, &mut *neighbours, vecs, config);
})
}
pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig) {
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect(); let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
rng.shuffle(&mut sigmas); rng.shuffle(&mut sigmas);
@ -353,7 +322,7 @@ pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<
sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| { sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
let mut neighbours = graph.out_neighbours_mut(sigma_i); let mut neighbours = graph.out_neighbours_mut(sigma_i);
let mut i = 0; let mut i = 0;
while neighbours.len() < config.r && i < 100 { while neighbours.len() < config.r && i < max_iters {
let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap(); let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap(); let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
if !neighbours.contains(&projected_neighbour) { if !neighbours.contains(&projected_neighbour) {

View File

@ -52,7 +52,14 @@ fn main() -> Result<()> {
let vecs = { let vecs = {
let _timer = Timer::new("loaded vectors"); let _timer = Timer::new("loaded vectors");
&load_file("query.bin", Some(D_EMB * n))? &load_file("real.bin", None)?
};
println!("{} vectors", vecs.len());
let queries = {
let _timer = Timer::new("loaded queries");
&load_file("../query5.bin", None)?
}; };
let (graph, medioid) = { let (graph, medioid) = {
@ -63,9 +70,12 @@ fn main() -> Result<()> {
l: 192, l: 192,
maxc: 750, maxc: 750,
alpha: 65200, alpha: 65200,
saturate_graph: false
}; };
let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); let mut graph = IndexGraph::empty(vecs.len(), config.r);
random_fill_graph(&mut rng, &mut graph, config.r);
let medioid = medioid(&vecs); let medioid = medioid(&vecs);
@ -92,9 +102,10 @@ fn main() -> Result<()> {
let mut config = IndexBuildConfig { let mut config = IndexBuildConfig {
r: 64, r: 64,
l: 50, l: 200,
alpha: 65536, alpha: 65536,
maxc: 0, maxc: 0,
saturate_graph: false
}; };
let mut scratch = Scratch::new(config); let mut scratch = Scratch::new(config);
@ -112,8 +123,8 @@ fn main() -> Result<()> {
let end = time.elapsed(); let end = time.elapsed();
println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n); println!("recall@1: {} ({}/{})", recall as f32 / vecs.len() as f32, recall, vecs.len());
println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n); println!("cmps: {} ({}/{})", cmps_ctr as f32 / vecs.len() as f32, cmps_ctr, vecs.len());
println!("median comparisons: {}", cmps[cmps.len() / 2]); println!("median comparisons: {}", cmps[cmps.len() / 2]);
//println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries); //println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries);
println!("{} QPS", n as f32 / end.as_secs_f32()); println!("{} QPS", n as f32 / end.as_secs_f32());

View File

@ -14,6 +14,7 @@ use itertools::Itertools;
use simsimd::SpatialSimilarity; use simsimd::SpatialSimilarity;
use std::hash::Hasher; use std::hash::Hasher;
use foldhash::{HashSet, HashSetExt}; use foldhash::{HashSet, HashSetExt};
use std::os::unix::prelude::FileExt;
use diskann::vector::{scale_dot_result_f64, ProductQuantizer}; use diskann::vector::{scale_dot_result_f64, ProductQuantizer};
@ -161,15 +162,29 @@ fn main() -> Result<()> {
let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries { let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries {
println!("constructing index"); println!("constructing index");
// not memory-efficient but this is small // not memory-efficient but this is small
let data = fs::read(queries_file).context("read queries file")?; let mut file = fs::File::open(queries_file).context("read queries file")?;
let mut size = file.metadata()?.len();
//let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; //let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?;
let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; let mut index = faiss::index_factory(D_EMB, "HNSW64,SQ8", faiss::MetricType::InnerProduct)?;
//let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?; //let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?;
let unpacked = common::decode_fp16_buffer(&data); let mut buf = vec![0; (D_EMB as usize) * (1<<18)];
index.train(&unpacked)?; loop {
index.add(&unpacked)?; if size == 0 {
break;
}
if size < (buf.len() as u64) {
buf.resize(size as usize, 0);
}
file.read_exact(&mut buf)?;
size -= buf.len() as u64;
let unpacked = common::decode_fp16_buffer(&buf);
if !index.is_trained() { index.train(&unpacked)?; print!("train"); }
index.add(&unpacked)?;
print!(".");
}
println!("done"); println!("done");
(Some(index), unpacked.len() / D_EMB as usize) let ntotal = index.ntotal();
(Some(index), ntotal as usize)
} else { } else {
(None, 0) (None, 0)
}; };
@ -267,14 +282,15 @@ fn main() -> Result<()> {
let shard = shard as usize; let shard = shard as usize;
// this random access is almost certainly rather slow // this random access is almost certainly rather slow
// parallelize? // parallelize?
files[shard].1.seek(SeekFrom::Start(offset))?;
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
files[shard].1.read_exact(&mut buf)?; files[shard].1.read_exact_at(&mut buf, offset)?;
let s: &mut [u32] = bytemuck::cast_slice_mut(&mut *buf); let s: &[u32] = bytemuck::cast_slice(&mut *buf);
for within_shard_id in s.iter_mut() { for within_shard_id in s.iter() {
*within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize]; let global_id = shard_id_mappings[shard].1[*within_shard_id as usize];
if !out_vertices.contains(&global_id) {
out_vertices.push(global_id);
}
} }
out_vertices.extend(s.iter().unique());
} }
Ok((out_vertices, shards)) Ok((out_vertices, shards))
@ -422,7 +438,7 @@ fn main() -> Result<()> {
let codes = quantizer.quantize_batch(&batch_embeddings); let codes = quantizer.quantize_batch(&batch_embeddings);
for (i, (x, _embedding)) in batch.into_iter().enumerate() { for (i, (x, _embedding)) in batch.into_iter().enumerate() {
let (vertices, shards) = read_out_vertices(count)?; // TODO: could parallelize this given the batching let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
let mut entry = PackedIndexEntry { let mut entry = PackedIndexEntry {
id: count + i as u32, id: count + i as u32,
vertices, vertices,

View File

@ -3,7 +3,7 @@ use itertools::Itertools;
use std::io::{BufReader, BufWriter, Write}; use std::io::{BufReader, BufWriter, Write};
use rmp_serde::decode::Error as DecodeError; use rmp_serde::decode::Error as DecodeError;
use std::fs; use std::fs;
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees}; use diskann::{augment_bipartite, build_graph, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid};
use half::f16; use half::f16;
mod common; mod common;
@ -41,10 +41,11 @@ fn main() -> Result<()> {
} }
let mut config = IndexBuildConfig { let mut config = IndexBuildConfig {
r: 40, r: 64,
l: 200, l: 192,
maxc: 750, maxc: 750,
alpha: 65300 alpha: 65200,
saturate_graph: false
}; };
let vecs = VectorList { let vecs = VectorList {
@ -67,9 +68,7 @@ fn main() -> Result<()> {
report_degrees(&graph); report_degrees(&graph);
let medioid = vecs.iter().position_max_by_key(|&v| { let medioid = medioid(&vecs);
dot(v, &centroid_fp16)
}).unwrap() as u32;
{ {
let _timer = Timer::new("first pass"); let _timer = Timer::new("first pass");
@ -101,7 +100,8 @@ fn main() -> Result<()> {
{ {
let _timer = Timer::new("augment bipartite"); let _timer = Timer::new("augment bipartite");
augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config); //augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config, 50);
//random_fill_graph(&mut rng, &mut graph, config.r);
} }
let len = original_ids.len(); let len = original_ids.len();

View File

@ -2,6 +2,7 @@ use anyhow::{bail, Context, Result};
use diskann::vector::scale_dot_result_f64; use diskann::vector::scale_dot_result_f64;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use std::io::{BufReader, Read, Seek, SeekFrom, Write}; use std::io::{BufReader, Read, Seek, SeekFrom, Write};
use std::os::unix::prelude::FileExt;
use std::path::PathBuf; use std::path::PathBuf;
use std::fs; use std::fs;
use base64::Engine; use base64::Engine;
@ -37,9 +38,8 @@ struct CLIArguments {
fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> { fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> {
let offset = id as usize * header.record_pad_size; let offset = id as usize * header.record_pad_size;
data_file.seek(SeekFrom::Start(offset as u64))?;
let mut buf = vec![0; header.record_pad_size as usize]; let mut buf = vec![0; header.record_pad_size as usize];
data_file.read_exact(&mut buf)?; data_file.read_exact_at(&mut buf, offset as u64)?;
let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize; let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize;
Ok(bitcode::decode(&buf[2..len+2])?) Ok(bitcode::decode(&buf[2..len+2])?)
} }
@ -117,9 +117,11 @@ fn summary_stats(ranks: &mut [usize]) {
ranks.sort_unstable(); ranks.sort_unstable();
let median = ranks[ranks.len() / 2] + 1; let median = ranks[ranks.len() / 2] + 1;
let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64; let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64;
println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean); println!("median {} mean {:.2} max {} min {} harmonic mean {:.2}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean);
} }
const K: usize = 20;
fn main() -> Result<()> { fn main() -> Result<()> {
let args: CLIArguments = argh::from_env(); let args: CLIArguments = argh::from_env();
@ -150,8 +152,11 @@ fn main() -> Result<()> {
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
let mut top_20_ranks_best_shard = vec![]; let mut top_k_ranks_best_shard = vec![];
let mut top_rank_best_shard = vec![]; let mut top_rank_best_shard = vec![];
let mut pq_cmps = vec![];
let mut cmps = vec![];
let mut recall_total = 0;
for query_vector in queries.iter() { for query_vector in queries.iter() {
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>(); let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
@ -183,26 +188,30 @@ fn main() -> Result<()> {
println!("brute force: {} {} {} {:?}", id, distance, url, shards); println!("brute force: {} {} {} {:?}", id, distance, url, shards);
}*/ }*/
let mut top_ranks = vec![usize::MAX; 20]; let mut top_ranks = vec![usize::MAX; K];
for shard in 0..header.shards.len() { for shard in 0..header.shards.len() {
let selected_start = header.shards[shard].1; let selected_start = header.shards[shard].1;
let mut scratch = Scratch { let mut scratch = Scratch {
visited: HashSet::new(), visited: HashSet::new(),
neighbour_buffer: NeighbourBuffer::new(5000), neighbour_buffer: NeighbourBuffer::new(1000),
neighbour_pre_buffer: Vec::new(), neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new() visited_list: Vec::new()
}; };
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng); //let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng);
let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef { let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
data_file: &mut data_file, data_file: &mut data_file,
header: &header, header: &header,
pq_codes: &pq_codes, pq_codes: &pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
}, args.disable_pq)?; }, args.disable_pq)?;
// slightly dubious because this is across shards
pq_cmps.push(cmps_result.1);
cmps.push(cmps_result.0);
if args.verbose { if args.verbose {
println!("index scan {}: {:?} cmps", shard, cmps); println!("index scan {}: {:?} cmps", shard, cmps);
} }
@ -221,14 +230,26 @@ fn main() -> Result<()> {
if args.verbose { println!("") } if args.verbose { println!("") }
} }
// results list is always correctly sorted
for &rank in top_ranks.iter() {
if rank < K {
recall_total += 1;
}
}
top_rank_best_shard.push(top_ranks[0]); top_rank_best_shard.push(top_ranks[0]);
top_20_ranks_best_shard.extend(top_ranks); top_k_ranks_best_shard.extend(top_ranks);
} }
println!("ranks of top 20:"); println!("ranks of top 20:");
summary_stats(&mut top_20_ranks_best_shard); summary_stats(&mut top_k_ranks_best_shard);
println!("ranks of top 1:"); println!("ranks of top 1:");
summary_stats(&mut top_rank_best_shard); summary_stats(&mut top_rank_best_shard);
println!("pq comparisons:");
summary_stats(&mut pq_cmps);
println!("comparisons:");
summary_stats(&mut cmps);
println!("recall@{}: {}", K, recall_total as f64 / (K * queries.len()) as f64);
Ok(()) Ok(())
} }