mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-23 22:40:03 +00:00
broken roargraph implementation
This commit is contained in:
parent
f1283137d6
commit
92396a68fb
@ -134,7 +134,8 @@ pub struct ShardedRecord {
|
|||||||
pub id: u32,
|
pub id: u32,
|
||||||
#[serde(with="serde_bytes")]
|
#[serde(with="serde_bytes")]
|
||||||
pub vector: Vec<u8>, // FP16
|
pub vector: Vec<u8>, // FP16
|
||||||
pub query_knns: Vec<u32>
|
pub query_knns: Vec<u32>,
|
||||||
|
pub query_knns_distances: Vec<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||||
@ -167,3 +168,24 @@ pub struct IndexHeader {
|
|||||||
pub record_pad_size: usize,
|
pub record_pad_size: usize,
|
||||||
pub quantizer: diskann::vector::ProductQuantizer
|
pub quantizer: diskann::vector::ProductQuantizer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub mod index_config {
|
||||||
|
use diskann::IndexBuildConfig;
|
||||||
|
|
||||||
|
pub const BASE_CONFIG: IndexBuildConfig = IndexBuildConfig {
|
||||||
|
r: 64,
|
||||||
|
r_cap: 80,
|
||||||
|
l: 500,
|
||||||
|
maxc: 750,
|
||||||
|
alpha: 60000
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const PROJECTION_CUT_POINT: usize = 1;
|
||||||
|
|
||||||
|
pub const FIRST_PASS_ALPHA: i64 = 65536;
|
||||||
|
|
||||||
|
pub const SECOND_PASS_ALPHA: i64 = 62000;
|
||||||
|
|
||||||
|
pub const QUERY_SEARCH_K: usize = 200; // we want each query to have QUERY_REVERSE_K results, but some queries are likely more common than others in the top-k lists, so oversample a bit
|
||||||
|
pub const QUERY_REVERSE_K: usize = 100;
|
||||||
|
}
|
||||||
|
@ -19,7 +19,7 @@ use diskann::vector::{scale_dot_result_f64, ProductQuantizer};
|
|||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader};
|
use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader, index_config::QUERY_SEARCH_K};
|
||||||
|
|
||||||
#[derive(FromArgs)]
|
#[derive(FromArgs)]
|
||||||
#[argh(description="Process scraper dump files")]
|
#[argh(description="Process scraper dump files")]
|
||||||
@ -124,7 +124,6 @@ const SHARD_SPILL: usize = 2;
|
|||||||
const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size
|
const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size
|
||||||
const D_EMB: u32 = 1152;
|
const D_EMB: u32 = 1152;
|
||||||
const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0);
|
const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0);
|
||||||
const KNN_K: usize = 30;
|
|
||||||
const BALANCE_WEIGHT: f64 = 0.2;
|
const BALANCE_WEIGHT: f64 = 0.2;
|
||||||
const BATCH_SIZE: usize = 128;
|
const BATCH_SIZE: usize = 128;
|
||||||
|
|
||||||
@ -373,7 +372,7 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let index = queries_index.as_mut().context("need queries")?;
|
let index = queries_index.as_mut().context("need queries")?;
|
||||||
let knn_result = index.search(&knn_query, KNN_K)?;
|
let knn_result = index.search(&knn_query, QUERY_SEARCH_K)?;
|
||||||
|
|
||||||
for (i, (x, embedding)) in batch.iter().enumerate() {
|
for (i, (x, embedding)) in batch.iter().enumerate() {
|
||||||
// closest matches first
|
// closest matches first
|
||||||
@ -386,7 +385,8 @@ fn main() -> Result<()> {
|
|||||||
let entry = ShardedRecord {
|
let entry = ShardedRecord {
|
||||||
id: count + i as u32,
|
id: count + i as u32,
|
||||||
vector: x.embedding.clone(),
|
vector: x.embedding.clone(),
|
||||||
query_knns: knn_result.labels[i * KNN_K..(i + 1)*KNN_K].into_iter().map(|x| x.get().unwrap() as u32).collect()
|
query_knns: knn_result.labels[i * QUERY_SEARCH_K..(i + 1)*QUERY_SEARCH_K].into_iter().flat_map(|x| x.get().map(|x| x as u32)).collect(),
|
||||||
|
query_knns_distances: knn_result.distances[i * QUERY_SEARCH_K..(i + 1)*QUERY_SEARCH_K].into_iter().copied().collect()
|
||||||
};
|
};
|
||||||
let data = rmp_serde::to_vec(&entry)?;
|
let data = rmp_serde::to_vec(&entry)?;
|
||||||
for (_, file, shard_count, _shard_index) in shards[0..SHARD_SPILL].iter_mut() {
|
for (_, file, shard_count, _shard_index) in shards[0..SHARD_SPILL].iter_mut() {
|
||||||
|
@ -1,17 +1,31 @@
|
|||||||
use anyhow::{Result, Context};
|
use anyhow::{Result, Context};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use std::io::{BufReader, Write, BufWriter};
|
use std::io::{BufReader, Write, BufWriter, Seek};
|
||||||
use rmp_serde::decode::Error as DecodeError;
|
use rmp_serde::decode::Error as DecodeError;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer};
|
use std::collections::BinaryHeap;
|
||||||
|
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use common::{ShardInputHeader, ShardedRecord, ShardHeader};
|
use common::{index_config::{self, QUERY_REVERSE_K}, ShardHeader, ShardInputHeader, ShardedRecord};
|
||||||
|
|
||||||
const D_EMB: usize = 1152;
|
const D_EMB: usize = 1152;
|
||||||
|
|
||||||
|
fn report_degrees(graph: &IndexGraph) {
|
||||||
|
let mut total_degree = 0;
|
||||||
|
let mut degrees = Vec::with_capacity(graph.graph.len());
|
||||||
|
for out_neighbours in graph.graph.iter() {
|
||||||
|
let deg = out_neighbours.read().unwrap().len();
|
||||||
|
total_degree += deg;
|
||||||
|
degrees.push(deg);
|
||||||
|
}
|
||||||
|
degrees.sort_unstable();
|
||||||
|
println!("average degree {}", (total_degree as f32) / (graph.graph.len() as f32));
|
||||||
|
println!("median degree {}", degrees[degrees.len() / 2]);
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let mut rng = fastrand::Rng::new();
|
let mut rng = fastrand::Rng::new();
|
||||||
|
|
||||||
@ -19,20 +33,33 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let mut original_ids = vec![];
|
let mut original_ids = vec![];
|
||||||
let mut vector_data = vec![];
|
let mut vector_data = vec![];
|
||||||
let mut query_knns = vec![];
|
|
||||||
|
|
||||||
let header: ShardInputHeader = rmp_serde::from_read(&mut stream)?;
|
let header: ShardInputHeader = rmp_serde::from_read(&mut stream)?;
|
||||||
let centroid_fp16 = header.centroid.iter().map(|x| f16::from_f32(*x)).collect::<Vec<_>>();
|
let centroid_fp16 = header.centroid.iter().map(|x| f16::from_f32(*x)).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let mut query_knns_bwd = vec![BinaryHeap::new(); header.max_query_id];
|
||||||
|
query_knns_bwd.fill_with(|| BinaryHeap::with_capacity(QUERY_REVERSE_K));
|
||||||
|
|
||||||
{
|
{
|
||||||
let _timer = Timer::new("read shard");
|
let _timer = Timer::new("read shard vectors");
|
||||||
loop {
|
loop {
|
||||||
let res: Result<ShardedRecord, DecodeError> = rmp_serde::from_read(&mut stream);
|
let res: Result<ShardedRecord, DecodeError> = rmp_serde::from_read(&mut stream);
|
||||||
match res {
|
match res {
|
||||||
Ok(x) => {
|
Ok(x) => {
|
||||||
|
let current_local_id = original_ids.len() as u32;
|
||||||
original_ids.push(x.id);
|
original_ids.push(x.id);
|
||||||
vector_data.extend(bytemuck::cast_slice(&x.vector));
|
vector_data.extend(bytemuck::cast_slice(&x.vector));
|
||||||
query_knns.push(x.query_knns);
|
|
||||||
|
for (&query_id, &distance) in x.query_knns.iter().zip(x.query_knns_distances.iter()) {
|
||||||
|
let distance = scale_dot_result(distance);
|
||||||
|
// Rust BinaryHeap is a max-heap - we want the lowest-dot-product vectors to be discarded first
|
||||||
|
// So negate the distance
|
||||||
|
let knns = &mut query_knns_bwd[query_id as usize];
|
||||||
|
if knns.len() == QUERY_REVERSE_K {
|
||||||
|
knns.pop();
|
||||||
|
}
|
||||||
|
query_knns_bwd[query_id as usize].push((-distance, current_local_id));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||||
Err(e) => return Err(e).context("decode fail")
|
Err(e) => return Err(e).context("decode fail")
|
||||||
@ -40,13 +67,30 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut config = IndexBuildConfig {
|
let mut query_knns = vec![vec![]; original_ids.len()];
|
||||||
r: 64,
|
query_knns.fill_with(|| Vec::with_capacity(8));
|
||||||
r_cap: 80,
|
let mut query_knns_bwd_out = vec![vec![]; header.max_query_id];
|
||||||
l: 256,
|
query_knns_bwd_out.fill_with(|| Vec::with_capacity(QUERY_REVERSE_K));
|
||||||
maxc: 750,
|
|
||||||
alpha: 65536
|
{
|
||||||
};
|
let _timer = Timer::new("initialize bipartite graph");
|
||||||
|
// RoarGraph: out-edge from closest base vector to each query vector
|
||||||
|
for (query_id, distance_id_pairs) in query_knns_bwd.into_iter().enumerate() {
|
||||||
|
let vec = distance_id_pairs.into_sorted_vec();
|
||||||
|
let it = vec.into_iter();
|
||||||
|
|
||||||
|
for (i, (_distance, id)) in it.enumerate() {
|
||||||
|
if i < index_config::PROJECTION_CUT_POINT {
|
||||||
|
query_knns[id as usize].push(query_id as u32);
|
||||||
|
} else {
|
||||||
|
query_knns_bwd_out[query_id].push(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut config = common::index_config::BASE_CONFIG;
|
||||||
|
|
||||||
let vecs = VectorList {
|
let vecs = VectorList {
|
||||||
data: vector_data,
|
data: vector_data,
|
||||||
@ -57,48 +101,41 @@ fn main() -> Result<()> {
|
|||||||
let mut graph = IndexGraph::empty(original_ids.len(), config.r_cap);
|
let mut graph = IndexGraph::empty(original_ids.len(), config.r_cap);
|
||||||
|
|
||||||
{
|
{
|
||||||
//let _timer = Timer::new("project bipartite");
|
let _timer = Timer::new("project bipartite");
|
||||||
//project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd, config, &vecs);
|
project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd_out, config, &vecs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
report_degrees(&graph);
|
||||||
|
|
||||||
{
|
{
|
||||||
let _timer = Timer::new("random fill");
|
let _timer = Timer::new("random fill");
|
||||||
random_fill_graph(&mut rng, &mut graph, config.r);
|
random_fill_graph(&mut rng, &mut graph, config.r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
report_degrees(&graph);
|
||||||
|
|
||||||
let medioid = vecs.iter().position_max_by_key(|&v| {
|
let medioid = vecs.iter().position_max_by_key(|&v| {
|
||||||
dot(v, ¢roid_fp16)
|
dot(v, ¢roid_fp16)
|
||||||
}).unwrap() as u32;
|
}).unwrap() as u32;
|
||||||
|
|
||||||
{
|
{
|
||||||
let _timer = Timer::new("first pass");
|
let _timer = Timer::new("first pass");
|
||||||
|
config.alpha = common::index_config::FIRST_PASS_ALPHA;
|
||||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
report_degrees(&graph);
|
||||||
|
|
||||||
{
|
{
|
||||||
let _timer = Timer::new("second pass");
|
let _timer = Timer::new("second pass");
|
||||||
config.alpha = 80000;
|
config.alpha = common::index_config::SECOND_PASS_ALPHA;
|
||||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
report_degrees(&graph);
|
||||||
|
|
||||||
std::mem::drop(vecs);
|
std::mem::drop(vecs);
|
||||||
|
|
||||||
let mut query_knns_bwd = vec![vec![]; header.max_query_id];
|
|
||||||
|
|
||||||
{
|
|
||||||
let _timer = Timer::new("compute backward edges");
|
|
||||||
for (record_id, knns) in query_knns.iter().enumerate() {
|
|
||||||
for &k in knns {
|
|
||||||
query_knns_bwd[k as usize].push(record_id as u32);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let _timer = Timer::new("augment bipartite");
|
|
||||||
augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config);
|
|
||||||
}
|
|
||||||
|
|
||||||
let len = original_ids.len();
|
let len = original_ids.len();
|
||||||
|
|
||||||
{
|
{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user