mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-22 14:00:09 +00:00
merge fixes into roargraph
This commit is contained in:
commit
fdc2af6f8a
7
diskann/Cargo.lock
generated
7
diskann/Cargo.lock
generated
@ -146,6 +146,7 @@ dependencies = [
|
||||
"foldhash",
|
||||
"half",
|
||||
"matrixmultiply",
|
||||
"min-max-heap",
|
||||
"rayon",
|
||||
"rmp-serde",
|
||||
"serde",
|
||||
@ -228,6 +229,12 @@ dependencies = [
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "min-max-heap"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18"
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.11"
|
||||
|
@ -11,26 +11,9 @@ output_code_size = 64
|
||||
output_code_bits = 8
|
||||
output_codebook_size = 2**output_code_bits
|
||||
n_dims_per_code = n_dims // output_code_size
|
||||
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
device = "cpu"
|
||||
|
||||
index = faiss.index_factory(n_dims, "HNSW32,SQfp16", faiss.METRIC_INNER_PRODUCT)
|
||||
index.train(queryset)
|
||||
index.add(queryset)
|
||||
print("index ready")
|
||||
|
||||
T = 64
|
||||
|
||||
nearby_query_indices = torch.zeros((dataset.shape[0], T), dtype=torch.int32)
|
||||
|
||||
SEARCH_BATCH_SIZE = 1024
|
||||
|
||||
for i in range(0, len(dataset), SEARCH_BATCH_SIZE):
|
||||
res = index.search(dataset[i:i+SEARCH_BATCH_SIZE], T)
|
||||
nearby_query_indices[i:i+SEARCH_BATCH_SIZE] = torch.tensor(res[1])
|
||||
|
||||
print("query indices ready")
|
||||
dataset = np.random.permutation(np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)).astype(np.float32)
|
||||
queryset = np.random.permutation(np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims))[:100000].astype(np.float32)
|
||||
device = "cuda"
|
||||
|
||||
def pq_assign(centroids, batch):
|
||||
quantized = torch.zeros_like(batch)
|
||||
@ -47,28 +30,27 @@ def pq_assign(centroids, batch):
|
||||
# OOD-DiskANN (https://arxiv.org/abs/2211.12850) uses a more complicated scheme because it uses L2 norm
|
||||
# We only care about inner product so our quantization error (wrt a query) is just abs(dot(query, centroid - vector))
|
||||
# Directly optimize for this (wrt top queries; it might actually be better to use a random sample instead?)
|
||||
def partition(vectors, centroids, projection, opt, queries, nearby_query_indices, k, max_iter=100, batch_size=4096):
|
||||
def partition(vectors, centroids, projection, opt, queries, k, max_iter=100, batch_size=4096, query_batch_size=2048):
|
||||
n_vectors = len(vectors)
|
||||
perm = torch.randperm(n_vectors, device=device)
|
||||
#perm = torch.randperm(n_vectors, device=device)
|
||||
|
||||
t = tqdm.trange(max_iter)
|
||||
for iter in t:
|
||||
total_loss = 0
|
||||
opt.zero_grad(set_to_none=True)
|
||||
|
||||
# randomly sample queries (with replacement, probably fine)
|
||||
queries_for_iteration = queries[torch.randint(0, len(queries), (query_batch_size,), device=device)]
|
||||
|
||||
for i in range(0, n_vectors, batch_size):
|
||||
loss = torch.tensor(0.0, device=device)
|
||||
batch = vectors[i:i+batch_size] @ projection
|
||||
quantized = pq_assign(centroids, batch)
|
||||
residuals = batch - quantized
|
||||
|
||||
# for each index in our set of nearby queries
|
||||
for j in range(0, nearby_query_indices.shape[1]):
|
||||
queries_for_batch_j = queries[nearby_query_indices[i:i+batch_size, j]]
|
||||
# minimize quantiation error in direction of query, i.e. mean abs(dot(query, centroid - vector))
|
||||
# PyTorch won't do batched dot products cleanly, to spite me. Do componentwise multiplication and reduce.
|
||||
sg_errs = (queries_for_batch_j * residuals).sum(dim=-1)
|
||||
loss += torch.mean(torch.abs(sg_errs))
|
||||
batch_error = queries_for_iteration @ residuals.T
|
||||
|
||||
loss += torch.mean(torch.pow(batch_error, 2))
|
||||
|
||||
total_loss += loss.detach().item()
|
||||
loss.backward()
|
||||
@ -90,10 +72,10 @@ queries = torch.tensor(queryset, device=device)
|
||||
perm = torch.randperm(len(vectors), device=device)
|
||||
centroids = vectors[perm[:output_codebook_size]]
|
||||
centroids.requires_grad = True
|
||||
opt = torch.optim.Adam([centroids], lr=0.001)
|
||||
opt = torch.optim.Adam([centroids], lr=0.0005)
|
||||
for i in range(30):
|
||||
# update centroids to minimize query-aware quantization loss
|
||||
partition(vectors, centroids, projection, opt, queries, nearby_query_indices, output_codebook_size, max_iter=8)
|
||||
partition(vectors, centroids, projection, opt, queries, output_codebook_size, max_iter=300)
|
||||
# compute new projection as R = VU^T from XY^T = USV^T (SVD)
|
||||
# where X is dataset vectors, Y is quantized dataset vectors
|
||||
with torch.no_grad():
|
||||
@ -102,12 +84,12 @@ for i in range(30):
|
||||
u, s, vt = torch.linalg.svd(vectors.T @ y)
|
||||
projection = vt.T @ u.T
|
||||
|
||||
print("done")
|
||||
with open("opq.msgpack", "wb") as f:
|
||||
msgpack.pack({
|
||||
"centroids": centroids.detach().cpu().numpy().flatten().tolist(),
|
||||
"transform": projection.cpu().numpy().flatten().tolist(),
|
||||
"n_dims_per_code": n_dims_per_code,
|
||||
"n_dims": n_dims
|
||||
}, f)
|
||||
|
||||
with open("opq.msgpack", "wb") as f:
|
||||
msgpack.pack({
|
||||
"centroids": centroids.detach().cpu().numpy().flatten().tolist(),
|
||||
"transform": projection.cpu().numpy().flatten().tolist(),
|
||||
"n_dims_per_code": n_dims_per_code,
|
||||
"n_dims": n_dims
|
||||
}, f)
|
||||
print("done")
|
||||
|
@ -19,20 +19,6 @@ pub struct IndexGraph {
|
||||
}
|
||||
|
||||
impl IndexGraph {
|
||||
pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self {
|
||||
let mut graph = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
let mut adjacency = Vec::with_capacity(capacity);
|
||||
for _ in 0..r {
|
||||
adjacency.push(rng.u32(0..(n as u32)));
|
||||
}
|
||||
graph.push(RwLock::new(adjacency));
|
||||
}
|
||||
IndexGraph {
|
||||
graph
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty(n: usize, capacity: usize) -> IndexGraph {
|
||||
let mut graph = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
@ -57,7 +43,8 @@ pub struct IndexBuildConfig {
|
||||
pub r: usize,
|
||||
pub l: usize,
|
||||
pub maxc: usize,
|
||||
pub alpha: i64
|
||||
pub alpha: i64,
|
||||
pub saturate_graph: bool
|
||||
}
|
||||
|
||||
|
||||
@ -79,6 +66,7 @@ pub fn medioid(vecs: &VectorList) -> u32 {
|
||||
|
||||
// neighbours list sorted by score descending
|
||||
// TODO: this may actually be an awful datastructure
|
||||
// we could also have a heap of unvisited things, but the algorithm's stopping condition cares about visited things, and this is probably still the easiest way
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NeighbourBuffer {
|
||||
pub ids: Vec<u32>,
|
||||
@ -244,8 +232,9 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
|
||||
let mut candidate_index = 0;
|
||||
while neigh.len() < config.r && candidate_index < candidates.len() {
|
||||
let p_star = candidates[candidate_index].0;
|
||||
let p_star_score = candidates[candidate_index].1;
|
||||
candidate_index += 1;
|
||||
if p_star == p || p_star == u32::MAX {
|
||||
if p_star == p || p_star_score == i64::MIN {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -256,7 +245,7 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
|
||||
// mark remaining candidates as not-to-be-used if "not much better than" current candidate
|
||||
for i in (candidate_index+1)..candidates.len() {
|
||||
let p_prime = candidates[i].0;
|
||||
if p_prime != u32::MAX {
|
||||
if candidates[i].1 != i64::MIN {
|
||||
scratch.robust_prune_scratch_buffer.push((i, p_prime));
|
||||
}
|
||||
}
|
||||
@ -268,7 +257,18 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
|
||||
let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16;
|
||||
|
||||
if alpha_times_p_star_prime_score >= p_prime_p_score {
|
||||
candidates[ci].0 = u32::MAX;
|
||||
candidates[ci].1 = i64::MIN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.saturate_graph {
|
||||
for &(id, _score) in candidates.iter() {
|
||||
if neigh.len() == config.r {
|
||||
return;
|
||||
}
|
||||
if !neigh.contains(&id) {
|
||||
neigh.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -382,7 +382,7 @@ pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec
|
||||
})
|
||||
}
|
||||
|
||||
pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig) {
|
||||
pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig, max_iters: usize) {
|
||||
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
|
||||
rng.shuffle(&mut sigmas);
|
||||
|
||||
@ -391,7 +391,7 @@ pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<
|
||||
sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
|
||||
let mut neighbours = graph.out_neighbours_mut(sigma_i);
|
||||
let mut i = 0;
|
||||
while neighbours.len() < config.r && i < 100 {
|
||||
while neighbours.len() < config.r && i < max_iters {
|
||||
let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
|
||||
let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
|
||||
if !neighbours.contains(&projected_neighbour) {
|
||||
|
@ -52,7 +52,14 @@ fn main() -> Result<()> {
|
||||
let vecs = {
|
||||
let _timer = Timer::new("loaded vectors");
|
||||
|
||||
&load_file("query.bin", Some(D_EMB * n))?
|
||||
&load_file("real.bin", None)?
|
||||
};
|
||||
|
||||
println!("{} vectors", vecs.len());
|
||||
|
||||
let queries = {
|
||||
let _timer = Timer::new("loaded queries");
|
||||
&load_file("../query5.bin", None)?
|
||||
};
|
||||
|
||||
let (graph, medioid) = {
|
||||
@ -63,9 +70,12 @@ fn main() -> Result<()> {
|
||||
l: 192,
|
||||
maxc: 750,
|
||||
alpha: 65200,
|
||||
saturate_graph: false
|
||||
};
|
||||
|
||||
let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap);
|
||||
let mut graph = IndexGraph::empty(vecs.len(), config.r);
|
||||
|
||||
random_fill_graph(&mut rng, &mut graph, config.r);
|
||||
|
||||
let medioid = medioid(&vecs);
|
||||
|
||||
@ -92,9 +102,10 @@ fn main() -> Result<()> {
|
||||
|
||||
let mut config = IndexBuildConfig {
|
||||
r: 64,
|
||||
l: 50,
|
||||
l: 200,
|
||||
alpha: 65536,
|
||||
maxc: 0,
|
||||
saturate_graph: false
|
||||
};
|
||||
|
||||
let mut scratch = Scratch::new(config);
|
||||
@ -112,8 +123,8 @@ fn main() -> Result<()> {
|
||||
|
||||
let end = time.elapsed();
|
||||
|
||||
println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n);
|
||||
println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n);
|
||||
println!("recall@1: {} ({}/{})", recall as f32 / vecs.len() as f32, recall, vecs.len());
|
||||
println!("cmps: {} ({}/{})", cmps_ctr as f32 / vecs.len() as f32, cmps_ctr, vecs.len());
|
||||
println!("median comparisons: {}", cmps[cmps.len() / 2]);
|
||||
//println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries);
|
||||
println!("{} QPS", n as f32 / end.as_secs_f32());
|
||||
|
@ -176,7 +176,8 @@ pub mod index_config {
|
||||
r: 40,
|
||||
l: 200,
|
||||
maxc: 900,
|
||||
alpha: 65200
|
||||
alpha: 65200,
|
||||
saturate_graph: false
|
||||
};
|
||||
|
||||
pub const PROJECTION_CUT_POINT: usize = 3;
|
||||
|
@ -14,6 +14,7 @@ use itertools::Itertools;
|
||||
use simsimd::SpatialSimilarity;
|
||||
use std::hash::Hasher;
|
||||
use foldhash::{HashSet, HashSetExt};
|
||||
use std::os::unix::prelude::FileExt;
|
||||
|
||||
use diskann::vector::{scale_dot_result_f64, ProductQuantizer};
|
||||
|
||||
@ -160,15 +161,29 @@ fn main() -> Result<()> {
|
||||
let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries {
|
||||
println!("constructing index");
|
||||
// not memory-efficient but this is small
|
||||
let data = fs::read(queries_file).context("read queries file")?;
|
||||
let mut file = fs::File::open(queries_file).context("read queries file")?;
|
||||
let mut size = file.metadata()?.len();
|
||||
//let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?;
|
||||
let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?;
|
||||
let mut index = faiss::index_factory(D_EMB, "HNSW64,SQ8", faiss::MetricType::InnerProduct)?;
|
||||
//let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?;
|
||||
let unpacked = common::decode_fp16_buffer(&data);
|
||||
index.train(&unpacked)?;
|
||||
index.add(&unpacked)?;
|
||||
let mut buf = vec![0; (D_EMB as usize) * (1<<18)];
|
||||
loop {
|
||||
if size == 0 {
|
||||
break;
|
||||
}
|
||||
if size < (buf.len() as u64) {
|
||||
buf.resize(size as usize, 0);
|
||||
}
|
||||
file.read_exact(&mut buf)?;
|
||||
size -= buf.len() as u64;
|
||||
let unpacked = common::decode_fp16_buffer(&buf);
|
||||
if !index.is_trained() { index.train(&unpacked)?; print!("train"); }
|
||||
index.add(&unpacked)?;
|
||||
print!(".");
|
||||
}
|
||||
println!("done");
|
||||
(Some(index), unpacked.len() / D_EMB as usize)
|
||||
let ntotal = index.ntotal();
|
||||
(Some(index), ntotal as usize)
|
||||
} else {
|
||||
(None, 0)
|
||||
};
|
||||
@ -266,14 +281,16 @@ fn main() -> Result<()> {
|
||||
let shard = shard as usize;
|
||||
// this random access is almost certainly rather slow
|
||||
// parallelize?
|
||||
files[shard].1.seek(SeekFrom::Start(offset))?;
|
||||
let mut buf = vec![0; len as usize];
|
||||
files[shard].1.read_exact(&mut buf)?;
|
||||
let mut s: Vec<u32> = bytemuck::allocation::pod_collect_to_vec(&*buf);
|
||||
for within_shard_id in s.iter_mut() {
|
||||
*within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize];
|
||||
|
||||
files[shard].1.read_exact_at(&mut buf, offset)?;
|
||||
let s: &[u32] = bytemuck::cast_slice(&mut *buf);
|
||||
for within_shard_id in s.iter() {
|
||||
let global_id = shard_id_mappings[shard].1[*within_shard_id as usize];
|
||||
if !out_vertices.contains(&global_id) {
|
||||
out_vertices.push(global_id);
|
||||
}
|
||||
}
|
||||
out_vertices.extend(s.iter().unique());
|
||||
}
|
||||
|
||||
Ok((out_vertices, shards))
|
||||
@ -422,7 +439,7 @@ fn main() -> Result<()> {
|
||||
let codes = quantizer.quantize_batch(&batch_embeddings);
|
||||
|
||||
for (i, (x, _embedding)) in batch.into_iter().enumerate() {
|
||||
let (vertices, shards) = read_out_vertices(count)?; // TODO: could parallelize this given the batching
|
||||
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
|
||||
let mut entry = PackedIndexEntry {
|
||||
id: count + i as u32,
|
||||
vertices,
|
||||
|
@ -4,7 +4,7 @@ use std::collections::BinaryHeap;
|
||||
use std::io::{BufReader, BufWriter, Write};
|
||||
use std::fs;
|
||||
use rmp_serde::decode::{Error as DecodeError, from_read};
|
||||
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer, report_degrees};
|
||||
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid};
|
||||
use half::f16;
|
||||
|
||||
mod common;
|
||||
@ -101,9 +101,7 @@ fn main() -> Result<()> {
|
||||
|
||||
report_degrees(&graph);
|
||||
|
||||
let medioid = vecs.iter().position_max_by_key(|&v| {
|
||||
dot(v, ¢roid_fp16)
|
||||
}).unwrap() as u32;
|
||||
let medioid = medioid(&vecs);
|
||||
|
||||
{
|
||||
let _timer = Timer::new("first pass");
|
||||
|
@ -2,6 +2,7 @@ use anyhow::{bail, Context, Result};
|
||||
use diskann::vector::scale_dot_result_f64;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
|
||||
use std::os::unix::prelude::FileExt;
|
||||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
use base64::Engine;
|
||||
@ -31,15 +32,16 @@ struct CLIArguments {
|
||||
verbose: bool,
|
||||
#[argh(option, short='n', description="stop at n queries")]
|
||||
n: Option<usize>,
|
||||
#[argh(option, short='L', description="search list size")]
|
||||
search_list_size: Option<usize>,
|
||||
#[argh(switch, description="always use full-precision vectors (slow)")]
|
||||
disable_pq: bool
|
||||
}
|
||||
|
||||
fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> {
|
||||
let offset = id as usize * header.record_pad_size;
|
||||
data_file.seek(SeekFrom::Start(offset as u64))?;
|
||||
let mut buf = vec![0; header.record_pad_size as usize];
|
||||
data_file.read_exact(&mut buf)?;
|
||||
data_file.read_exact_at(&mut buf, offset as u64)?;
|
||||
let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize;
|
||||
Ok(bitcode::decode(&buf[2..len+2])?)
|
||||
}
|
||||
@ -117,9 +119,11 @@ fn summary_stats(ranks: &mut [usize]) {
|
||||
ranks.sort_unstable();
|
||||
let median = ranks[ranks.len() / 2] + 1;
|
||||
let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64;
|
||||
println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean);
|
||||
println!("median {} mean {:.2} max {} min {} harmonic mean {:.2}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean);
|
||||
}
|
||||
|
||||
const K: usize = 20;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: CLIArguments = argh::from_env();
|
||||
|
||||
@ -150,8 +154,11 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
|
||||
|
||||
let mut top_20_ranks_best_shard = vec![];
|
||||
let mut top_k_ranks_best_shard = vec![];
|
||||
let mut top_rank_best_shard = vec![];
|
||||
let mut pq_cmps = vec![];
|
||||
let mut cmps = vec![];
|
||||
let mut recall_total = 0;
|
||||
|
||||
for query_vector in queries.iter() {
|
||||
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
|
||||
@ -183,55 +190,68 @@ fn main() -> Result<()> {
|
||||
println!("brute force: {} {} {} {:?}", id, distance, url, shards);
|
||||
}*/
|
||||
|
||||
let mut top_ranks = vec![usize::MAX; 20];
|
||||
let mut top_ranks = vec![usize::MAX; K];
|
||||
|
||||
for shard in 0..header.shards.len() {
|
||||
let selected_start = header.shards[shard].1;
|
||||
|
||||
let mut scratch = Scratch {
|
||||
visited: HashSet::new(),
|
||||
neighbour_buffer: NeighbourBuffer::new(300),
|
||||
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
||||
neighbour_pre_buffer: Vec::new(),
|
||||
visited_list: Vec::new()
|
||||
};
|
||||
|
||||
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng);
|
||||
let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
|
||||
let cmps_result = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
|
||||
data_file: &mut data_file,
|
||||
header: &header,
|
||||
pq_codes: &pq_codes,
|
||||
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
|
||||
}, args.disable_pq)?;
|
||||
|
||||
// slightly dubious because this is across shards
|
||||
pq_cmps.push(cmps_result.1);
|
||||
cmps.push(cmps_result.0);
|
||||
|
||||
if args.verbose {
|
||||
println!("index scan {}: {:?} cmps", shard, cmps);
|
||||
}
|
||||
|
||||
scratch.visited_list.sort_by_key(|x| -x.1);
|
||||
for (i, (id, distance, url, shards)) in scratch.visited_list.iter().take(20).enumerate() {
|
||||
if args.verbose {
|
||||
println!("index scan: {} {} {} {:?}", id, distance, url, shards);
|
||||
};
|
||||
let found_id = match matches.binary_search(&(*id, 0)) {
|
||||
Ok(pos) => pos,
|
||||
Err(pos) => pos
|
||||
};
|
||||
if args.verbose {
|
||||
println!("rank {}", matches[found_id].1);
|
||||
println!("index scan: {} {} {} {:?}; rank {}", id, distance, url, shards, matches[found_id].1 + 1);
|
||||
};
|
||||
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
|
||||
}
|
||||
if args.verbose { println!("") }
|
||||
}
|
||||
|
||||
// results list is always correctly sorted
|
||||
for &rank in top_ranks.iter() {
|
||||
if rank < K {
|
||||
recall_total += 1;
|
||||
}
|
||||
}
|
||||
|
||||
top_rank_best_shard.push(top_ranks[0]);
|
||||
top_20_ranks_best_shard.extend(top_ranks);
|
||||
top_k_ranks_best_shard.extend(top_ranks);
|
||||
}
|
||||
|
||||
println!("ranks of top 20:");
|
||||
summary_stats(&mut top_20_ranks_best_shard);
|
||||
summary_stats(&mut top_k_ranks_best_shard);
|
||||
println!("ranks of top 1:");
|
||||
summary_stats(&mut top_rank_best_shard);
|
||||
println!("pq comparisons:");
|
||||
summary_stats(&mut pq_cmps);
|
||||
println!("comparisons:");
|
||||
summary_stats(&mut cmps);
|
||||
println!("recall@{}: {}", K, recall_total as f64 / (K * queries.len()) as f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user