diff --git a/diskann/src/main.rs b/diskann/src/main.rs index d687424..0d7593e 100644 --- a/diskann/src/main.rs +++ b/diskann/src/main.rs @@ -7,7 +7,7 @@ use std::{io::Read, time::Instant}; use anyhow::Result; use half::f16; -use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, dot, VectorList, self}, Timer}; +use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer}; use simsimd::SpatialSimilarity; const D_EMB: usize = 1152; @@ -40,6 +40,7 @@ fn main() -> Result<()> { } let pq_scores = codec.asymmetric_dot_product(&query, &codes); for (x, y) in real_scores.iter().zip(pq_scores.iter()) { + let y = (*y as f32) / SCALE; println!("{} {} {} {}", x, y, x - y, (x - y) / x); } } diff --git a/diskann/src/vector.rs b/diskann/src/vector.rs index 2283976..7196164 100644 --- a/diskann/src/vector.rs +++ b/diskann/src/vector.rs @@ -4,7 +4,6 @@ use half::f16; use simsimd::SpatialSimilarity; use fastrand::Rng; use serde::{Serialize, Deserialize}; -use tracing_subscriber::field::RecordFields; #[derive(Debug, Clone)] pub struct Vector(Vec); @@ -46,12 +45,12 @@ impl Vector { } // Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers -const SCALE: f32 = 281474976710656.0; -const SCALE_F64: f64 = SCALE as f64; +pub const SCALE: f32 = 1099511627776.0; +pub const SCALE_F64: f64 = SCALE as f64; -pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> f32 { +pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 { // safety is not real - (simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap()) as f32 + scale_dot_result_f64(simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap()) } pub fn to_svector(vec: VectorRef) -> SVector { @@ -404,11 +403,17 @@ impl ProductQuantizer { // I have no idea why but we somehow have significant degradation in search quality // if this accumulates in integers. As such, do floats and convert at the end. // I'm sure there are fascinating reasons for this, but God is dead, God remains dead, etc. - scores.into_iter().map(|x| (x * SCALE) as i64).collect() + scores.into_iter().map(scale_dot_result).collect() } } -pub fn scale_dot_result(x: f64) -> i64 { +#[inline] +pub fn scale_dot_result(x: f32) -> i64 { + (x * SCALE) as i64 +} + +#[inline] +pub fn scale_dot_result_f64(x: f64) -> i64 { (x * SCALE_F64) as i64 }