1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-02-23 06:20:04 +00:00

broken roargraph implementation

This commit is contained in:
osmarks 2025-01-02 19:57:17 +00:00
parent f1283137d6
commit 92396a68fb
3 changed files with 96 additions and 37 deletions

View File

@ -134,7 +134,8 @@ pub struct ShardedRecord {
pub id: u32,
#[serde(with="serde_bytes")]
pub vector: Vec<u8>, // FP16
pub query_knns: Vec<u32>
pub query_knns: Vec<u32>,
pub query_knns_distances: Vec<f32>
}
#[derive(Clone, Deserialize, Serialize, Debug)]
@ -167,3 +168,24 @@ pub struct IndexHeader {
pub record_pad_size: usize,
pub quantizer: diskann::vector::ProductQuantizer
}
pub mod index_config {
use diskann::IndexBuildConfig;
pub const BASE_CONFIG: IndexBuildConfig = IndexBuildConfig {
r: 64,
r_cap: 80,
l: 500,
maxc: 750,
alpha: 60000
};
pub const PROJECTION_CUT_POINT: usize = 1;
pub const FIRST_PASS_ALPHA: i64 = 65536;
pub const SECOND_PASS_ALPHA: i64 = 62000;
pub const QUERY_SEARCH_K: usize = 200; // we want each query to have QUERY_REVERSE_K results, but some queries are likely more common than others in the top-k lists, so oversample a bit
pub const QUERY_REVERSE_K: usize = 100;
}

View File

