diff --git a/diskann/Cargo.lock b/diskann/Cargo.lock index f9a41da..370cfdc 100644 --- a/diskann/Cargo.lock +++ b/diskann/Cargo.lock @@ -146,6 +146,7 @@ dependencies = [ "foldhash", "half", "matrixmultiply", + "min-max-heap", "rayon", "rmp-serde", "serde", @@ -228,6 +229,12 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "min-max-heap" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18" + [[package]] name = "mio" version = "0.8.11" diff --git a/diskann/aopq_train.py b/diskann/aopq_train.py index 03c840e..d0b778f 100644 --- a/diskann/aopq_train.py +++ b/diskann/aopq_train.py @@ -11,26 +11,9 @@ output_code_size = 64 output_code_bits = 8 output_codebook_size = 2**output_code_bits n_dims_per_code = n_dims // output_code_size -dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) -queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) -device = "cpu" - -index = faiss.index_factory(n_dims, "HNSW32,SQfp16", faiss.METRIC_INNER_PRODUCT) -index.train(queryset) -index.add(queryset) -print("index ready") - -T = 64 - -nearby_query_indices = torch.zeros((dataset.shape[0], T), dtype=torch.int32) - -SEARCH_BATCH_SIZE = 1024 - -for i in range(0, len(dataset), SEARCH_BATCH_SIZE): - res = index.search(dataset[i:i+SEARCH_BATCH_SIZE], T) - nearby_query_indices[i:i+SEARCH_BATCH_SIZE] = torch.tensor(res[1]) - -print("query indices ready") +dataset = np.random.permutation(np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)).astype(np.float32) +queryset = np.random.permutation(np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims))[:100000].astype(np.float32) +device = "cuda" def pq_assign(centroids, batch): quantized = torch.zeros_like(batch) @@ -47,28 +30,27 @@ def pq_assign(centroids, batch): # OOD-DiskANN (https://arxiv.org/abs/2211.12850) uses a more complicated scheme because it uses L2 norm # We only care about inner product so our quantization error (wrt a query) is just abs(dot(query, centroid - vector)) # Directly optimize for this (wrt top queries; it might actually be better to use a random sample instead?) -def partition(vectors, centroids, projection, opt, queries, nearby_query_indices, k, max_iter=100, batch_size=4096): +def partition(vectors, centroids, projection, opt, queries, k, max_iter=100, batch_size=4096, query_batch_size=2048): n_vectors = len(vectors) - perm = torch.randperm(n_vectors, device=device) + #perm = torch.randperm(n_vectors, device=device) t = tqdm.trange(max_iter) for iter in t: total_loss = 0 opt.zero_grad(set_to_none=True) + # randomly sample queries (with replacement, probably fine) + queries_for_iteration = queries[torch.randint(0, len(queries), (query_batch_size,), device=device)] + for i in range(0, n_vectors, batch_size): loss = torch.tensor(0.0, device=device) batch = vectors[i:i+batch_size] @ projection quantized = pq_assign(centroids, batch) residuals = batch - quantized - # for each index in our set of nearby queries - for j in range(0, nearby_query_indices.shape[1]): - queries_for_batch_j = queries[nearby_query_indices[i:i+batch_size, j]] - # minimize quantiation error in direction of query, i.e. mean abs(dot(query, centroid - vector)) - # PyTorch won't do batched dot products cleanly, to spite me. Do componentwise multiplication and reduce. - sg_errs = (queries_for_batch_j * residuals).sum(dim=-1) - loss += torch.mean(torch.abs(sg_errs)) + batch_error = queries_for_iteration @ residuals.T + + loss += torch.mean(torch.pow(batch_error, 2)) total_loss += loss.detach().item() loss.backward() @@ -90,10 +72,10 @@ queries = torch.tensor(queryset, device=device) perm = torch.randperm(len(vectors), device=device) centroids = vectors[perm[:output_codebook_size]] centroids.requires_grad = True -opt = torch.optim.Adam([centroids], lr=0.001) +opt = torch.optim.Adam([centroids], lr=0.0005) for i in range(30): # update centroids to minimize query-aware quantization loss - partition(vectors, centroids, projection, opt, queries, nearby_query_indices, output_codebook_size, max_iter=8) + partition(vectors, centroids, projection, opt, queries, output_codebook_size, max_iter=300) # compute new projection as R = VU^T from XY^T = USV^T (SVD) # where X is dataset vectors, Y is quantized dataset vectors with torch.no_grad(): @@ -102,12 +84,12 @@ for i in range(30): u, s, vt = torch.linalg.svd(vectors.T @ y) projection = vt.T @ u.T -print("done") + with open("opq.msgpack", "wb") as f: + msgpack.pack({ + "centroids": centroids.detach().cpu().numpy().flatten().tolist(), + "transform": projection.cpu().numpy().flatten().tolist(), + "n_dims_per_code": n_dims_per_code, + "n_dims": n_dims + }, f) -with open("opq.msgpack", "wb") as f: - msgpack.pack({ - "centroids": centroids.detach().cpu().numpy().flatten().tolist(), - "transform": projection.cpu().numpy().flatten().tolist(), - "n_dims_per_code": n_dims_per_code, - "n_dims": n_dims - }, f) +print("done") diff --git a/diskann/src/lib.rs b/diskann/src/lib.rs index 758ca68..fc22b45 100644 --- a/diskann/src/lib.rs +++ b/diskann/src/lib.rs @@ -19,20 +19,6 @@ pub struct IndexGraph { } impl IndexGraph { - pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self { - let mut graph = Vec::with_capacity(n); - for _ in 0..n { - let mut adjacency = Vec::with_capacity(capacity); - for _ in 0..r { - adjacency.push(rng.u32(0..(n as u32))); - } - graph.push(RwLock::new(adjacency)); - } - IndexGraph { - graph - } - } - pub fn empty(n: usize, capacity: usize) -> IndexGraph { let mut graph = Vec::with_capacity(n); for _ in 0..n { @@ -57,7 +43,8 @@ pub struct IndexBuildConfig { pub r: usize, pub l: usize, pub maxc: usize, - pub alpha: i64 + pub alpha: i64, + pub saturate_graph: bool } @@ -79,6 +66,7 @@ pub fn medioid(vecs: &VectorList) -> u32 { // neighbours list sorted by score descending // TODO: this may actually be an awful datastructure +// we could also have a heap of unvisited things, but the algorithm's stopping condition cares about visited things, and this is probably still the easiest way #[derive(Clone, Debug)] pub struct NeighbourBuffer { pub ids: Vec, @@ -244,8 +232,9 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec, vecs: &Vect let mut candidate_index = 0; while neigh.len() < config.r && candidate_index < candidates.len() { let p_star = candidates[candidate_index].0; + let p_star_score = candidates[candidate_index].1; candidate_index += 1; - if p_star == p || p_star == u32::MAX { + if p_star == p || p_star_score == i64::MIN { continue; } @@ -256,7 +245,7 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec, vecs: &Vect // mark remaining candidates as not-to-be-used if "not much better than" current candidate for i in (candidate_index+1)..candidates.len() { let p_prime = candidates[i].0; - if p_prime != u32::MAX { + if candidates[i].1 != i64::MIN { scratch.robust_prune_scratch_buffer.push((i, p_prime)); } } @@ -268,7 +257,18 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec, vecs: &Vect let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16; if alpha_times_p_star_prime_score >= p_prime_p_score { - candidates[ci].0 = u32::MAX; + candidates[ci].1 = i64::MIN; + } + } + } + + if config.saturate_graph { + for &(id, _score) in candidates.iter() { + if neigh.len() == config.r { + return; + } + if !neigh.contains(&id) { + neigh.push(id); } } } @@ -382,7 +382,7 @@ pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec }) } -pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec>, query_knns_bwd: Vec>, config: IndexBuildConfig) { +pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec>, query_knns_bwd: Vec>, config: IndexBuildConfig, max_iters: usize) { let mut sigmas: Vec = (0..(graph.graph.len() as u32)).collect(); rng.shuffle(&mut sigmas); @@ -391,7 +391,7 @@ pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec< sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| { let mut neighbours = graph.out_neighbours_mut(sigma_i); let mut i = 0; - while neighbours.len() < config.r && i < 100 { + while neighbours.len() < config.r && i < max_iters { let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap(); let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap(); if !neighbours.contains(&projected_neighbour) { diff --git a/diskann/src/main.rs b/diskann/src/main.rs index ba3208a..90a4eb8 100644 --- a/diskann/src/main.rs +++ b/diskann/src/main.rs @@ -52,7 +52,14 @@ fn main() -> Result<()> { let vecs = { let _timer = Timer::new("loaded vectors"); - &load_file("query.bin", Some(D_EMB * n))? + &load_file("real.bin", None)? + }; + + println!("{} vectors", vecs.len()); + + let queries = { + let _timer = Timer::new("loaded queries"); + &load_file("../query5.bin", None)? }; let (graph, medioid) = { @@ -63,9 +70,12 @@ fn main() -> Result<()> { l: 192, maxc: 750, alpha: 65200, + saturate_graph: false }; - let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); + let mut graph = IndexGraph::empty(vecs.len(), config.r); + + random_fill_graph(&mut rng, &mut graph, config.r); let medioid = medioid(&vecs); @@ -92,9 +102,10 @@ fn main() -> Result<()> { let mut config = IndexBuildConfig { r: 64, - l: 50, + l: 200, alpha: 65536, maxc: 0, + saturate_graph: false }; let mut scratch = Scratch::new(config); @@ -112,8 +123,8 @@ fn main() -> Result<()> { let end = time.elapsed(); - println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n); - println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n); + println!("recall@1: {} ({}/{})", recall as f32 / vecs.len() as f32, recall, vecs.len()); + println!("cmps: {} ({}/{})", cmps_ctr as f32 / vecs.len() as f32, cmps_ctr, vecs.len()); println!("median comparisons: {}", cmps[cmps.len() / 2]); //println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries); println!("{} QPS", n as f32 / end.as_secs_f32()); diff --git a/src/common.rs b/src/common.rs index 8205264..da0cd8b 100644 --- a/src/common.rs +++ b/src/common.rs @@ -176,7 +176,8 @@ pub mod index_config { r: 40, l: 200, maxc: 900, - alpha: 65200 + alpha: 65200, + saturate_graph: false }; pub const PROJECTION_CUT_POINT: usize = 3; diff --git a/src/dump_processor.rs b/src/dump_processor.rs index c199090..6669a89 100644 --- a/src/dump_processor.rs +++ b/src/dump_processor.rs @@ -14,6 +14,7 @@ use itertools::Itertools; use simsimd::SpatialSimilarity; use std::hash::Hasher; use foldhash::{HashSet, HashSetExt}; +use std::os::unix::prelude::FileExt; use diskann::vector::{scale_dot_result_f64, ProductQuantizer}; @@ -160,15 +161,29 @@ fn main() -> Result<()> { let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries { println!("constructing index"); // not memory-efficient but this is small - let data = fs::read(queries_file).context("read queries file")?; + let mut file = fs::File::open(queries_file).context("read queries file")?; + let mut size = file.metadata()?.len(); //let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; - let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; + let mut index = faiss::index_factory(D_EMB, "HNSW64,SQ8", faiss::MetricType::InnerProduct)?; //let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?; - let unpacked = common::decode_fp16_buffer(&data); - index.train(&unpacked)?; - index.add(&unpacked)?; + let mut buf = vec![0; (D_EMB as usize) * (1<<18)]; + loop { + if size == 0 { + break; + } + if size < (buf.len() as u64) { + buf.resize(size as usize, 0); + } + file.read_exact(&mut buf)?; + size -= buf.len() as u64; + let unpacked = common::decode_fp16_buffer(&buf); + if !index.is_trained() { index.train(&unpacked)?; print!("train"); } + index.add(&unpacked)?; + print!("."); + } println!("done"); - (Some(index), unpacked.len() / D_EMB as usize) + let ntotal = index.ntotal(); + (Some(index), ntotal as usize) } else { (None, 0) }; @@ -266,14 +281,16 @@ fn main() -> Result<()> { let shard = shard as usize; // this random access is almost certainly rather slow // parallelize? - files[shard].1.seek(SeekFrom::Start(offset))?; let mut buf = vec![0; len as usize]; - files[shard].1.read_exact(&mut buf)?; - let mut s: Vec = bytemuck::allocation::pod_collect_to_vec(&*buf); - for within_shard_id in s.iter_mut() { - *within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize]; + + files[shard].1.read_exact_at(&mut buf, offset)?; + let s: &[u32] = bytemuck::cast_slice(&mut *buf); + for within_shard_id in s.iter() { + let global_id = shard_id_mappings[shard].1[*within_shard_id as usize]; + if !out_vertices.contains(&global_id) { + out_vertices.push(global_id); + } } - out_vertices.extend(s.iter().unique()); } Ok((out_vertices, shards)) @@ -422,7 +439,7 @@ fn main() -> Result<()> { let codes = quantizer.quantize_batch(&batch_embeddings); for (i, (x, _embedding)) in batch.into_iter().enumerate() { - let (vertices, shards) = read_out_vertices(count)?; // TODO: could parallelize this given the batching + let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching let mut entry = PackedIndexEntry { id: count + i as u32, vertices, diff --git a/src/generate_index_shard.rs b/src/generate_index_shard.rs index 24a52c2..2106558 100644 --- a/src/generate_index_shard.rs +++ b/src/generate_index_shard.rs @@ -4,7 +4,7 @@ use std::collections::BinaryHeap; use std::io::{BufReader, BufWriter, Write}; use std::fs; use rmp_serde::decode::{Error as DecodeError, from_read}; -use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer, report_degrees}; +use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid}; use half::f16; mod common; @@ -101,9 +101,7 @@ fn main() -> Result<()> { report_degrees(&graph); - let medioid = vecs.iter().position_max_by_key(|&v| { - dot(v, ¢roid_fp16) - }).unwrap() as u32; + let medioid = medioid(&vecs); { let _timer = Timer::new("first pass"); diff --git a/src/query_disk_index.rs b/src/query_disk_index.rs index 41b8469..8ddbe22 100644 --- a/src/query_disk_index.rs +++ b/src/query_disk_index.rs @@ -2,6 +2,7 @@ use anyhow::{bail, Context, Result}; use diskann::vector::scale_dot_result_f64; use serde::{Serialize, Deserialize}; use std::io::{BufReader, Read, Seek, SeekFrom, Write}; +use std::os::unix::prelude::FileExt; use std::path::PathBuf; use std::fs; use base64::Engine; @@ -31,15 +32,16 @@ struct CLIArguments { verbose: bool, #[argh(option, short='n', description="stop at n queries")] n: Option, + #[argh(option, short='L', description="search list size")] + search_list_size: Option, #[argh(switch, description="always use full-precision vectors (slow)")] disable_pq: bool } fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result { let offset = id as usize * header.record_pad_size; - data_file.seek(SeekFrom::Start(offset as u64))?; let mut buf = vec![0; header.record_pad_size as usize]; - data_file.read_exact(&mut buf)?; + data_file.read_exact_at(&mut buf, offset as u64)?; let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize; Ok(bitcode::decode(&buf[2..len+2])?) } @@ -117,9 +119,11 @@ fn summary_stats(ranks: &mut [usize]) { ranks.sort_unstable(); let median = ranks[ranks.len() / 2] + 1; let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::() / ranks.len() as f64; - println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean); + println!("median {} mean {:.2} max {} min {} harmonic mean {:.2}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean); } +const K: usize = 20; + fn main() -> Result<()> { let args: CLIArguments = argh::from_env(); @@ -150,8 +154,11 @@ fn main() -> Result<()> { println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); - let mut top_20_ranks_best_shard = vec![]; + let mut top_k_ranks_best_shard = vec![]; let mut top_rank_best_shard = vec![]; + let mut pq_cmps = vec![]; + let mut cmps = vec![]; + let mut recall_total = 0; for query_vector in queries.iter() { let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::>(); @@ -183,55 +190,68 @@ fn main() -> Result<()> { println!("brute force: {} {} {} {:?}", id, distance, url, shards); }*/ - let mut top_ranks = vec![usize::MAX; 20]; + let mut top_ranks = vec![usize::MAX; K]; for shard in 0..header.shards.len() { let selected_start = header.shards[shard].1; let mut scratch = Scratch { visited: HashSet::new(), - neighbour_buffer: NeighbourBuffer::new(300), + neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)), neighbour_pre_buffer: Vec::new(), visited_list: Vec::new() }; //let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng); - let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef { + let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef { data_file: &mut data_file, header: &header, pq_codes: &pq_codes, pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, }, args.disable_pq)?; + // slightly dubious because this is across shards + pq_cmps.push(cmps_result.1); + cmps.push(cmps_result.0); + if args.verbose { println!("index scan {}: {:?} cmps", shard, cmps); } scratch.visited_list.sort_by_key(|x| -x.1); for (i, (id, distance, url, shards)) in scratch.visited_list.iter().take(20).enumerate() { - if args.verbose { - println!("index scan: {} {} {} {:?}", id, distance, url, shards); - }; let found_id = match matches.binary_search(&(*id, 0)) { Ok(pos) => pos, Err(pos) => pos }; if args.verbose { - println!("rank {}", matches[found_id].1); + println!("index scan: {} {} {} {:?}; rank {}", id, distance, url, shards, matches[found_id].1 + 1); }; top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1); } if args.verbose { println!("") } } + // results list is always correctly sorted + for &rank in top_ranks.iter() { + if rank < K { + recall_total += 1; + } + } + top_rank_best_shard.push(top_ranks[0]); - top_20_ranks_best_shard.extend(top_ranks); + top_k_ranks_best_shard.extend(top_ranks); } println!("ranks of top 20:"); - summary_stats(&mut top_20_ranks_best_shard); + summary_stats(&mut top_k_ranks_best_shard); println!("ranks of top 1:"); summary_stats(&mut top_rank_best_shard); + println!("pq comparisons:"); + summary_stats(&mut pq_cmps); + println!("comparisons:"); + summary_stats(&mut cmps); + println!("recall@{}: {}", K, recall_total as f64 / (K * queries.len()) as f64); Ok(()) }