mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-23 06:20:04 +00:00
broken roargraph implementation
This commit is contained in:
parent
f1283137d6
commit
92396a68fb
@ -134,7 +134,8 @@ pub struct ShardedRecord {
|
||||
pub id: u32,
|
||||
#[serde(with="serde_bytes")]
|
||||
pub vector: Vec<u8>, // FP16
|
||||
pub query_knns: Vec<u32>
|
||||
pub query_knns: Vec<u32>,
|
||||
pub query_knns_distances: Vec<f32>
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
@ -167,3 +168,24 @@ pub struct IndexHeader {
|
||||
pub record_pad_size: usize,
|
||||
pub quantizer: diskann::vector::ProductQuantizer
|
||||
}
|
||||
|
||||
pub mod index_config {
|
||||
use diskann::IndexBuildConfig;
|
||||
|
||||
pub const BASE_CONFIG: IndexBuildConfig = IndexBuildConfig {
|
||||
r: 64,
|
||||
r_cap: 80,
|
||||
l: 500,
|
||||
maxc: 750,
|
||||
alpha: 60000
|
||||
};
|
||||
|
||||
pub const PROJECTION_CUT_POINT: usize = 1;
|
||||
|
||||
pub const FIRST_PASS_ALPHA: i64 = 65536;
|
||||
|
||||
pub const SECOND_PASS_ALPHA: i64 = 62000;
|
||||
|
||||
pub const QUERY_SEARCH_K: usize = 200; // we want each query to have QUERY_REVERSE_K results, but some queries are likely more common than others in the top-k lists, so oversample a bit
|
||||
pub const QUERY_REVERSE_K: usize = 100;
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ use diskann::vector::{scale_dot_result_f64, ProductQuantizer};
|
||||
|
||||
mod common;
|
||||
|
||||
use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader};
|
||||
use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader, index_config::QUERY_SEARCH_K};
|
||||
|
||||
#[derive(FromArgs)]
|
||||
#[argh(description="Process scraper dump files")]
|
||||
@ -124,7 +124,6 @@ const SHARD_SPILL: usize = 2;
|
||||
const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size
|
||||
const D_EMB: u32 = 1152;
|
||||
const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0);
|
||||
const KNN_K: usize = 30;
|
||||
const BALANCE_WEIGHT: f64 = 0.2;
|
||||
const BATCH_SIZE: usize = 128;
|
||||
|
||||
@ -373,7 +372,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
let index = queries_index.as_mut().context("need queries")?;
|
||||
let knn_result = index.search(&knn_query, KNN_K)?;
|
||||
let knn_result = index.search(&knn_query, QUERY_SEARCH_K)?;
|
||||
|
||||
for (i, (x, embedding)) in batch.iter().enumerate() {
|
||||
// closest matches first
|
||||
@ -386,7 +385,8 @@ fn main() -> Result<()> {
|
||||
let entry = ShardedRecord {
|
||||
id: count + i as u32,
|
||||
vector: x.embedding.clone(),
|
||||
query_knns: knn_result.labels[i * KNN_K..(i + 1)*KNN_K].into_iter().map(|x| x.get().unwrap() as u32).collect()
|
||||
query_knns: knn_result.labels[i * QUERY_SEARCH_K..(i + 1)*QUERY_SEARCH_K].into_iter().flat_map(|x| x.get().map(|x| x as u32)).collect(),
|
||||
query_knns_distances: knn_result.distances[i * QUERY_SEARCH_K..(i + 1)*QUERY_SEARCH_K].into_iter().copied().collect()
|
||||
};
|
||||
let data = rmp_serde::to_vec(&entry)?;
|
||||
for (_, file, shard_count, _shard_index) in shards[0..SHARD_SPILL].iter_mut() {
|
||||
|
@ -1,17 +1,31 @@
|
||||
use anyhow::{Result, Context};
|
||||
use itertools::Itertools;
|
||||
use std::io::{BufReader, Write, BufWriter};
|
||||
use std::io::{BufReader, Write, BufWriter, Seek};
|
||||
use rmp_serde::decode::Error as DecodeError;
|
||||
use std::fs;
|
||||
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer};
|
||||
use std::collections::BinaryHeap;
|
||||
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer};
|
||||
use half::f16;
|
||||
|
||||
mod common;
|
||||
|
||||
use common::{ShardInputHeader, ShardedRecord, ShardHeader};
|
||||
use common::{index_config::{self, QUERY_REVERSE_K}, ShardHeader, ShardInputHeader, ShardedRecord};
|
||||
|
||||
const D_EMB: usize = 1152;
|
||||
|
||||
fn report_degrees(graph: &IndexGraph) {
|
||||
let mut total_degree = 0;
|
||||
let mut degrees = Vec::with_capacity(graph.graph.len());
|
||||
for out_neighbours in graph.graph.iter() {
|
||||
let deg = out_neighbours.read().unwrap().len();
|
||||
total_degree += deg;
|
||||
degrees.push(deg);
|
||||
}
|
||||
degrees.sort_unstable();
|
||||
println!("average degree {}", (total_degree as f32) / (graph.graph.len() as f32));
|
||||
println!("median degree {}", degrees[degrees.len() / 2]);
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let mut rng = fastrand::Rng::new();
|
||||
|
||||
@ -19,20 +33,33 @@ fn main() -> Result<()> {
|
||||
|
||||
let mut original_ids = vec![];
|
||||
let mut vector_data = vec![];
|
||||
let mut query_knns = vec![];
|
||||
|
||||
let header: ShardInputHeader = rmp_serde::from_read(&mut stream)?;
|
||||
let centroid_fp16 = header.centroid.iter().map(|x| f16::from_f32(*x)).collect::<Vec<_>>();
|
||||
|
||||
let mut query_knns_bwd = vec![BinaryHeap::new(); header.max_query_id];
|
||||
query_knns_bwd.fill_with(|| BinaryHeap::with_capacity(QUERY_REVERSE_K));
|
||||
|
||||
{
|
||||
let _timer = Timer::new("read shard");
|
||||
let _timer = Timer::new("read shard vectors");
|
||||
loop {
|
||||
let res: Result<ShardedRecord, DecodeError> = rmp_serde::from_read(&mut stream);
|
||||
match res {
|
||||
Ok(x) => {
|
||||
let current_local_id = original_ids.len() as u32;
|
||||
original_ids.push(x.id);
|
||||
vector_data.extend(bytemuck::cast_slice(&x.vector));
|
||||
query_knns.push(x.query_knns);
|
||||
|
||||
for (&query_id, &distance) in x.query_knns.iter().zip(x.query_knns_distances.iter()) {
|
||||
let distance = scale_dot_result(distance);
|
||||
// Rust BinaryHeap is a max-heap - we want the lowest-dot-product vectors to be discarded first
|
||||
// So negate the distance
|
||||
let knns = &mut query_knns_bwd[query_id as usize];
|
||||
if knns.len() == QUERY_REVERSE_K {
|
||||
knns.pop();
|
||||
}
|
||||
query_knns_bwd[query_id as usize].push((-distance, current_local_id));
|
||||
}
|
||||
},
|
||||
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||
Err(e) => return Err(e).context("decode fail")
|
||||
@ -40,13 +67,30 @@ fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let mut config = IndexBuildConfig {
|
||||
r: 64,
|
||||
r_cap: 80,
|
||||
l: 256,
|
||||
maxc: 750,
|
||||
alpha: 65536
|
||||
};
|
||||
let mut query_knns = vec![vec![]; original_ids.len()];
|
||||
query_knns.fill_with(|| Vec::with_capacity(8));
|
||||
let mut query_knns_bwd_out = vec![vec![]; header.max_query_id];
|
||||
query_knns_bwd_out.fill_with(|| Vec::with_capacity(QUERY_REVERSE_K));
|
||||
|
||||
{
|
||||
let _timer = Timer::new("initialize bipartite graph");
|
||||
// RoarGraph: out-edge from closest base vector to each query vector
|
||||
for (query_id, distance_id_pairs) in query_knns_bwd.into_iter().enumerate() {
|
||||
let vec = distance_id_pairs.into_sorted_vec();
|
||||
let it = vec.into_iter();
|
||||
|
||||
for (i, (_distance, id)) in it.enumerate() {
|
||||
if i < index_config::PROJECTION_CUT_POINT {
|
||||
query_knns[id as usize].push(query_id as u32);
|
||||
} else {
|
||||
query_knns_bwd_out[query_id].push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
let mut config = common::index_config::BASE_CONFIG;
|
||||
|
||||
let vecs = VectorList {
|
||||
data: vector_data,
|
||||
@ -57,48 +101,41 @@ fn main() -> Result<()> {
|
||||
let mut graph = IndexGraph::empty(original_ids.len(), config.r_cap);
|
||||
|
||||
{
|
||||
//let _timer = Timer::new("project bipartite");
|
||||
//project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd, config, &vecs);
|
||||
let _timer = Timer::new("project bipartite");
|
||||
project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd_out, config, &vecs);
|
||||
}
|
||||
|
||||
report_degrees(&graph);
|
||||
|
||||
{
|
||||
let _timer = Timer::new("random fill");
|
||||
random_fill_graph(&mut rng, &mut graph, config.r);
|
||||
}
|
||||
|
||||
report_degrees(&graph);
|
||||
|
||||
let medioid = vecs.iter().position_max_by_key(|&v| {
|
||||
dot(v, ¢roid_fp16)
|
||||
}).unwrap() as u32;
|
||||
|
||||
{
|
||||
let _timer = Timer::new("first pass");
|
||||
config.alpha = common::index_config::FIRST_PASS_ALPHA;
|
||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||
}
|
||||
|
||||
report_degrees(&graph);
|
||||
|
||||
{
|
||||
let _timer = Timer::new("second pass");
|
||||
config.alpha = 80000;
|
||||
config.alpha = common::index_config::SECOND_PASS_ALPHA;
|
||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||
}
|
||||
|
||||
report_degrees(&graph);
|
||||
|
||||
std::mem::drop(vecs);
|
||||
|
||||
let mut query_knns_bwd = vec![vec![]; header.max_query_id];
|
||||
|
||||
{
|
||||
let _timer = Timer::new("compute backward edges");
|
||||
for (record_id, knns) in query_knns.iter().enumerate() {
|
||||
for &k in knns {
|
||||
query_knns_bwd[k as usize].push(record_id as u32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let _timer = Timer::new("augment bipartite");
|
||||
augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config);
|
||||
}
|
||||
|
||||
let len = original_ids.len();
|
||||
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user