mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-11-07 18:54:06 +00:00
RobustVamana algorithm for big index run
This commit is contained in:
40
Cargo.lock
generated
40
Cargo.lock
generated
@@ -402,18 +402,6 @@ version = "2.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
|
||||
dependencies = [
|
||||
"funty",
|
||||
"radium",
|
||||
"tap",
|
||||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
@@ -808,7 +796,6 @@ name = "diskann"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bitvec",
|
||||
"bytemuck",
|
||||
"fastrand",
|
||||
"foldhash",
|
||||
@@ -1050,12 +1037,6 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "funty"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.31"
|
||||
@@ -2519,12 +2500,6 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "radium"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
@@ -3461,12 +3436,6 @@ dependencies = [
|
||||
"version-compare",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tap"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.16"
|
||||
@@ -4294,15 +4263,6 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wyz"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
|
||||
dependencies = [
|
||||
"tap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.7.35"
|
||||
|
||||
3
build.rs
Normal file
3
build.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
println!("cargo::rustc-link-search=/usr/local/lib/");
|
||||
}
|
||||
@@ -10,7 +10,6 @@ tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
simsimd = "6"
|
||||
foldhash = "0.1"
|
||||
bitvec = "1"
|
||||
tqdm = "0.7"
|
||||
anyhow = "1"
|
||||
bytemuck = { version = "1", features = ["extern_crate_alloc"] }
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
extern crate test;
|
||||
|
||||
use foldhash::{HashSet, HashMap, HashMapExt, HashSetExt};
|
||||
use foldhash::{HashSet, HashSetExt};
|
||||
use fastrand::Rng;
|
||||
use rayon::prelude::*;
|
||||
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, Mutex};
|
||||
@@ -44,7 +44,10 @@ pub struct IndexBuildConfig {
|
||||
pub l: usize,
|
||||
pub maxc: usize,
|
||||
pub alpha: i64,
|
||||
pub saturate_graph: bool
|
||||
pub saturate_graph: bool,
|
||||
pub query_breakpoint: u32, // above this nodes are queries and not base vectors
|
||||
pub max_add_per_stitch_iter: usize,
|
||||
pub query_alpha: i64
|
||||
}
|
||||
|
||||
|
||||
@@ -160,7 +163,7 @@ pub struct Scratch {
|
||||
}
|
||||
|
||||
impl Scratch {
|
||||
pub fn new(IndexBuildConfig { l, r, maxc, .. }: IndexBuildConfig) -> Self {
|
||||
pub fn new(IndexBuildConfig { l, r, .. }: IndexBuildConfig) -> Self {
|
||||
Scratch {
|
||||
visited: HashSet::with_capacity(l * 8),
|
||||
neighbour_buffer: NeighbourBuffer::new(l),
|
||||
@@ -177,7 +180,7 @@ pub struct GreedySearchCounters {
|
||||
|
||||
// Algorithm 1 from the DiskANN paper
|
||||
// We support the dot product metric only, so we want to keep things with the HIGHEST dot product
|
||||
pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters {
|
||||
pub fn greedy_search(scratch: &mut Scratch, start: u32, base_vectors_only: bool, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters {
|
||||
scratch.visited.clear();
|
||||
scratch.neighbour_buffer.clear();
|
||||
scratch.visited_list.clear();
|
||||
@@ -190,7 +193,8 @@ pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs:
|
||||
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
|
||||
scratch.neighbour_pre_buffer.clear();
|
||||
for &neighbour in graph.out_neighbours(pt).iter() {
|
||||
if scratch.visited.insert(neighbour) {
|
||||
let neighbour_is_query = neighbour >= config.query_breakpoint; // OOD-DiskANN page 4: if we are searching for a query, only consider results in base vectors
|
||||
if scratch.visited.insert(neighbour) && !(base_vectors_only && neighbour_is_query) {
|
||||
scratch.neighbour_pre_buffer.push(neighbour);
|
||||
}
|
||||
}
|
||||
@@ -208,7 +212,7 @@ pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs:
|
||||
|
||||
type CandidateList = Vec<(u32, i64)>;
|
||||
|
||||
fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList, config: IndexBuildConfig) {
|
||||
fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList) {
|
||||
let p_vec = &vecs[point as usize];
|
||||
for (i, &n) in neigh.iter().enumerate() {
|
||||
let dot = fast_dot(p_vec, &vecs[n as usize], &vecs[neigh[(i + 1) % neigh.len() as usize] as usize]);
|
||||
@@ -218,7 +222,7 @@ fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh:
|
||||
|
||||
// "Robust prune" algorithm, kind of
|
||||
// The algorithm in the paper does not actually match the code as implemented in microsoft/DiskANN
|
||||
// and that's slightly different from the one in ParlayANN for no reason
|
||||
// and that's slightly different from the one in ParlayANN for no clear reason
|
||||
// This is closer to ParlayANN
|
||||
fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &VectorList, config: IndexBuildConfig) {
|
||||
neigh.clear();
|
||||
@@ -254,7 +258,12 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
|
||||
let next_vec = &vecs[scratch.robust_prune_scratch_buffer[(i + 1) % scratch.robust_prune_scratch_buffer.len()].0 as usize];
|
||||
let p_star_prime_score = fast_dot(&vecs[p_prime as usize], &vecs[p_star as usize], next_vec);
|
||||
let p_prime_p_score = candidates[ci].1;
|
||||
let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16;
|
||||
let con_alpha = if p_prime >= config.query_breakpoint {
|
||||
config.query_alpha
|
||||
} else {
|
||||
config.alpha
|
||||
};
|
||||
let alpha_times_p_star_prime_score = (con_alpha * p_star_prime_score) >> 16;
|
||||
|
||||
if alpha_times_p_star_prime_score >= p_prime_p_score {
|
||||
candidates[ci].1 = i64::MIN;
|
||||
@@ -262,7 +271,8 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
|
||||
}
|
||||
}
|
||||
|
||||
if config.saturate_graph {
|
||||
// saturate graph on for query points - otherwise they get no neighbours, more or less
|
||||
if config.saturate_graph || p >= config.query_breakpoint {
|
||||
for &(id, _score) in candidates.iter() {
|
||||
if neigh.len() == config.r {
|
||||
return;
|
||||
@@ -276,21 +286,21 @@ fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &Vect
|
||||
|
||||
pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &VectorList, config: IndexBuildConfig) {
|
||||
assert!(vecs.len() < u32::MAX as usize);
|
||||
assert_eq!(vecs.len(), graph.graph.len());
|
||||
|
||||
let mut sigmas: Vec<u32> = (0..(vecs.len() as u32)).collect();
|
||||
rng.shuffle(&mut sigmas);
|
||||
|
||||
let rng = Mutex::new(rng.fork());
|
||||
|
||||
//let scratch = &mut Scratch::new(config);
|
||||
//let mut rng = rng.lock().unwrap();
|
||||
sigmas.into_par_iter().for_each_init(|| (Scratch::new(config), rng.lock().unwrap().fork()), |(scratch, rng), sigma_i| {
|
||||
sigmas.into_par_iter().for_each_init(|| Scratch::new(config), |scratch, sigma_i| {
|
||||
//sigmas.into_iter().for_each(|sigma_i| {
|
||||
greedy_search(scratch, medioid, &vecs[sigma_i as usize], vecs, &graph, config);
|
||||
let is_query = sigma_i >= config.query_breakpoint;
|
||||
greedy_search(scratch, medioid, is_query, &vecs[sigma_i as usize], vecs, &graph, config);
|
||||
|
||||
{
|
||||
let n = graph.out_neighbours(sigma_i);
|
||||
merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs, config);
|
||||
merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs);
|
||||
}
|
||||
|
||||
{
|
||||
@@ -303,8 +313,8 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
|
||||
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
|
||||
if neighbour_neighbours.len() == config.r {
|
||||
scratch.visited_list.clear();
|
||||
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
|
||||
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config);
|
||||
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs);
|
||||
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs);
|
||||
robust_prune(scratch, neighbour, &mut neighbour_neighbours, vecs, config);
|
||||
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r {
|
||||
neighbour_neighbours.push(sigma_i);
|
||||
@@ -313,24 +323,56 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
|
||||
});
|
||||
}
|
||||
|
||||
pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig, max_iters: usize) {
|
||||
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
|
||||
rng.shuffle(&mut sigmas);
|
||||
pub fn robust_stitch(rng: &mut Rng, graph: &mut IndexGraph, vecs: &VectorList, config: IndexBuildConfig) {
|
||||
let n_queries = graph.graph.len() as u32 - config.query_breakpoint;
|
||||
let mut in_edges = Vec::with_capacity(n_queries as usize);
|
||||
for _i in 0..(n_queries as usize) {
|
||||
in_edges.push(Vec::with_capacity(config.r as usize));
|
||||
}
|
||||
|
||||
// Iterate through graph vertices in a random order
|
||||
let rng = Mutex::new(rng.fork());
|
||||
sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
|
||||
let mut neighbours = graph.out_neighbours_mut(sigma_i);
|
||||
let mut i = 0;
|
||||
while neighbours.len() < config.r && i < max_iters {
|
||||
let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
|
||||
let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
|
||||
if !neighbours.contains(&projected_neighbour) {
|
||||
neighbours.push(projected_neighbour);
|
||||
let mut queries_order = (config.query_breakpoint..(graph.graph.len() as u32)).collect::<Vec<u32>>();
|
||||
rng.shuffle(&mut queries_order);
|
||||
|
||||
for base_i in 0..config.query_breakpoint {
|
||||
let mut out_neighbours = graph.out_neighbours_mut(base_i);
|
||||
// store out-edges (to queries) from each base data node with corresponding query node and drop out-edges to queries
|
||||
out_neighbours.retain(|&out_neighbour_out_edge| {
|
||||
let is_query = out_neighbour_out_edge >= config.query_breakpoint;
|
||||
if is_query {
|
||||
in_edges[(out_neighbour_out_edge - config.query_breakpoint) as usize].push(base_i);
|
||||
}
|
||||
i += 1;
|
||||
!is_query
|
||||
});
|
||||
}
|
||||
|
||||
queries_order.into_par_iter().for_each(|query_i| {
|
||||
// For each query, fill spare space at in-neighbours with query's out-neighbours
|
||||
// The OOD-DiskANN paper itself seems to fill *all* the spare space at once with (out-neighbours of) the first query which is encountered, which feels like an odd choice.
|
||||
// We have a switch for that instead.
|
||||
let query_out_neighbours = graph.out_neighbours(query_i);
|
||||
println!("{} has {} in {} out", query_i, in_edges[(query_i - config.query_breakpoint) as usize].len(), query_out_neighbours.len());
|
||||
for &in_neighbour in in_edges[(query_i - config.query_breakpoint) as usize].iter() {
|
||||
let mut candidates = Vec::with_capacity(query_out_neighbours.len());
|
||||
for (i, &neigh) in query_out_neighbours.iter().enumerate() {
|
||||
let score = fast_dot(&vecs[in_neighbour as usize], &vecs[neigh as usize], &vecs[query_out_neighbours[(i + 1) % query_out_neighbours.len()] as usize]);
|
||||
candidates.push((neigh, score));
|
||||
}
|
||||
candidates.sort_unstable_by_key(|(_neigh, score)| -*score);
|
||||
let mut in_neighbour_out_edges = graph.out_neighbours_mut(in_neighbour);
|
||||
let mut added = 0;
|
||||
for (neigh, _score) in candidates {
|
||||
if added >= config.max_add_per_stitch_iter {
|
||||
break;
|
||||
}
|
||||
if in_neighbour_out_edges.contains(&neigh) {
|
||||
continue;
|
||||
}
|
||||
in_neighbour_out_edges.push(neigh);
|
||||
added += 1;
|
||||
}
|
||||
println!("wrote {} out to {}", added, in_neighbour);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn random_fill_graph(rng: &mut Rng, graph: &mut IndexGraph, r: usize) {
|
||||
|
||||
50
generate_queries_bin.py
Normal file
50
generate_queries_bin.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import sys, aiohttp, msgpack, numpy, pgvector.asyncpg, asyncio
|
||||
|
||||
async def use_emb_server(sess, query):
|
||||
async with sess.post("http://100.64.0.10:1708/", data=msgpack.dumps(query), timeout=aiohttp.ClientTimeout(connect=5, sock_connect=5, sock_read=None)) as res:
|
||||
response = msgpack.loads(await res.read())
|
||||
if res.status == 200:
|
||||
return response
|
||||
else:
|
||||
raise Exception(response if res.headers.get("content-type") == "application/msgpack" else (await res.text()))
|
||||
|
||||
BATCH_SIZE = 32
|
||||
|
||||
async def main():
|
||||
with open("query_data.bin", "wb") as f:
|
||||
with open("queries.txt", "r") as g:
|
||||
write_lock = asyncio.Lock()
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
sem = asyncio.Semaphore(3)
|
||||
|
||||
async def process_batch(batch):
|
||||
while True:
|
||||
try:
|
||||
embs = await use_emb_server(sess, { "text": batch })
|
||||
async with write_lock:
|
||||
f.write(b"".join(embs))
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
sem.release()
|
||||
|
||||
async def dispatch(batch):
|
||||
await sem.acquire()
|
||||
tg.create_task(process_batch(batch))
|
||||
|
||||
batch = []
|
||||
while line := g.readline():
|
||||
if line.strip(): batch.append(line.strip())
|
||||
if len(batch) == BATCH_SIZE:
|
||||
await dispatch(batch)
|
||||
batch = []
|
||||
if len(batch) > 0:
|
||||
await dispatch(batch)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -125,16 +125,14 @@ pub struct ProcessedEntry {
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
pub struct ShardInputHeader {
|
||||
pub id: u32,
|
||||
pub centroid: Vec<f32>,
|
||||
pub max_query_id: usize
|
||||
pub centroid: Vec<f32>
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
pub struct ShardedRecord {
|
||||
pub id: u32,
|
||||
#[serde(with="serde_bytes")]
|
||||
pub vector: Vec<u8>, // FP16
|
||||
pub query_knns: Vec<u32>
|
||||
pub vector: Vec<u8> // FP16
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::io::{BufReader, Read, Seek, SeekFrom, Write, BufWriter};
|
||||
use std::io::{BufReader, Write, BufWriter};
|
||||
use std::path::PathBuf;
|
||||
use rmp_serde::decode::Error as DecodeError;
|
||||
use std::fs;
|
||||
@@ -8,7 +8,6 @@ use base64::Engine;
|
||||
use argh::FromArgs;
|
||||
use chrono::{TimeZone, Utc, DateTime};
|
||||
use std::collections::VecDeque;
|
||||
use faiss::Index;
|
||||
use std::sync::mpsc::{sync_channel, SyncSender};
|
||||
use itertools::Itertools;
|
||||
use simsimd::SpatialSimilarity;
|
||||
@@ -47,8 +46,6 @@ struct CLIArguments {
|
||||
centroids: Option<String>,
|
||||
#[argh(option, short='S', description="index shard directory")]
|
||||
shards_dir: Option<String>,
|
||||
#[argh(option, short='Q', description="query vectors file")]
|
||||
queries: Option<String>,
|
||||
#[argh(option, short='d', description="random seed")]
|
||||
seed: Option<u64>,
|
||||
#[argh(option, short='i', description="index output directory")]
|
||||
@@ -62,7 +59,9 @@ struct CLIArguments {
|
||||
#[argh(option, short='q', description="product quantization codec path")]
|
||||
pq_codec: Option<String>,
|
||||
#[argh(switch, short='j', description="JSON output")]
|
||||
json: bool
|
||||
json: bool,
|
||||
#[argh(option, short='f', description="k-means balance fudge factor", default="0.2")]
|
||||
balance_fudge: f64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
@@ -125,9 +124,19 @@ const SHARD_SPILL: usize = 2;
|
||||
const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size
|
||||
const D_EMB: u32 = 1152;
|
||||
const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0);
|
||||
const KNN_K: usize = 30;
|
||||
const BALANCE_WEIGHT: f64 = 0.2;
|
||||
const BATCH_SIZE: usize = 128;
|
||||
const BATCH_SIZE: usize = 1024;
|
||||
|
||||
#[derive(Clone, Serialize, Debug)]
|
||||
pub struct JsonEntry<'a> {
|
||||
pub url: &'a str,
|
||||
pub id: &'a str,
|
||||
pub title: &'a str,
|
||||
pub subreddit: &'a str,
|
||||
pub author: &'a str,
|
||||
pub timestamp: u64,
|
||||
pub embedding: &'a [f32],
|
||||
pub metadata: common::OriginalImageMetadata
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: CLIArguments = argh::from_env();
|
||||
@@ -158,37 +167,6 @@ fn main() -> Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
// construct FAISS index over query vectors for kNNs
|
||||
let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries {
|
||||
println!("constructing index");
|
||||
// not memory-efficient but this is small
|
||||
let mut file = fs::File::open(queries_file).context("read queries file")?;
|
||||
let mut size = file.metadata()?.len();
|
||||
//let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?;
|
||||
let mut index = faiss::index_factory(D_EMB, "HNSW64,SQ8", faiss::MetricType::InnerProduct)?;
|
||||
//let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?;
|
||||
let mut buf = vec![0; (D_EMB as usize) * (1<<18)];
|
||||
loop {
|
||||
if size == 0 {
|
||||
break;
|
||||
}
|
||||
if size < (buf.len() as u64) {
|
||||
buf.resize(size as usize, 0);
|
||||
}
|
||||
file.read_exact(&mut buf)?;
|
||||
size -= buf.len() as u64;
|
||||
let unpacked = common::decode_fp16_buffer(&buf);
|
||||
if !index.is_trained() { index.train(&unpacked)?; print!("train"); }
|
||||
index.add(&unpacked)?;
|
||||
print!(".");
|
||||
}
|
||||
println!("done");
|
||||
let ntotal = index.ntotal();
|
||||
(Some(index), ntotal as usize)
|
||||
} else {
|
||||
(None, 0)
|
||||
};
|
||||
|
||||
// if sufficient config to split index exists, set up output files
|
||||
let mut shards_out = if let (Some(shards_dir), Some(centroids)) = (&args.shards_dir, &args.centroids) {
|
||||
let mut shards = Vec::new();
|
||||
@@ -202,10 +180,12 @@ fn main() -> Result<()> {
|
||||
for i in 0..(centroids_data.len() / (D_EMB as usize)) {
|
||||
let centroid = centroids_data[i * (D_EMB as usize)..(i + 1) * (D_EMB as usize)].to_vec();
|
||||
let mut file = fs::File::create(PathBuf::from(shards_dir).join(format!("{}.shard.msgpack", i))).context("create shard file")?;
|
||||
rmp_serde::encode::write(&mut file, &ShardInputHeader { id: i as u32, centroid: centroid.clone(), max_query_id })?;
|
||||
rmp_serde::encode::write(&mut file, &ShardInputHeader { id: i as u32, centroid: centroid.clone() })?;
|
||||
shards.push((centroid, file, 0, i));
|
||||
}
|
||||
|
||||
println!("splitting into {} shards", shards.len());
|
||||
|
||||
Some(shards)
|
||||
} else {
|
||||
None
|
||||
@@ -376,6 +356,21 @@ fn main() -> Result<()> {
|
||||
if args.print_embeddings {
|
||||
println!("https://mse.osmarks.net/?e={}", base64::engine::general_purpose::URL_SAFE.encode(&x.embedding));
|
||||
}
|
||||
// this is not a very compact format, but I am lazy and this will never be a performance bottleneck
|
||||
if args.json {
|
||||
let entry = JsonEntry {
|
||||
url: &x.url,
|
||||
id: &x.id,
|
||||
title: &x.title,
|
||||
subreddit: &x.subreddit,
|
||||
author: &x.author,
|
||||
timestamp: x.timestamp,
|
||||
embedding: &embedding,
|
||||
metadata: x.metadata.clone()
|
||||
};
|
||||
let data = serde_json::to_string(&entry).unwrap();
|
||||
println!("{}", data);
|
||||
}
|
||||
|
||||
Some((x, embedding))
|
||||
};
|
||||
@@ -395,26 +390,17 @@ fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
if let Some(shards) = &mut shards_out {
|
||||
let mut knn_query = vec![];
|
||||
for (_, embedding) in batch.iter() {
|
||||
knn_query.extend(embedding);
|
||||
}
|
||||
|
||||
let index = queries_index.as_mut().context("need queries")?;
|
||||
let knn_result = index.search(&knn_query, KNN_K)?;
|
||||
|
||||
for (i, (x, embedding)) in batch.iter().enumerate() {
|
||||
// closest matches first
|
||||
shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| {
|
||||
let mut dot = SpatialSimilarity::dot(¢roid, &embedding).unwrap();
|
||||
dot -= BALANCE_WEIGHT * (shard_count as f64 / bal_count as f64);
|
||||
dot -= args.balance_fudge * (shard_count as f64 / bal_count as f64);
|
||||
-scale_dot_result_f64(dot)
|
||||
});
|
||||
|
||||
let entry = ShardedRecord {
|
||||
id: count + i as u32,
|
||||
vector: x.embedding.clone(),
|
||||
query_knns: knn_result.labels[i * KNN_K..(i + 1)*KNN_K].into_iter().map(|x| x.get().unwrap() as u32).collect()
|
||||
vector: x.embedding.clone()
|
||||
};
|
||||
let data = rmp_serde::to_vec(&entry)?;
|
||||
for (_, file, shard_count, _shard_index) in shards[0..SHARD_SPILL].iter_mut() {
|
||||
|
||||
@@ -1,28 +1,57 @@
|
||||
use anyhow::{Result, Context};
|
||||
use itertools::Itertools;
|
||||
use std::io::{BufReader, BufWriter, Write};
|
||||
use std::io::{BufReader, BufWriter, Write, Read};
|
||||
use rmp_serde::decode::Error as DecodeError;
|
||||
use std::fs;
|
||||
use diskann::{augment_bipartite, build_graph, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid};
|
||||
use diskann::{build_graph, random_fill_graph, vector::VectorList, IndexBuildConfig, IndexGraph, Timer, report_degrees, medioid, robust_stitch};
|
||||
use half::f16;
|
||||
use argh::FromArgs;
|
||||
|
||||
mod common;
|
||||
|
||||
use common::{ShardInputHeader, ShardedRecord, ShardHeader};
|
||||
|
||||
#[derive(FromArgs)]
|
||||
#[argh(description="Generate indices from shard files")]
|
||||
struct CLIArguments {
|
||||
#[argh(positional)]
|
||||
input_file: String,
|
||||
#[argh(positional)]
|
||||
out_dir: String,
|
||||
#[argh(positional)]
|
||||
queries_bin: Option<String>,
|
||||
#[argh(option, short='L', default="192", description="search list size (higher is better but slower)")]
|
||||
l: usize,
|
||||
#[argh(option, short='R', default="64", description="graph degree")]
|
||||
r: usize,
|
||||
#[argh(option, short='C', default="750", description="max candidate list size")]
|
||||
maxc: usize,
|
||||
#[argh(option, short='A', default="65536", description="first pass relaxation factor (times 2^16)")]
|
||||
alpha: i64,
|
||||
#[argh(option, short='Q', default="65536", description="query set special relaxation factor (times 2^16)")]
|
||||
query_alpha: i64,
|
||||
#[argh(option, short='B', default="65536", description="second pass relaxation factor (times 2^16)")]
|
||||
alpha_2: i64,
|
||||
#[argh(switch, short='s', description="do second pass")]
|
||||
second_pass: bool,
|
||||
#[argh(option, short='N', description="number of vectors to allocate for")]
|
||||
n: Option<usize>
|
||||
}
|
||||
|
||||
const D_EMB: usize = 1152;
|
||||
const READ_CHUNK_SIZE: usize = D_EMB * size_of::<f16>() * 1024;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: CLIArguments = argh::from_env();
|
||||
|
||||
let mut rng = fastrand::Rng::new();
|
||||
|
||||
let mut stream = BufReader::new(fs::File::open(std::env::args().nth(1).unwrap()).context("read dump file")?);
|
||||
let mut stream = BufReader::new(fs::File::open(args.input_file)?);
|
||||
|
||||
let mut original_ids = vec![];
|
||||
let mut vector_data = vec![];
|
||||
let mut query_knns = vec![];
|
||||
// There is no convenient way to pass the actual size along, so accursedly do it manually
|
||||
let mut original_ids = Vec::with_capacity(args.n.unwrap_or(0));
|
||||
let mut vector_data = Vec::with_capacity(args.n.unwrap_or(0) * D_EMB);
|
||||
|
||||
let header: ShardInputHeader = rmp_serde::from_read(&mut stream)?;
|
||||
let centroid_fp16 = header.centroid.iter().map(|x| f16::from_f32(*x)).collect::<Vec<_>>();
|
||||
|
||||
{
|
||||
let _timer = Timer::new("read shard");
|
||||
@@ -32,7 +61,6 @@ fn main() -> Result<()> {
|
||||
Ok(x) => {
|
||||
original_ids.push(x.id);
|
||||
vector_data.extend(bytemuck::cast_slice(&x.vector));
|
||||
query_knns.push(x.query_knns);
|
||||
},
|
||||
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||
Err(e) => return Err(e).context("decode fail")
|
||||
@@ -40,26 +68,38 @@ fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let query_breakpoint = original_ids.len();
|
||||
|
||||
if let Some(ref queries_bin) = args.queries_bin {
|
||||
let mut queries_file = BufReader::new(fs::File::open(queries_bin)?);
|
||||
let mut buf = vec![0; READ_CHUNK_SIZE];
|
||||
loop {
|
||||
let n = queries_file.by_ref().take(READ_CHUNK_SIZE as u64).read_to_end(&mut buf)?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
vector_data.extend(bytemuck::cast_slice(&buf[..n]));
|
||||
}
|
||||
}
|
||||
|
||||
let mut config = IndexBuildConfig {
|
||||
r: 64,
|
||||
l: 192,
|
||||
maxc: 750,
|
||||
alpha: 65200,
|
||||
saturate_graph: false
|
||||
r: args.r,
|
||||
l: args.l,
|
||||
maxc: args.maxc,
|
||||
alpha: args.alpha,
|
||||
query_alpha: args.query_alpha,
|
||||
saturate_graph: false,
|
||||
query_breakpoint: query_breakpoint as u32,
|
||||
max_add_per_stitch_iter: 16
|
||||
};
|
||||
|
||||
let vecs = VectorList {
|
||||
length: vector_data.len() / D_EMB,
|
||||
data: vector_data,
|
||||
d_emb: D_EMB,
|
||||
length: original_ids.len()
|
||||
d_emb: D_EMB
|
||||
};
|
||||
|
||||
let mut graph = IndexGraph::empty(original_ids.len(), config.r);
|
||||
|
||||
{
|
||||
//let _timer = Timer::new("project bipartite");
|
||||
//project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd, config, &vecs);
|
||||
}
|
||||
let mut graph = IndexGraph::empty(vecs.len(), config.r);
|
||||
|
||||
{
|
||||
let _timer = Timer::new("random fill");
|
||||
@@ -77,42 +117,33 @@ fn main() -> Result<()> {
|
||||
|
||||
report_degrees(&graph);
|
||||
|
||||
{
|
||||
//let _timer = Timer::new("second pass");
|
||||
//config.alpha = 62000;
|
||||
//build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||
if args.second_pass {
|
||||
{
|
||||
let _timer = Timer::new("second pass");
|
||||
config.alpha = args.alpha_2;
|
||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||
}
|
||||
report_degrees(&graph);
|
||||
}
|
||||
|
||||
//report_degrees(&graph);
|
||||
if query_breakpoint < graph.graph.len() {
|
||||
let _timer = Timer::new("robust stitch");
|
||||
robust_stitch(&mut rng, &mut graph, &vecs, config);
|
||||
report_degrees(&graph);
|
||||
}
|
||||
|
||||
std::mem::drop(vecs);
|
||||
|
||||
let mut query_knns_bwd = vec![vec![]; header.max_query_id];
|
||||
|
||||
{
|
||||
let _timer = Timer::new("compute backward edges");
|
||||
for (record_id, knns) in query_knns.iter().enumerate() {
|
||||
for &k in knns {
|
||||
query_knns_bwd[k as usize].push(record_id as u32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let _timer = Timer::new("augment bipartite");
|
||||
//augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config, 50);
|
||||
//random_fill_graph(&mut rng, &mut graph, config.r);
|
||||
}
|
||||
|
||||
let len = original_ids.len();
|
||||
|
||||
{
|
||||
let _timer = Timer::new("write shard");
|
||||
let mut graph_data = BufWriter::new(fs::File::create(&format!("{}.shard.bin", header.id))?);
|
||||
let mut graph_data = BufWriter::new(fs::File::create(&format!("{}/{}.shard.bin", args.out_dir, header.id))?);
|
||||
|
||||
let mut offsets = Vec::with_capacity(original_ids.len());
|
||||
let mut offset = 0;
|
||||
for out_neighbours in graph.graph.iter() {
|
||||
for (i, out_neighbours) in graph.graph.iter().enumerate() {
|
||||
if i >= query_breakpoint { break; }
|
||||
let out_neighbours = out_neighbours.read().unwrap();
|
||||
offsets.push(offset);
|
||||
let s: &[u8] = bytemuck::cast_slice(&*out_neighbours);
|
||||
@@ -121,7 +152,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
offsets.push(offset); // dummy entry for convenience
|
||||
|
||||
let mut header_f = fs::File::create(&format!("{}.shard-header.msgpack", header.id))?;
|
||||
let mut header_f = fs::File::create(&format!("{}/{}.shard-header.msgpack", args.out_dir, header.id))?;
|
||||
header_f.write_all(&rmp_serde::to_vec(&ShardHeader {
|
||||
id: header.id,
|
||||
max: *original_ids.iter().max().unwrap(),
|
||||
|
||||
Reference in New Issue
Block a user