1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-11-07 18:54:06 +00:00

RobustVamana algorithm for big index run

This commit is contained in:
osmarks
2025-01-16 21:10:12 +00:00
parent d341a8c243
commit f4376f62ed
8 changed files with 243 additions and 174 deletions

40
Cargo.lock generated
View File

@@ -402,18 +402,6 @@ version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@@ -808,7 +796,6 @@ name = "diskann"
version = "0.1.0"
dependencies = [
"anyhow",
"bitvec",
"bytemuck",
"fastrand",
"foldhash",
@@ -1050,12 +1037,6 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "futures-channel"
version = "0.3.31"
@@ -2519,12 +2500,6 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rand"
version = "0.8.5"
@@ -3461,12 +3436,6 @@ dependencies = [
"version-compare",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "target-lexicon"
version = "0.12.16"
@@ -4294,15 +4263,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]
[[package]]
name = "zerocopy"
version = "0.7.35"

3
build.rs Normal file
View File

@@ -0,0 +1,3 @@
fn main() {
println!("cargo::rustc-link-search=/usr/local/lib/");
}

View File

@@ -10,7 +10,6 @@ tracing = "0.1"
tracing-subscriber = "0.3"
simsimd = "6"
foldhash = "0.1"
bitvec = "1"
tqdm = "0.7"
anyhow = "1"
bytemuck = { version = "1", features = ["extern_crate_alloc"] }

View File

@@ -3,7 +3,7 @@
extern crate test;
use foldhash::{HashSet, HashMap, HashMapExt, HashSetExt};
use foldhash::{HashSet, HashSetExt};
use fastrand::Rng;
use rayon::prelude::*;
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, Mutex};
@@ -44,7 +44,10 @@ pub struct IndexBuildConfig {
pub l: usize,
pub maxc: usize,
pub alpha: i64,
pub saturate_graph: bool
pub saturate_graph: bool,
pub query_breakpoint: u32, // above this nodes are queries and not base vectors
pub max_add_per_stitch_iter: usize,
pub query_alpha: i64
}
@@ -160,7 +163,7 @@ pub struct Scratch {
}
impl Scratch {
pub fn new(IndexBuildConfig { l, r, maxc, .. }: IndexBuildConfig) -> Self {
pub fn new(IndexBuildConfig { l, r, .. }: IndexBuildConfig) -> Self {
Scratch {
visited: HashSet::with_capacity(l * 8),
neighbour_buffer: NeighbourBuffer::new(l),
@@ -177,7 +180,7 @@ pub struct GreedySearchCounters {
// Algorithm 1 from the DiskANN paper
// We support the dot product metric only, so we want to keep things with the HIGHEST dot product
pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters {
pub fn greedy_search(scratch: &mut Scratch, start: u32, base_vectors_only: bool, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters {
scratch.visited.clear();
scratch.neighbour_buffer.clear();
scratch.visited_list.clear();
@@ -190,7 +193,8 @@ pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs:
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
scratch.neighbour_pre_buffer.clear();
for &neighbour in graph.out_neighbours(pt).iter() {
if scratch.visited.insert(neighbour) {
let neighbour_is_query = neighbour >= config.query_breakpoint; // OOD-DiskANN page 4: if we are searching for a query, only consider results in base vectors
if scratch.visited.insert(neighbour) && !(base_vectors_only && neighbour_is_query) {
scratch.neighbour_pre_buffer.push(neighbour);
}
}
@@ -208,7 +212,7 @@ pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs:
type CandidateList = Vec<(u32, i64)>;
fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList, config: IndexBuildConfig) {
fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList) {
let p_vec = &vecs[point as usize];
for (i, &n) in neigh.iter().enumerate() {
let dot = fast_dot(p_vec, &vecs[n as usize], &vecs[neigh[(i + 1) % neigh.len() as usize] as usize]);
@@ -218,7 +222,7 @@ fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh:
// "Robust prune" algorithm, kind of
// The algorithm in the paper does not actually match the code as implemented in microsoft/DiskANN
// and that's slightly different from the one in ParlayANN for no reason
// and that's slightly different from the one in ParlayANN for no clear reason
// This is closer to ParlayANN
fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &VectorList, config: IndexBuildConfig) {
neigh.clear();
@@ -254,7 +258,12 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
let next_vec = &vecs[scratch.robust_prune_scratch_buffer[(i + 1) % scratch.robust_prune_scratch_buffer.len()].0 as usize];
let p_star_prime_score = fast_dot(&vecs[p_prime as usize], &vecs[p_star as usize], next_vec);
let p_prime_p_score = candidates[ci].1;
let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16;
let con_alpha = if p_prime >= config.query_breakpoint {
config.query_alpha
} else {
config.alpha
};
let alpha_times_p_star_prime_score = (con_alpha * p_star_prime_score) >> 16;
if alpha_times_p_star_prime_score >= p_prime_p_score {
candidates[ci].1 = i64::MIN;
@@ -262,7 +271,8 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
}
}
if config.saturate_graph {
// saturate graph on for query points - otherwise they get no neighbours, more or less
if config.saturate_graph || p >= config.query_breakpoint {
for &(id, _score) in candidates.iter() {
if neigh.len() == config.r {
return;
@@ -276,21 +286,21 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &VectorList, config: IndexBuildConfig) {
assert!(vecs.len() < u32::MAX as usize);
assert_eq!(vecs.len(), graph.graph.len());
let mut sigmas: Vec<u32> = (0..(vecs.len() as u32)).collect();
rng.shuffle(&mut sigmas);
let rng = Mutex::new(rng.fork());
//let scratch = &mut Scratch::new(config);
//let mut rng = rng.lock().unwrap();
sigmas.into_par_iter().for_each_init(|| (Scratch::new(config), rng.lock().unwrap().fork()), |(scratch, rng), sigma_i| {
sigmas.into_par_iter().for_each_init(|| Scratch::new(config), |scratch, sigma_i| {
//sigmas.into_iter().for_each(|sigma_i| {
greedy_search(scratch, medioid, &vecs[sigma_i as usize], vecs, &graph, config);
let is_query = sigma_i >= config.query_breakpoint;
greedy_search(scratch, medioid, is_query, &vecs[sigma_i as usize], vecs, &graph, config);
{
let n = graph.out_neighbours(sigma_i);
merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs, config);
merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs);
}
{
@@ -303,8 +313,8 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
if neighbour_neighbours.len() == config.r {
scratch.visited_list.clear();
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config);
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs);
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs);
robust_prune(scratch, neighbour, &mut neighbour_neighbours, vecs, config);
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r {
neighbour_neighbours.push(sigma_i);
@@ -313,24 +323,56 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
});
}
pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig, max_iters: usize) {
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
rng.shuffle(&mut sigmas);
pub fn robust_stitch(rng: &mut Rng, graph: &mut IndexGraph, vecs: &VectorList, config: IndexBuildConfig) {
let n_queries = graph.graph.len() as u32 - config.query_breakpoint;
let mut in_edges = Vec::with_capacity(n_queries as usize);
for _i in 0..(n_queries as usize) {
in_edges.push(Vec::with_capacity(config.r as usize));
}
// Iterate through graph vertices in a random order
let rng = Mutex::new(rng.fork());
sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
let mut neighbours = graph.out_neighbours_mut(sigma_i);
let mut i = 0;
while neighbours.len() < config.r && i < max_iters {
let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
if !neighbours.contains(&projected_neighbour) {
neighbours.push(projected_neighbour);
let mut queries_order = (config.query_breakpoint..(graph.graph.len() as u32)).collect::<Vec<u32>>();
rng.shuffle(&mut queries_order);
for base_i in 0..config.query_breakpoint {
let mut out_neighbours = graph.out_neighbours_mut(base_i);
// store out-edges (to queries) from each base data node with corresponding query node and drop out-edges to queries
out_neighbours.retain(|&out_neighbour_out_edge| {
let is_query = out_neighbour_out_edge >= config.query_breakpoint;
if is_query {
in_edges[(out_neighbour_out_edge - config.query_breakpoint) as usize].push(base_i);
}
i += 1;
!is_query
});
}
queries_order.into_par_iter().for_each(|query_i| {
// For each query, fill spare space at in-neighbours with query's out-neighbours
// The OOD-DiskANN paper itself seems to fill *all* the spare space at once with (out-neighbours of) the first query which is encountered, which feels like an odd choice.
// We have a switch for that instead.
let query_out_neighbours = graph.out_neighbours(query_i);
println!("{} has {} in {} out", query_i, in_edges[(query_i - config.query_breakpoint) as usize].len(), query_out_neighbours.len());
for &in_neighbour in in_edges[(query_i - config.query_breakpoint) as usize].iter() {
let mut candidates = Vec::with_capacity(query_out_neighbours.len());
for (i, &neigh) in query_out_neighbours.iter().enumerate() {
let score = fast_dot(&vecs[in_neighbour as usize], &vecs[neigh as usize], &vecs[query_out_neighbours[(i + 1) % query_out_neighbours.len()] as usize]);
candidates.push((neigh, score));
}
candidates.sort_unstable_by_key(|(_neigh, score)| -*score);
let mut in_neighbour_out_edges = graph.out_neighbours_mut(in_neighbour);
let mut added = 0;
for (neigh, _score) in candidates {
if added >= config.max_add_per_stitch_iter {
break;
}
if in_neighbour_out_edges.contains(&neigh) {
continue;
}
in_neighbour_out_edges.push(neigh);
added += 1;
}
println!("wrote {} out to {}", added, in_neighbour);
}
})
});
}
pub fn random_fill_graph(rng: &mut Rng, graph: &mut IndexGraph, r: usize) {

50
generate_queries_bin.py Normal file
View File

@@ -0,0 +1,50 @@
import sys, aiohttp, msgpack, numpy, pgvector.asyncpg, asyncio
async def use_emb_server(sess, query):
async with sess.post("http://100.64.0.10:1708/", data=msgpack.dumps(query), timeout=aiohttp.ClientTimeout(connect=5, sock_connect=5, sock_read=None)) as res:
response = msgpack.loads(await res.read())
if res.status == 200:
return response
else:
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
BATCH_SIZE = 32
async def main():
with open("query_data.bin", "wb") as f:
with open("queries.txt", "r") as g:
write_lock = asyncio.Lock()
async with aiohttp.ClientSession() as sess:
async with asyncio.TaskGroup() as tg:
sem = asyncio.Semaphore(3)
async def process_batch(batch):
while True:
try:
embs = await use_emb_server(sess, { "text": batch })
async with write_lock:
f.write(b"".join(embs))
sys.stdout.write(".")
sys.stdout.flush()
break
except Exception as e:
print(e)
await asyncio.sleep(5)
sem.release()
async def dispatch(batch):
await sem.acquire()
tg.create_task(process_batch(batch))
batch = []
while line := g.readline():
if line.strip(): batch.append(line.strip())
if len(batch) == BATCH_SIZE:
await dispatch(batch)
batch = []
if len(batch) > 0:
await dispatch(batch)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -125,16 +125,14 @@ pub struct ProcessedEntry {
#[derive(Clone, Deserialize, Serialize, Debug)]
pub struct ShardInputHeader {
pub id: u32,
pub centroid: Vec<f32>,
pub max_query_id: usize
pub centroid: Vec<f32>
}
#[derive(Clone, Deserialize, Serialize, Debug)]
pub struct ShardedRecord {
pub id: u32,
#[serde(with="serde_bytes")]
pub vector: Vec<u8>, // FP16
pub query_knns: Vec<u32>
pub vector: Vec<u8> // FP16
}
#[derive(Clone, Deserialize, Serialize, Debug)]

View File

@@ -1,6 +1,6 @@
use anyhow::{bail, Context, Result};
use serde::{Serialize, Deserialize};
use std::io::{BufReader, Read, Seek, SeekFrom, Write, BufWriter};
use std::io::{BufReader, Write, BufWriter};
use std::path::PathBuf;
use rmp_serde::decode::Error as DecodeError;
use std::fs;
@@ -8,7 +8,6 @@ use base64::Engine;
use argh::FromArgs;
use chrono::{TimeZone, Utc, DateTime};
use std::collections::VecDeque;
use faiss::Index;
use std::sync::mpsc::{sync_channel, SyncSender};
use itertools::Itertools;
use simsimd::SpatialSimilarity;
@@ -47,8 +46,6 @@ struct CLIArguments {
centroids: Option<String>,
#[argh(option, short='S', description="index shard directory")]
shards_dir: Option<String>,
#[argh(option, short='Q', description="query vectors file")]
queries: Option<String>,
#[argh(option, short='d', description="random seed")]
seed: Option<u64>,
#[argh(option, short='i', description="index output directory")]
@@ -62,7 +59,9 @@ struct CLIArguments {
#[argh(option, short='q', description="product quantization codec path")]
pq_codec: Option<String>,
#[argh(switch, short='j', description="JSON output")]
json: bool
json: bool,
#[argh(option, short='f', description="k-means balance fudge factor", default="0.2")]
balance_fudge: f64,
}
#[derive(Clone, Deserialize, Serialize, Debug)]
@@ -125,9 +124,19 @@ const SHARD_SPILL: usize = 2;
const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size
const D_EMB: u32 = 1152;
const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0);
const KNN_K: usize = 30;
const BALANCE_WEIGHT: f64 = 0.2;
const BATCH_SIZE: usize = 128;
const BATCH_SIZE: usize = 1024;
#[derive(Clone, Serialize, Debug)]
pub struct JsonEntry<'a> {
pub url: &'a str,
pub id: &'a str,
pub title: &'a str,
pub subreddit: &'a str,
pub author: &'a str,
pub timestamp: u64,
pub embedding: &'a [f32],
pub metadata: common::OriginalImageMetadata
}
fn main() -> Result<()> {
let args: CLIArguments = argh::from_env();
@@ -158,37 +167,6 @@ fn main() -> Result<()> {
None
};
// construct FAISS index over query vectors for kNNs
let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries {
println!("constructing index");
// not memory-efficient but this is small
let mut file = fs::File::open(queries_file).context("read queries file")?;
let mut size = file.metadata()?.len();
//let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?;
let mut index = faiss::index_factory(D_EMB, "HNSW64,SQ8", faiss::MetricType::InnerProduct)?;
//let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?;
let mut buf = vec![0; (D_EMB as usize) * (1<<18)];
loop {
if size == 0 {
break;
}
if size < (buf.len() as u64) {
buf.resize(size as usize, 0);
}
file.read_exact(&mut buf)?;
size -= buf.len() as u64;
let unpacked = common::decode_fp16_buffer(&buf);
if !index.is_trained() { index.train(&unpacked)?; print!("train"); }
index.add(&unpacked)?;
print!(".");
}
println!("done");
let ntotal = index.ntotal();
(Some(index), ntotal as usize)
} else {
(None, 0)
};
// if sufficient config to split index exists, set up output files
let mut shards_out = if let (Some(shards_dir), Some(centroids)) = (&args.shards_dir, &args.centroids) {
let mut shards = Vec::new();
@@ -202,10 +180,12 @@ fn main() -> Result<()> {
for i in 0..(centroids_data.len() / (D_EMB as usize)) {
let centroid = centroids_data[i * (D_EMB as usize)..(i + 1) * (D_EMB as usize)].to_vec();
let mut file = fs::File::create(PathBuf::from(shards_dir).join(format!("{}.shard.msgpack", i))).context("create shard file")?;
rmp_serde::encode::write(&mut file, &ShardInputHeader { id: i as u32, centroid: centroid.clone(), max_query_id })?;
rmp_serde::encode::write(&mut file, &ShardInputHeader { id: i as u32, centroid: centroid.clone() })?;
shards.push((centroid, file, 0, i));
}
println!("splitting into {} shards", shards.len());
Some(shards)
} else {
None
@@ -376,6 +356,21 @@ fn main() -> Result<()> {
if args.print_embeddings {
println!("https://mse.osmarks.net/?e={}", base64::engine::general_purpose::URL_SAFE.encode(&x.embedding));
}
// this is not a very compact format, but I am lazy and this will never be a performance bottleneck
if args.json {
let entry = JsonEntry {
url: &x.url,
id: &x.id,
title: &x.title,
subreddit: &x.subreddit,
author: &x.author,
timestamp: x.timestamp,
embedding: &embedding,
metadata: x.metadata.clone()
};
let data = serde_json::to_string(&entry).unwrap();
println!("{}", data);
}
Some((x, embedding))
};
@@ -395,26 +390,17 @@ fn main() -> Result<()> {
}
if let Some(shards) = &mut shards_out {
let mut knn_query = vec![];
for (_, embedding) in batch.iter() {
knn_query.extend(embedding);
}
let index = queries_index.as_mut().context("need queries")?;
let knn_result = index.search(&knn_query, KNN_K)?;
for (i, (x, embedding)) in batch.iter().enumerate() {
// closest matches first
shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| {
let mut dot = SpatialSimilarity::dot(&centroid, &embedding).unwrap();
dot -= BALANCE_WEIGHT * (shard_count as f64 / bal_count as f64);
dot -= args.balance_fudge * (shard_count as f64 / bal_count as f64);
-scale_dot_result_f64(dot)
});
let entry = ShardedRecord {
id: count + i as u32,
vector: x.embedding.clone(),
query_knns: knn_result.labels[i * KNN_K..(i + 1)*KNN_K].into_iter().map(|x| x.get().unwrap() as u32).collect()
vector: x.embedding.clone()
};
let data = rmp_serde::to_vec(&entry)?;
for (_, file, shard_count, _shard_index) in shards[0..SHARD_SPILL].iter_mut() {

View File

@@ -1,28 +1,57 @@
use anyhow::{Result, Context};
use itertools::Itertools;
use std::io::{BufReader, BufWriter, Write};
use std::io::{BufReader, BufWriter, Write, Read};
use rmp_serde::decode::Error as DecodeError;
use std::fs;
use diskann::{augment_bipartite, build_graph, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid};
use diskann::{build_graph, random_fill_graph, vector::VectorList, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid, robust_stitch};
use half::f16;
use argh::FromArgs;
mod common;
use common::{ShardInputHeader, ShardedRecord, ShardHeader};
#[derive(FromArgs)]
#[argh(description="Generate indices from shard files")]
struct CLIArguments {
#[argh(positional)]
input_file: String,
#[argh(positional)]
out_dir: String,
#[argh(positional)]
queries_bin: Option<String>,
#[argh(option, short='L', default="192", description="search list size (higher is better but slower)")]
l: usize,
#[argh(option, short='R', default="64", description="graph degree")]
r: usize,
#[argh(option, short='C', default="750", description="max candidate list size")]
maxc: usize,
#[argh(option, short='A', default="65536", description="first pass relaxation factor (times 2^16)")]
alpha: i64,
#[argh(option, short='Q', default="65536", description="query set special relaxation factor (times 2^16)")]
query_alpha: i64,
#[argh(option, short='B', default="65536", description="second pass relaxation factor (times 2^16)")]
alpha_2: i64,
#[argh(switch, short='s', description="do second pass")]
second_pass: bool,
#[argh(option, short='N', description="number of vectors to allocate for")]
n: Option<usize>
}
const D_EMB: usize = 1152;
const READ_CHUNK_SIZE: usize = D_EMB * size_of::<f16>() * 1024;
fn main() -> Result<()> {
let args: CLIArguments = argh::from_env();
let mut rng = fastrand::Rng::new();
let mut stream = BufReader::new(fs::File::open(std::env::args().nth(1).unwrap()).context("read dump file")?);
let mut stream = BufReader::new(fs::File::open(args.input_file)?);
let mut original_ids = vec![];
let mut vector_data = vec![];
let mut query_knns = vec![];
// There is no convenient way to pass the actual size along, so accursedly do it manually
let mut original_ids = Vec::with_capacity(args.n.unwrap_or(0));
let mut vector_data = Vec::with_capacity(args.n.unwrap_or(0) * D_EMB);
let header: ShardInputHeader = rmp_serde::from_read(&mut stream)?;
let centroid_fp16 = header.centroid.iter().map(|x| f16::from_f32(*x)).collect::<Vec<_>>();
{
let _timer = Timer::new("read shard");
@@ -32,7 +61,6 @@ fn main() -> Result<()> {
Ok(x) => {
original_ids.push(x.id);
vector_data.extend(bytemuck::cast_slice(&x.vector));
query_knns.push(x.query_knns);
},
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e).context("decode fail")
@@ -40,26 +68,38 @@ fn main() -> Result<()> {
}
}
let query_breakpoint = original_ids.len();
if let Some(ref queries_bin) = args.queries_bin {
let mut queries_file = BufReader::new(fs::File::open(queries_bin)?);
let mut buf = vec![0; READ_CHUNK_SIZE];
loop {
let n = queries_file.by_ref().take(READ_CHUNK_SIZE as u64).read_to_end(&mut buf)?;
if n == 0 {
break;
}
vector_data.extend(bytemuck::cast_slice(&buf[..n]));
}
}
let mut config = IndexBuildConfig {
r: 64,
l: 192,
maxc: 750,
alpha: 65200,
saturate_graph: false
r: args.r,
l: args.l,
maxc: args.maxc,
alpha: args.alpha,
query_alpha: args.query_alpha,
saturate_graph: false,
query_breakpoint: query_breakpoint as u32,
max_add_per_stitch_iter: 16
};
let vecs = VectorList {
length: vector_data.len() / D_EMB,
data: vector_data,
d_emb: D_EMB,
length: original_ids.len()
d_emb: D_EMB
};
let mut graph = IndexGraph::empty(original_ids.len(), config.r);
{
//let _timer = Timer::new("project bipartite");
//project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd, config, &vecs);
}
let mut graph = IndexGraph::empty(vecs.len(), config.r);
{
let _timer = Timer::new("random fill");
@@ -77,42 +117,33 @@ fn main() -> Result<()> {
report_degrees(&graph);
{
//let _timer = Timer::new("second pass");
//config.alpha = 62000;
//build_graph(&mut rng, &mut graph, medioid, &vecs, config);
if args.second_pass {
{
let _timer = Timer::new("second pass");
config.alpha = args.alpha_2;
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
}
report_degrees(&graph);
}
//report_degrees(&graph);
if query_breakpoint < graph.graph.len() {
let _timer = Timer::new("robust stitch");
robust_stitch(&mut rng, &mut graph, &vecs, config);
report_degrees(&graph);
}
std::mem::drop(vecs);
let mut query_knns_bwd = vec![vec![]; header.max_query_id];
{
let _timer = Timer::new("compute backward edges");
for (record_id, knns) in query_knns.iter().enumerate() {
for &k in knns {
query_knns_bwd[k as usize].push(record_id as u32);
}
}
}
{
let _timer = Timer::new("augment bipartite");
//augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config, 50);
//random_fill_graph(&mut rng, &mut graph, config.r);
}
let len = original_ids.len();
{
let _timer = Timer::new("write shard");
let mut graph_data = BufWriter::new(fs::File::create(&format!("{}.shard.bin", header.id))?);
let mut graph_data = BufWriter::new(fs::File::create(&format!("{}/{}.shard.bin", args.out_dir, header.id))?);
let mut offsets = Vec::with_capacity(original_ids.len());
let mut offset = 0;
for out_neighbours in graph.graph.iter() {
for (i, out_neighbours) in graph.graph.iter().enumerate() {
if i >= query_breakpoint { break; }
let out_neighbours = out_neighbours.read().unwrap();
offsets.push(offset);
let s: &[u8] = bytemuck::cast_slice(&*out_neighbours);
@@ -121,7 +152,7 @@ fn main() -> Result<()> {
}
offsets.push(offset); // dummy entry for convenience
let mut header_f = fs::File::create(&format!("{}.shard-header.msgpack", header.id))?;
let mut header_f = fs::File::create(&format!("{}/{}.shard-header.msgpack", args.out_dir, header.id))?;
header_f.write_all(&rmp_serde::to_vec(&ShardHeader {
id: header.id,
max: *original_ids.iter().max().unwrap(),