1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-18 00:33:16 +00:00

fixed roargraph things

This commit is contained in:
osmarks 2025-01-11 11:17:51 +00:00
commit 9334fc189c
10 changed files with 575 additions and 123 deletions

2
.gitignore vendored
View File

@ -15,3 +15,5 @@ diskann/target
*.bin
*.msgpack
*/flamegraph.svg
*.hdf5
*.v

356
Cargo.lock generated
View File

@ -95,6 +95,37 @@ dependencies = [
"syn 2.0.79",
]
[[package]]
name = "argh"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7af5ba06967ff7214ce4c7419c7d185be7ecd6cc4965a8f6e1d8ce0398aad219"
dependencies = [
"argh_derive",
"argh_shared",
]
[[package]]
name = "argh_derive"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56df0aeedf6b7a2fc67d06db35b09684c3e8da0c95f8f27685cb17e08413d87a"
dependencies = [
"argh_shared",
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "argh_shared"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5693f39141bda5760ecc4111ab08da40565d1771038c4a0250f03457ec707531"
dependencies = [
"serde",
]
[[package]]
name = "arrayvec"
version = "0.7.6"
@ -317,6 +348,30 @@ version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61"
[[package]]
name = "bitcode"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee1bce7608560cd4bf0296a4262d0dbf13e6bcec5ff2105724c8ab88cc7fc784"
dependencies = [
"arrayvec",
"bitcode_derive",
"bytemuck",
"glam",
"serde",
]
[[package]]
name = "bitcode_derive"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a539389a13af092cd345a2b47ae7dec12deb306d660b2223d25cd3419b253ebe"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@ -347,6 +402,18 @@ version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@ -379,6 +446,20 @@ name = "bytemuck"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d"
dependencies = [
"bytemuck_derive",
]
[[package]]
name = "bytemuck_derive"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "byteorder"
@ -626,6 +707,31 @@ version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]]
name = "crossterm"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67"
dependencies = [
"bitflags 1.3.2",
"crossterm_winapi",
"libc",
"mio 0.8.11",
"parking_lot",
"signal-hook",
"signal-hook-mio",
"winapi 0.3.9",
]
[[package]]
name = "crossterm_winapi"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "crunchy"
version = "0.2.2"
@ -697,6 +803,26 @@ dependencies = [
"subtle",
]
[[package]]
name = "diskann"
version = "0.1.0"
dependencies = [
"anyhow",
"bitvec",
"bytemuck",
"fastrand",
"foldhash",
"half",
"matrixmultiply",
"rayon",
"rmp-serde",
"serde",
"simsimd",
"tqdm",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "document-features"
version = "0.2.10"
@ -894,6 +1020,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
[[package]]
name = "foreign-types"
version = "0.3.2"
@ -918,6 +1050,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "futures-channel"
version = "0.3.31"
@ -1039,6 +1177,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glam"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc46dd3ec48fdd8e693a98d2b8bafae273a2d54c1de02a2a7e3d57d501f39677"
[[package]]
name = "glob"
version = "0.3.1"
@ -1070,10 +1214,17 @@ version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"bytemuck",
"cfg-if",
"crunchy",
]
[[package]]
name = "hamming"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65043da274378d68241eb9a8f8f8aa54e349136f7b8e12f63e3ef44043cc30e1"
[[package]]
name = "hashbrown"
version = "0.12.3"
@ -1496,7 +1647,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d"
dependencies = [
"winapi",
"winapi 0.2.8",
"winapi-build",
]
@ -1637,6 +1788,28 @@ dependencies = [
"rawpointer",
]
[[package]]
name = "maud"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df518b75016b4289cdddffa1b01f2122f4a49802c93191f3133f6dc2472ebcaa"
dependencies = [
"itoa",
"maud_macros",
]
[[package]]
name = "maud_macros"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa453238ec218da0af6b11fc5978d3b5c3a45ed97b722391a2a11f3306274e18"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "maybe-rayon"
version = "0.1.1"
@ -1667,23 +1840,31 @@ name = "meme-search-engine"
version = "0.1.0"
dependencies = [
"anyhow",
"argh",
"async-recursion",
"axum",
"base64 0.22.1",
"bitcode",
"bytemuck",
"chrono",
"compact_str",
"console-subscriber",
"diskann",
"faiss",
"fast_image_resize",
"fastrand",
"ffmpeg-the-third",
"fnv",
"foldhash",
"futures-util",
"half",
"hamming",
"image",
"itertools 0.13.0",
"json5",
"lazy_static",
"maud",
"memmap2",
"mimalloc",
"ndarray",
"num_cpus",
@ -1691,9 +1872,11 @@ dependencies = [
"regex",
"reqwest",
"rmp-serde",
"seahash",
"serde",
"serde_bytes",
"serde_json",
"simsimd",
"sonic-rs",
"sqlx",
"tokio",
@ -1701,11 +1884,21 @@ dependencies = [
"tower 0.4.13",
"tower-http",
"tracing",
"tracing-subscriber",
"url",
"walkdir",
"zstd",
]
[[package]]
name = "memmap2"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
dependencies = [
"libc",
]
[[package]]
name = "mimalloc"
version = "0.1.43"
@ -1756,6 +1949,18 @@ dependencies = [
"simd-adler32",
]
[[package]]
name = "mio"
version = "0.8.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
dependencies = [
"libc",
"log",
"wasi",
"windows-sys 0.48.0",
]
[[package]]
name = "mio"
version = "1.0.2"
@ -1843,6 +2048,16 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8"
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi 0.3.9",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
@ -2000,6 +2215,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking_lot"
version = "0.12.3"
@ -2170,6 +2391,29 @@ dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro2"
version = "1.0.88"
@ -2275,6 +2519,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rand"
version = "0.8.5"
@ -2625,7 +2875,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d931a44fdaa43b8637009e7632a02adc4f2b2e0733c08caa4cf00e8da4a117a7"
dependencies = [
"kernel32-sys",
"winapi",
"winapi 0.2.8",
]
[[package]]
@ -2643,6 +2893,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "seahash"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b"
[[package]]
name = "security-framework"
version = "2.11.1"
@ -2775,6 +3031,27 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]]
name = "signal-hook-mio"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
dependencies = [
"libc",
"mio 0.8.11",
"signal-hook",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
@ -2815,6 +3092,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
[[package]]
name = "simsimd"
version = "6.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb18072bc601c31152841c2a114b78c11fe15513f0767eacd57d68359b7130e3"
dependencies = [
"cc",
]
[[package]]
name = "slab"
version = "0.4.9"
@ -3175,6 +3461,12 @@ dependencies = [
"version-compare",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "target-lexicon"
version = "0.12.16"
@ -3259,7 +3551,7 @@ dependencies = [
"backtrace",
"bytes",
"libc",
"mio",
"mio 1.0.2",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
@ -3453,6 +3745,17 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tqdm"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa2d2932240205a99b65f15d9861992c95fbb8c9fb280b3a1f17a92db6dc611f"
dependencies = [
"anyhow",
"crossterm",
"once_cell",
]
[[package]]
name = "tracing"
version = "0.1.40"
@ -3486,6 +3789,17 @@ dependencies = [
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
@ -3493,12 +3807,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
[[package]]
@ -3630,7 +3947,7 @@ checksum = "bb08f9e670fab86099470b97cd2b252d6527f0b3cc1401acdb595ffc9dd288ff"
dependencies = [
"kernel32-sys",
"same-file",
"winapi",
"winapi 0.2.8",
]
[[package]]
@ -3753,12 +4070,34 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "167dc9d6949a9b857f3451275e911c3f44255842c1f7a76f33c55103a909087a"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-build"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d315eee3b34aca4797b2da6b13ed88266e6d612562a0c46390af8299fc699bc"
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.52.0"
@ -3955,6 +4294,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]
[[package]]
name = "zerocopy"
version = "0.7.35"

View File

@ -3,7 +3,9 @@ name = "meme-search-engine"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[profile.release-with-debug]
inherits = "release"
debug = true
[dependencies]
tokio = { version = "1", features = ["full", "tracing"] }
@ -44,6 +46,17 @@ compact_str = { version = "0.8.0-beta", features = ["serde"] }
itertools = "0.13"
async-recursion = "1"
fast_image_resize = { version = "5", features = ["image"] }
argh = "0.1"
maud = "0.26"
hamming = "0.1"
seahash = "4"
tracing-subscriber = "0.3"
diskann = { path = "./diskann" }
bytemuck = "1"
bitcode = "0.6"
simsimd = "6"
foldhash = "0.1"
memmap2 = "0.9"
[[bin]]
name = "reddit-dump"
@ -56,3 +69,11 @@ path = "src/video_reader.rs"
[[bin]]
name = "dump-processor"
path = "src/dump_processor.rs"
[[bin]]
name = "generate-index-shard"
path = "src/generate_index_shard.rs"
[[bin]]
name = "query-disk-index"
path = "src/query_disk_index.rs"

7
diskann/hdf5_to_bin.py Normal file
View File

@ -0,0 +1,7 @@
import h5py
import numpy as np
f = h5py.File("glove-200-angular.hdf5", "r")
data = np.array(f["train"])
norms = np.linalg.norm(data, axis=1)
data = data / norms[:, np.newaxis]
data.astype(np.float16).tofile("glove-200-angular.bin")

View File

@ -55,7 +55,6 @@ impl IndexGraph {
#[derive(Clone, Copy, Debug)]
pub struct IndexBuildConfig {
pub r: usize,
pub r_cap: usize,
pub l: usize,
pub maxc: usize,
pub alpha: i64
@ -146,7 +145,14 @@ impl NeighbourBuffer {
self.scores.truncate(self.size);
self.visited.truncate(self.size);
self.next_unvisited = Some(loc as u32);
match self.next_unvisited {
Some(ref mut next_unvisited) => {
*next_unvisited = (loc as u32).min(*next_unvisited);
},
None => {
self.next_unvisited = Some(loc as u32);
}
}
}
pub fn clear(&mut self) {
@ -194,7 +200,6 @@ pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs:
let mut counters = GreedySearchCounters { distances: 0 };
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
//println!("pt {} {:?}", pt, graph.out_neighbours(pt));
scratch.neighbour_pre_buffer.clear();
for &neighbour in graph.out_neighbours(pt).iter() {
if scratch.visited.insert(neighbour) {
@ -299,7 +304,7 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
// somewhat ugly deadlock avoidance hack - only hold writelock at end
let neighbour_neighbours = graph.out_neighbours(neighbour);
// To cut down pruning time slightly, allow accumulating more neighbours than usual limit
if neighbour_neighbours.len() == config.r_cap {
if neighbour_neighbours.len() == config.r {
let mut n = neighbour_neighbours.to_vec();
scratch.visited_list.clear();
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
@ -307,7 +312,7 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
robust_prune(scratch, neighbour, &mut n, vecs, config);
std::mem::drop(neighbour_neighbours);
*graph.out_neighbours_mut(neighbour) = n;
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap {
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r {
// we apparently cannot actually upgrade the read lock to a write lock
std::mem::drop(neighbour_neighbours);
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
@ -360,7 +365,7 @@ pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec
// as in Vamana algorithm
// TODO factor out common code
let neighbour_neighbours = graph.out_neighbours(neighbour);
if neighbour_neighbours.len() == config.r_cap {
if neighbour_neighbours.len() == config.r {
let mut n = neighbour_neighbours.to_vec();
scratch.visited_list.clear();
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
@ -368,7 +373,7 @@ pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec
robust_prune(scratch, neighbour, &mut n, vecs, config);
std::mem::drop(neighbour_neighbours);
*graph.out_neighbours_mut(neighbour) = n;
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap {
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r {
std::mem::drop(neighbour_neighbours);
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
neighbour_neighbours.push(sigma_i);
@ -386,7 +391,7 @@ pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<
sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
let mut neighbours = graph.out_neighbours_mut(sigma_i);
let mut i = 0;
while neighbours.len() < config.r_cap && i < 100 {
while neighbours.len() < config.r && i < 100 {
let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
if !neighbours.contains(&projected_neighbour) {
@ -423,3 +428,18 @@ impl Drop for Timer {
println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32());
}
}
pub fn report_degrees(graph: &IndexGraph) {
let mut total_degree = 0;
let mut degrees = Vec::with_capacity(graph.graph.len());
for out_neighbours in graph.graph.iter() {
let deg = out_neighbours.read().unwrap().len();
total_degree += deg;
degrees.push(deg);
}
degrees.sort_unstable();
println!("average degree {}", (total_degree as f64) / (graph.graph.len() as f64));
println!("median degree {}", degrees[degrees.len() / 2]);
println!("min degree {}", degrees[0]);
println!("max degree {}", degrees[degrees.len() - 1]);
}

View File

@ -7,7 +7,7 @@ use std::{io::Read, time::Instant};
use anyhow::Result;
use half::f16;
use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer};
use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer, report_degrees, random_fill_graph};
use simsimd::SpatialSimilarity;
const D_EMB: usize = 1152;
@ -26,12 +26,13 @@ const PQ_TEST_SIZE: usize = 1000;
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
/*/
{
let file = std::fs::File::open("opq.msgpack")?;
let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?;
let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
let codes = codec.quantize_batch(&input);
println!("{:?}", codes);
//println!("{:?}", codes);
let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
let query = codec.preprocess_query(&raw_query);
let mut real_scores = vec![];
@ -41,17 +42,17 @@ fn main() -> Result<()> {
let pq_scores = codec.asymmetric_dot_product(&query, &codes);
for (x, y) in real_scores.iter().zip(pq_scores.iter()) {
let y = (*y as f32) / SCALE;
println!("{} {} {} {}", x, y, x - y, (x - y) / x);
//println!("{} {} {} {}", x, y, x - y, (x - y) / x);
}
}
}*/
let mut rng = fastrand::Rng::with_seed(1);
let n = 100000;
let n = 100_000;
let vecs = {
let _timer = Timer::new("loaded vectors");
&load_file("embeddings.bin", Some(D_EMB * n))?
&load_file("query.bin", Some(D_EMB * n))?
};
let (graph, medioid) = {
@ -59,10 +60,9 @@ fn main() -> Result<()> {
let mut config = IndexBuildConfig {
r: 64,
r_cap: 80,
l: 128,
l: 192,
maxc: 750,
alpha: 65536,
alpha: 65200,
};
let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap);
@ -70,8 +70,11 @@ fn main() -> Result<()> {
let medioid = medioid(&vecs);
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
config.alpha = 58000;
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
report_degrees(&graph);
//random_fill_graph(&mut rng, &mut graph, config.r);
//config.alpha = 65536;
//build_graph(&mut rng, &mut graph, medioid, &vecs, config);
report_degrees(&graph);
(graph, medioid)
};
@ -82,8 +85,6 @@ fn main() -> Result<()> {
edge_ctr += adjlist.read().unwrap().len();
}
println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32);
let time = Instant::now();
let mut recall = 0;
let mut cmps_ctr = 0;
@ -91,7 +92,6 @@ fn main() -> Result<()> {
let mut config = IndexBuildConfig {
r: 64,
r_cap: 64,
l: 50,
alpha: 65536,
maxc: 0,

View File

@ -173,18 +173,17 @@ pub mod index_config {
use diskann::IndexBuildConfig;
pub const BASE_CONFIG: IndexBuildConfig = IndexBuildConfig {
r: 64,
r_cap: 80,
l: 500,
maxc: 750,
alpha: 60000
r: 40,
l: 200,
maxc: 900,
alpha: 65200
};
pub const PROJECTION_CUT_POINT: usize = 1;
pub const PROJECTION_CUT_POINT: usize = 3;
pub const FIRST_PASS_ALPHA: i64 = 65536;
pub const FIRST_PASS_ALPHA: i64 = 65200;
pub const SECOND_PASS_ALPHA: i64 = 62000;
//pub const SECOND_PASS_ALPHA: i64 = 62000;
pub const QUERY_SEARCH_K: usize = 200; // we want each query to have QUERY_REVERSE_K results, but some queries are likely more common than others in the top-k lists, so oversample a bit
pub const QUERY_REVERSE_K: usize = 100;

View File

@ -55,7 +55,7 @@ struct CLIArguments {
#[argh(switch, short='t', description="print titles")]
titles: bool,
#[argh(option, description="truncate centroids list")]
clip_centroids: Option<usize>,
clip_shards: Option<usize>,
#[argh(switch, description="print original linked URL")]
original_url: bool,
#[argh(option, short='q', description="product quantization codec path")]
@ -179,7 +179,7 @@ fn main() -> Result<()> {
let centroids_data = fs::read(centroids).context("read centroids file")?;
let mut centroids_data = common::decode_fp16_buffer(&centroids_data);
if let Some(clip) = args.clip_centroids {
if let Some(clip) = args.clip_shards {
centroids_data.truncate(clip * D_EMB as usize);
}
@ -208,6 +208,14 @@ fn main() -> Result<()> {
let path = file.path();
let filename = path.file_name().unwrap().to_str().unwrap();
let (fst, snd) = filename.split_once(".").unwrap();
let id: u32 = str::parse(fst)?;
if let Some(clip) = args.clip_shards {
if id >= (clip as u32) {
continue;
}
}
if snd == "shard-header.msgpack" {
let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?;
if original_ids_to_shards.len() < (header.max as usize + 1) {
@ -237,7 +245,6 @@ fn main() -> Result<()> {
shard_id_mappings.push((header.id, header.mapping));
} else if snd == "shard.bin" {
let file = fs::File::open(&path).context("open shard file")?;
let id: u32 = str::parse(fst)?;
files.push((id, file));
}
}
@ -245,11 +252,16 @@ fn main() -> Result<()> {
files.sort_by_key(|(id, _)| *id);
shard_id_mappings.sort_by_key(|(id, _)| *id);
let read_out_vertices =move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
let mut out_vertices: Vec<u32> = vec![];
let mut shards: Vec<u32> = vec![];
// look up each location in shard files
for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() {
if (shard, offset, len) == EMPTY_LOOKUP {
continue;
}
shards.push(shard);
let shard = shard as usize;
// this random access is almost certainly rather slow
@ -257,7 +269,7 @@ fn main() -> Result<()> {
files[shard].1.seek(SeekFrom::Start(offset))?;
let mut buf = vec![0; len as usize];
files[shard].1.read_exact(&mut buf)?;
let s: &mut [u32] = bytemuck::cast_slice_mut(&mut *buf);
let mut s: Vec<u32> = bytemuck::allocation::pod_collect_to_vec(&*buf);
for within_shard_id in s.iter_mut() {
*within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize];
}

View File

@ -1,10 +1,10 @@
use anyhow::{Result, Context};
use itertools::Itertools;
use std::io::{BufReader, Write, BufWriter, Seek};
use rmp_serde::decode::Error as DecodeError;
use std::fs;
use std::collections::BinaryHeap;
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer};
use std::io::{BufReader, BufWriter, Write};
use std::fs;
use rmp_serde::decode::{Error as DecodeError, from_read};
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer, report_degrees};
use half::f16;
mod common;
@ -13,19 +13,6 @@ use common::{index_config::{self, QUERY_REVERSE_K}, ShardHeader, ShardInputHeade
const D_EMB: usize = 1152;
fn report_degrees(graph: &IndexGraph) {
let mut total_degree = 0;
let mut degrees = Vec::with_capacity(graph.graph.len());
for out_neighbours in graph.graph.iter() {
let deg = out_neighbours.read().unwrap().len();
total_degree += deg;
degrees.push(deg);
}
degrees.sort_unstable();
println!("average degree {}", (total_degree as f32) / (graph.graph.len() as f32));
println!("median degree {}", degrees[degrees.len() / 2]);
}
fn main() -> Result<()> {
let mut rng = fastrand::Rng::new();
@ -98,7 +85,7 @@ fn main() -> Result<()> {
length: original_ids.len()
};
let mut graph = IndexGraph::empty(original_ids.len(), config.r_cap);
let mut graph = IndexGraph::empty(original_ids.len(), config.r);
{
let _timer = Timer::new("project bipartite");
@ -126,14 +113,6 @@ fn main() -> Result<()> {
report_degrees(&graph);
{
let _timer = Timer::new("second pass");
config.alpha = common::index_config::SECOND_PASS_ALPHA;
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
}
report_degrees(&graph);
std::mem::drop(vecs);
let len = original_ids.len();

View File

@ -7,7 +7,6 @@ use std::fs;
use base64::Engine;
use argh::FromArgs;
use chrono::{TimeZone, Utc, DateTime};
use std::collections::VecDeque;
use itertools::Itertools;
use foldhash::{HashSet, HashSetExt};
use half::f16;
@ -23,9 +22,17 @@ use common::{PackedIndexEntry, IndexHeader};
#[argh(description="Query disk index")]
struct CLIArguments {
#[argh(positional)]
query_vector: String,
#[argh(positional)]
index_path: String
index_path: String,
#[argh(option, short='q', description="query vector in base64")]
query_vector_base64: Option<String>,
#[argh(option, short='f', description="file of FP16 query vectors")]
query_vector_file: Option<String>,
#[argh(switch, short='v', description="verbose")]
verbose: bool,
#[argh(option, short='n', description="stop at n queries")]
n: Option<usize>,
#[argh(switch, description="always use full-precision vectors (slow)")]
disable_pq: bool
}
fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> {
@ -56,7 +63,7 @@ struct IndexRef<'a> {
pq_code_size: usize
}
fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef) -> Result<(usize, usize)> {
fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef, disable_pq: bool) -> Result<(usize, usize)> {
scratch.visited.clear();
scratch.neighbour_buffer.clear();
scratch.visited_list.clear();
@ -88,24 +95,48 @@ fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preproc
}
let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes);
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
//let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
//let node = read_node(neighbour, index.data_file, index.header)?;
//let vector = bytemuck::cast_slice(&node.vector);
//let distance = fast_dot_noprefetch(query, &vector);
pq_cmps += 1;
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]);
//scratch.neighbour_buffer.insert(neighbour, distance);
if disable_pq {
//let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
let node = read_node(neighbour, index.data_file, index.header)?;
let vector = bytemuck::cast_slice(&node.vector);
let distance = fast_dot_noprefetch(query, &vector);
scratch.neighbour_buffer.insert(neighbour, distance);
} else {
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]);
pq_cmps += 1;
}
}
}
Ok((cmps, pq_cmps))
}
fn summary_stats(ranks: &mut [usize]) {
let sum = ranks.iter().sum::<usize>();
let mean = sum as f64 / ranks.len() as f64 + 1.0;
ranks.sort_unstable();
let median = ranks[ranks.len() / 2] + 1;
let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64;
println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean);
}
fn main() -> Result<()> {
let args: CLIArguments = argh::from_env();
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(args.query_vector.as_bytes()).context("invalid base64")?);
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
let mut queries = vec![];
if let Some(query_vector_base64) = args.query_vector_base64 {
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(query_vector_base64.as_bytes()).context("invalid base64")?);
queries.push(query_vector);
}
if let Some(query_vector_file) = args.query_vector_file {
let query_vectors = fs::read(query_vector_file)?;
queries.extend(common::chunk_fp16_buffer(&query_vectors).chunks(1152).map(|x| x.to_vec()).collect::<Vec<_>>());
}
if let Some(n) = args.n {
queries.truncate(n);
}
let index_path = PathBuf::from(&args.index_path);
let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?;
@ -117,57 +148,90 @@ fn main() -> Result<()> {
MmapOptions::new().populate().map(&pq_codes_file)?
};
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32);
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
// TODO slightly dubious
let selected_shard = header.shards.iter().position_max_by_key(|x| {
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
}).unwrap();
let mut top_20_ranks_best_shard = vec![];
let mut top_rank_best_shard = vec![];
println!("best shard is {}", selected_shard);
for query_vector in queries.iter() {
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32);
for shard in 0..header.shards.len() {
let selected_start = header.shards[shard].1;
// TODO slightly dubious
let selected_shard = header.shards.iter().position_max_by_key(|x| {
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
}).unwrap();
let mut scratch = Scratch {
visited: HashSet::new(),
neighbour_buffer: NeighbourBuffer::new(5000),
neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new()
};
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng);
let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
data_file: &mut data_file,
header: &header,
pq_codes: &pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
})?;
println!("index scan {}: {:?} cmps", shard, cmps);
scratch.visited_list.sort_by_key(|x| -x.1);
for (id, distance, url, shards) in scratch.visited_list.iter().take(20) {
println!("index scan: {} {} {} {:?}", id, distance, url, shards);
if args.verbose {
println!("selected shard is {}", selected_shard);
}
println!("");
let mut matches = vec![];
// brute force scan
for i in 0..header.count {
let node = read_node(i, &mut data_file, &header)?;
//println!("{} {}", i, node.url);
let vector = bytemuck::cast_slice(&node.vector);
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
}
matches.sort_unstable_by_key(|x| -x.1);
let mut matches = matches.into_iter().enumerate().map(|(i, (id, distance, url, shards))| (id, i)).collect::<Vec<_>>();
matches.sort_unstable();
/*for (id, distance, url, shards) in matches.iter().take(20) {
println!("brute force: {} {} {} {:?}", id, distance, url, shards);
}*/
let mut top_ranks = vec![usize::MAX; 20];
for shard in 0..header.shards.len() {
let selected_start = header.shards[shard].1;
let mut scratch = Scratch {
visited: HashSet::new(),
neighbour_buffer: NeighbourBuffer::new(300),
neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new()
};
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng);
let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
data_file: &mut data_file,
header: &header,
pq_codes: &pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
}, args.disable_pq)?;
if args.verbose {
println!("index scan {}: {:?} cmps", shard, cmps);
}
scratch.visited_list.sort_by_key(|x| -x.1);
for (i, (id, distance, url, shards)) in scratch.visited_list.iter().take(20).enumerate() {
if args.verbose {
println!("index scan: {} {} {} {:?}", id, distance, url, shards);
};
let found_id = match matches.binary_search(&(*id, 0)) {
Ok(pos) => pos,
Err(pos) => pos
};
if args.verbose {
println!("rank {}", matches[found_id].1);
};
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
}
if args.verbose { println!("") }
}
top_rank_best_shard.push(top_ranks[0]);
top_20_ranks_best_shard.extend(top_ranks);
}
let mut matches = vec![];
// brute force scan
for i in 0..header.count {
let node = read_node(i, &mut data_file, &header)?;
//println!("{} {}", i, node.url);
let vector = bytemuck::cast_slice(&node.vector);
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
}
matches.sort_by_key(|x| -x.1);
for (id, distance, url, shards) in matches.iter().take(20) {
println!("brute force: {} {} {} {:?}", id, distance, url, shards);
}
println!("ranks of top 20:");
summary_stats(&mut top_20_ranks_best_shard);
println!("ranks of top 1:");
summary_stats(&mut top_rank_best_shard);
Ok(())
}