@ -19,7 +19,7 @@ use diskann::vector::{scale_dot_result_f64, ProductQuantizer};
mod common;
use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader};
use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader, index_config::QUERY_SEARCH_K};
#[derive(FromArgs)]
#[argh(description="Process scraper dump files")]
@ -124,7 +124,6 @@ const SHARD_SPILL: usize = 2;
const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size
const D_EMB: u32 = 1152;
const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0);
const KNN_K: usize = 30;
const BALANCE_WEIGHT: f64 = 0.2;
const BATCH_SIZE: usize = 128;
@ -373,7 +372,7 @@ fn main() -> Result<()> {
}
let index = queries_index.as_mut().context("need queries")?;
let knn_result = index.search(&knn_query, KNN_K)?;
let knn_result = index.search(&knn_query, QUERY_SEARCH_K)?;
for (i, (x, embedding)) in batch.iter().enumerate() {
// closest matches first
@ -386,7 +385,8 @@ fn main() -> Result<()> {
let entry = ShardedRecord {
id: count + i as u32,
vector: x.embedding.clone(),
query_knns: knn_result.labels[i * KNN_K..(i + 1)*KNN_K].into_iter().map(|x| x.get().unwrap() as u32).collect()
query_knns: knn_result.labels[i * QUERY_SEARCH_K..(i + 1)*QUERY_SEARCH_K].into_iter().flat_map(|x| x.get().map(|x| x as u32)).collect(),
query_knns_distances: knn_result.distances[i * QUERY_SEARCH_K..(i + 1)*QUERY_SEARCH_K].into_iter().copied().collect()
};
let data = rmp_serde::to_vec(&entry)?;
for (_, file, shard_count, _shard_index) in shards[0..SHARD_SPILL].iter_mut() {

View File

@ -1,17 +1,31 @@
use anyhow::{Result, Context};
use itertools::Itertools;
use std::io::{BufReader, Write, BufWriter};
use std::io::{BufReader, Write, BufWriter, Seek};
use rmp_serde::decode::Error as DecodeError;
use std::fs;
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer};
use std::collections::BinaryHeap;
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer};
use half::f16;
mod common;
use common::{ShardInputHeader, ShardedRecord, ShardHeader};
use common::{index_config::{self, QUERY_REVERSE_K}, ShardHeader, ShardInputHeader, ShardedRecord};
const D_EMB: usize = 1152;
fn report_degrees(graph: &IndexGraph) {
let mut total_degree = 0;
let mut degrees = Vec::with_capacity(graph.graph.len());
for out_neighbours in graph.graph.iter() {
let deg = out_neighbours.read().unwrap().len();
total_degree += deg;
degrees.push(deg);
}
degrees.sort_unstable();
println!("average degree {}", (total_degree as f32) / (graph.graph.len() as f32));
println!("median degree {}", degrees[degrees.len() / 2]);
}
fn main() -> Result<()> {
let mut rng = fastrand::Rng::new();
@ -19,20 +33,33 @@ fn main() -> Result<()> {
let mut original_ids = vec![];
let mut vector_data = vec![];
let mut query_knns = vec![];
let header: ShardInputHeader = rmp_serde::from_read(&mut stream)?;
let centroid_fp16 = header.centroid.iter().map(|x| f16::from_f32(*x)).collect::<Vec<_>>();
let mut query_knns_bwd = vec![BinaryHeap::new(); header.max_query_id];
query_knns_bwd.fill_with(|| BinaryHeap::with_capacity(QUERY_REVERSE_K));
{
let _timer = Timer::new("read shard");
let _timer = Timer::new("read shard vectors");
loop {
let res: Result<ShardedRecord, DecodeError> = rmp_serde::from_read(&mut stream);
match res {
Ok(x) => {
let current_local_id = original_ids.len() as u32;
original_ids.push(x.id);
vector_data.extend(bytemuck::cast_slice(&x.vector));
query_knns.push(x.query_knns);
for (&query_id, &distance) in x.query_knns.iter().zip(x.query_knns_distances.iter()) {
let distance = scale_dot_result(distance);
// Rust BinaryHeap is a max-heap - we want the lowest-dot-product vectors to be discarded first
// So negate the distance
let knns = &mut query_knns_bwd[query_id as usize];
if knns.len() == QUERY_REVERSE_K {
knns.pop();
}
query_knns_bwd[query_id as usize].push((-distance, current_local_id));
}
},
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e).context("decode fail")
@ -40,13 +67,30 @@ fn main() -> Result<()> {
}
}
let mut config = IndexBuildConfig {
r: 64,
r_cap: 80,
l: 256,
maxc: 750,
alpha: 65536
};
let mut query_knns = vec![vec![]; original_ids.len()];
query_knns.fill_with(|| Vec::with_capacity(8));
let mut query_knns_bwd_out = vec![vec![]; header.max_query_id];
query_knns_bwd_out.fill_with(|| Vec::with_capacity(QUERY_REVERSE_K));
{
let _timer = Timer::new("initialize bipartite graph");
// RoarGraph: out-edge from closest base vector to each query vector
for (query_id, distance_id_pairs) in query_knns_bwd.into_iter().enumerate() {
let vec = distance_id_pairs.into_sorted_vec();
let it = vec.into_iter();
for (i, (_distance, id)) in it.enumerate() {
if i < index_config::PROJECTION_CUT_POINT {
query_knns[id as usize].push(query_id as u32);
} else {
query_knns_bwd_out[query_id].push(id);
}
}
}
}
let mut config = common::index_config::BASE_CONFIG;
let vecs = VectorList {
data: vector_data,
@ -57,48 +101,41 @@ fn main() -> Result<()> {
let mut graph = IndexGraph::empty(original_ids.len(), config.r_cap);
{
//let _timer = Timer::new("project bipartite");
//project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd, config, &vecs);
let _timer = Timer::new("project bipartite");
project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd_out, config, &vecs);
}
report_degrees(&graph);
{
let _timer = Timer::new("random fill");
random_fill_graph(&mut rng, &mut graph, config.r);
}
report_degrees(&graph);
let medioid = vecs.iter().position_max_by_key(|&v| {
dot(v, &centroid_fp16)
}).unwrap() as u32;
{
let _timer = Timer::new("first pass");
config.alpha = common::index_config::FIRST_PASS_ALPHA;
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
}
report_degrees(&graph);
{
let _timer = Timer::new("second pass");
config.alpha = 80000;
config.alpha = common::index_config::SECOND_PASS_ALPHA;
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
}
report_degrees(&graph);
std::mem::drop(vecs);
let mut query_knns_bwd = vec![vec![]; header.max_query_id];
{
let _timer = Timer::new("compute backward edges");
for (record_id, knns) in query_knns.iter().enumerate() {
for &k in knns {
query_knns_bwd[k as usize].push(record_id as u32);
}
}
}
{
let _timer = Timer::new("augment bipartite");
augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config);
}
let len = original_ids.len();
{