mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-05-09 02:34:07 +00:00
fix overflow bug
This commit is contained in:
parent
e0cf65204b
commit
35df1201e2
@ -7,7 +7,7 @@ use std::{io::Read, time::Instant};
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
|
||||||
use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, dot, VectorList, self}, Timer};
|
use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer};
|
||||||
use simsimd::SpatialSimilarity;
|
use simsimd::SpatialSimilarity;
|
||||||
|
|
||||||
const D_EMB: usize = 1152;
|
const D_EMB: usize = 1152;
|
||||||
@ -40,6 +40,7 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
let pq_scores = codec.asymmetric_dot_product(&query, &codes);
|
let pq_scores = codec.asymmetric_dot_product(&query, &codes);
|
||||||
for (x, y) in real_scores.iter().zip(pq_scores.iter()) {
|
for (x, y) in real_scores.iter().zip(pq_scores.iter()) {
|
||||||
|
let y = (*y as f32) / SCALE;
|
||||||
println!("{} {} {} {}", x, y, x - y, (x - y) / x);
|
println!("{} {} {} {}", x, y, x - y, (x - y) / x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@ use half::f16;
|
|||||||
use simsimd::SpatialSimilarity;
|
use simsimd::SpatialSimilarity;
|
||||||
use fastrand::Rng;
|
use fastrand::Rng;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use tracing_subscriber::field::RecordFields;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Vector(Vec<f16>);
|
pub struct Vector(Vec<f16>);
|
||||||
@ -46,12 +45,12 @@ impl Vector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
|
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
|
||||||
const SCALE: f32 = 281474976710656.0;
|
pub const SCALE: f32 = 1099511627776.0;
|
||||||
const SCALE_F64: f64 = SCALE as f64;
|
pub const SCALE_F64: f64 = SCALE as f64;
|
||||||
|
|
||||||
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> f32 {
|
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 {
|
||||||
// safety is not real
|
// safety is not real
|
||||||
(simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap()) as f32
|
scale_dot_result_f64(simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_svector(vec: VectorRef) -> SVector {
|
pub fn to_svector(vec: VectorRef) -> SVector {
|
||||||
@ -404,11 +403,17 @@ impl ProductQuantizer {
|
|||||||
// I have no idea why but we somehow have significant degradation in search quality
|
// I have no idea why but we somehow have significant degradation in search quality
|
||||||
// if this accumulates in integers. As such, do floats and convert at the end.
|
// if this accumulates in integers. As such, do floats and convert at the end.
|
||||||
// I'm sure there are fascinating reasons for this, but God is dead, God remains dead, etc.
|
// I'm sure there are fascinating reasons for this, but God is dead, God remains dead, etc.
|
||||||
scores.into_iter().map(|x| (x * SCALE) as i64).collect()
|
scores.into_iter().map(scale_dot_result).collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn scale_dot_result(x: f64) -> i64 {
|
#[inline]
|
||||||
|
pub fn scale_dot_result(x: f32) -> i64 {
|
||||||
|
(x * SCALE) as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn scale_dot_result_f64(x: f64) -> i64 {
|
||||||
(x * SCALE_F64) as i64
|
(x * SCALE_F64) as i64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user