diff --git a/.gitignore b/.gitignore
index c804eb0..4fb341b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,9 @@ node_modules/*
node_modules
*sqlite3*
thumbtemp
-mse-test-db-small
\ No newline at end of file
+mse-test-db-small
+clipfront2/static/bg*
+diskann/target
+*.bin
+*.msgpack
+*/flamegraph.svg
diff --git a/diskann/Cargo.lock b/diskann/Cargo.lock
new file mode 100644
index 0000000..f9a41da
--- /dev/null
+++ b/diskann/Cargo.lock
@@ -0,0 +1,748 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "anyhow"
+version = "1.0.93"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775"
+
+[[package]]
+name = "autocfg"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
+
+[[package]]
+name = "bitflags"
+version = "1.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
+
+[[package]]
+name = "bitflags"
+version = "2.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
+
+[[package]]
+name = "bitvec"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
+dependencies = [
+ "funty",
+ "radium",
+ "tap",
+ "wyz",
+]
+
+[[package]]
+name = "bytemuck"
+version = "1.20.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a"
+dependencies = [
+ "bytemuck_derive",
+]
+
+[[package]]
+name = "bytemuck_derive"
+version = "1.7.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "byteorder"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
+
+[[package]]
+name = "cc"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
+dependencies = [
+ "shlex",
+]
+
+[[package]]
+name = "cfg-if"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+
+[[package]]
+name = "crossbeam-deque"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
+dependencies = [
+ "crossbeam-epoch",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-epoch"
+version = "0.9.18"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
+dependencies = [
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-utils"
+version = "0.8.20"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
+
+[[package]]
+name = "crossterm"
+version = "0.25.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67"
+dependencies = [
+ "bitflags 1.3.2",
+ "crossterm_winapi",
+ "libc",
+ "mio",
+ "parking_lot",
+ "signal-hook",
+ "signal-hook-mio",
+ "winapi",
+]
+
+[[package]]
+name = "crossterm_winapi"
+version = "0.9.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
+dependencies = [
+ "winapi",
+]
+
+[[package]]
+name = "crunchy"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
+
+[[package]]
+name = "diskann"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "bitvec",
+ "bytemuck",
+ "fastrand",
+ "foldhash",
+ "half",
+ "matrixmultiply",
+ "rayon",
+ "rmp-serde",
+ "serde",
+ "simsimd",
+ "tqdm",
+ "tracing",
+ "tracing-subscriber",
+]
+
+[[package]]
+name = "either"
+version = "1.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
+
+[[package]]
+name = "fastrand"
+version = "2.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
+
+[[package]]
+name = "foldhash"
+version = "0.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
+
+[[package]]
+name = "funty"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
+
+[[package]]
+name = "half"
+version = "2.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
+dependencies = [
+ "bytemuck",
+ "cfg-if",
+ "crunchy",
+]
+
+[[package]]
+name = "lazy_static"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
+
+[[package]]
+name = "libc"
+version = "0.2.164"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
+
+[[package]]
+name = "lock_api"
+version = "0.4.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
+dependencies = [
+ "autocfg",
+ "scopeguard",
+]
+
+[[package]]
+name = "log"
+version = "0.4.22"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
+
+[[package]]
+name = "matrixmultiply"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
+dependencies = [
+ "autocfg",
+ "rawpointer",
+]
+
+[[package]]
+name = "mio"
+version = "0.8.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
+dependencies = [
+ "libc",
+ "log",
+ "wasi",
+ "windows-sys",
+]
+
+[[package]]
+name = "nu-ansi-term"
+version = "0.46.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
+dependencies = [
+ "overload",
+ "winapi",
+]
+
+[[package]]
+name = "num-traits"
+version = "0.2.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "once_cell"
+version = "1.20.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
+
+[[package]]
+name = "overload"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
+
+[[package]]
+name = "parking_lot"
+version = "0.12.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
+dependencies = [
+ "lock_api",
+ "parking_lot_core",
+]
+
+[[package]]
+name = "parking_lot_core"
+version = "0.9.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "redox_syscall",
+ "smallvec",
+ "windows-targets 0.52.6",
+]
+
+[[package]]
+name = "paste"
+version = "1.0.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
+
+[[package]]
+name = "pin-project-lite"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff"
+
+[[package]]
+name = "proc-macro2"
+version = "1.0.89"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e"
+dependencies = [
+ "unicode-ident",
+]
+
+[[package]]
+name = "quote"
+version = "1.0.37"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
+dependencies = [
+ "proc-macro2",
+]
+
+[[package]]
+name = "radium"
+version = "0.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
+
+[[package]]
+name = "rawpointer"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
+
+[[package]]
+name = "rayon"
+version = "1.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
+dependencies = [
+ "either",
+ "rayon-core",
+]
+
+[[package]]
+name = "rayon-core"
+version = "1.12.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
+dependencies = [
+ "crossbeam-deque",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "redox_syscall"
+version = "0.5.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
+dependencies = [
+ "bitflags 2.6.0",
+]
+
+[[package]]
+name = "rmp"
+version = "0.8.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4"
+dependencies = [
+ "byteorder",
+ "num-traits",
+ "paste",
+]
+
+[[package]]
+name = "rmp-serde"
+version = "1.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db"
+dependencies = [
+ "byteorder",
+ "rmp",
+ "serde",
+]
+
+[[package]]
+name = "scopeguard"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
+
+[[package]]
+name = "serde"
+version = "1.0.215"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
+dependencies = [
+ "serde_derive",
+]
+
+[[package]]
+name = "serde_derive"
+version = "1.0.215"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "sharded-slab"
+version = "0.1.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
+dependencies = [
+ "lazy_static",
+]
+
+[[package]]
+name = "shlex"
+version = "1.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
+
+[[package]]
+name = "signal-hook"
+version = "0.3.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
+dependencies = [
+ "libc",
+ "signal-hook-registry",
+]
+
+[[package]]
+name = "signal-hook-mio"
+version = "0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
+dependencies = [
+ "libc",
+ "mio",
+ "signal-hook",
+]
+
+[[package]]
+name = "signal-hook-registry"
+version = "1.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1"
+dependencies = [
+ "libc",
+]
+
+[[package]]
+name = "simsimd"
+version = "6.0.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "be2ad0164e13e58a994d3dd1ff57d44cee87c445708e3acea7ad4f03a47092ce"
+dependencies = [
+ "cc",
+]
+
+[[package]]
+name = "smallvec"
+version = "1.13.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
+
+[[package]]
+name = "syn"
+version = "2.0.87"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-ident",
+]
+
+[[package]]
+name = "tap"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
+
+[[package]]
+name = "thread_local"
+version = "1.1.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
+dependencies = [
+ "cfg-if",
+ "once_cell",
+]
+
+[[package]]
+name = "tqdm"
+version = "0.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "aa2d2932240205a99b65f15d9861992c95fbb8c9fb280b3a1f17a92db6dc611f"
+dependencies = [
+ "anyhow",
+ "crossterm",
+ "once_cell",
+]
+
+[[package]]
+name = "tracing"
+version = "0.1.40"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
+dependencies = [
+ "pin-project-lite",
+ "tracing-attributes",
+ "tracing-core",
+]
+
+[[package]]
+name = "tracing-attributes"
+version = "0.1.27"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "tracing-core"
+version = "0.1.32"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
+dependencies = [
+ "once_cell",
+ "valuable",
+]
+
+[[package]]
+name = "tracing-log"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
+dependencies = [
+ "log",
+ "once_cell",
+ "tracing-core",
+]
+
+[[package]]
+name = "tracing-subscriber"
+version = "0.3.18"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
+dependencies = [
+ "nu-ansi-term",
+ "sharded-slab",
+ "smallvec",
+ "thread_local",
+ "tracing-core",
+ "tracing-log",
+]
+
+[[package]]
+name = "unicode-ident"
+version = "1.0.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
+
+[[package]]
+name = "valuable"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
+
+[[package]]
+name = "wasi"
+version = "0.11.0+wasi-snapshot-preview1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
+
+[[package]]
+name = "winapi"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
+dependencies = [
+ "winapi-i686-pc-windows-gnu",
+ "winapi-x86_64-pc-windows-gnu",
+]
+
+[[package]]
+name = "winapi-i686-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
+
+[[package]]
+name = "winapi-x86_64-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
+
+[[package]]
+name = "windows-sys"
+version = "0.48.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
+dependencies = [
+ "windows-targets 0.48.5",
+]
+
+[[package]]
+name = "windows-targets"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
+dependencies = [
+ "windows_aarch64_gnullvm 0.48.5",
+ "windows_aarch64_msvc 0.48.5",
+ "windows_i686_gnu 0.48.5",
+ "windows_i686_msvc 0.48.5",
+ "windows_x86_64_gnu 0.48.5",
+ "windows_x86_64_gnullvm 0.48.5",
+ "windows_x86_64_msvc 0.48.5",
+]
+
+[[package]]
+name = "windows-targets"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
+dependencies = [
+ "windows_aarch64_gnullvm 0.52.6",
+ "windows_aarch64_msvc 0.52.6",
+ "windows_i686_gnu 0.52.6",
+ "windows_i686_gnullvm",
+ "windows_i686_msvc 0.52.6",
+ "windows_x86_64_gnu 0.52.6",
+ "windows_x86_64_gnullvm 0.52.6",
+ "windows_x86_64_msvc 0.52.6",
+]
+
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
+
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
+
+[[package]]
+name = "windows_aarch64_msvc"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
+
+[[package]]
+name = "windows_aarch64_msvc"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
+
+[[package]]
+name = "windows_i686_gnu"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
+
+[[package]]
+name = "windows_i686_gnu"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
+
+[[package]]
+name = "windows_i686_gnullvm"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
+
+[[package]]
+name = "windows_i686_msvc"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
+
+[[package]]
+name = "windows_i686_msvc"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
+
+[[package]]
+name = "windows_x86_64_gnu"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
+
+[[package]]
+name = "windows_x86_64_gnu"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
+
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
+
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
+
+[[package]]
+name = "windows_x86_64_msvc"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
+
+[[package]]
+name = "windows_x86_64_msvc"
+version = "0.52.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
+
+[[package]]
+name = "wyz"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
+dependencies = [
+ "tap",
+]
diff --git a/diskann/Cargo.toml b/diskann/Cargo.toml
new file mode 100644
index 0000000..13701ad
--- /dev/null
+++ b/diskann/Cargo.toml
@@ -0,0 +1,28 @@
+[package]
+name = "diskann"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+half = { version = "2", features = ["bytemuck"] }
+fastrand = "2"
+tracing = "0.1"
+tracing-subscriber = "0.3"
+simsimd = "6"
+foldhash = "0.1"
+bitvec = "1"
+tqdm = "0.7"
+anyhow = "1"
+bytemuck = { version = "1", features = ["extern_crate_alloc"] }
+serde = { version = "1", features = ["derive"] }
+rmp-serde = "1"
+rayon = "1"
+matrixmultiply = "0.3"
+
+[lib]
+name = "diskann"
+path = "src/lib.rs"
+
+[[bin]]
+name = "diskann"
+path = "src/main.rs"
diff --git a/diskann/aopq_train.py b/diskann/aopq_train.py
new file mode 100644
index 0000000..03c840e
--- /dev/null
+++ b/diskann/aopq_train.py
@@ -0,0 +1,113 @@
+import numpy as np
+import msgpack
+import math
+import torch
+from torch import autograd
+import faiss
+import tqdm
+
+n_dims = 1152
+output_code_size = 64
+output_code_bits = 8
+output_codebook_size = 2**output_code_bits
+n_dims_per_code = n_dims // output_code_size
+dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
+queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
+device = "cpu"
+
+index = faiss.index_factory(n_dims, "HNSW32,SQfp16", faiss.METRIC_INNER_PRODUCT)
+index.train(queryset)
+index.add(queryset)
+print("index ready")
+
+T = 64
+
+nearby_query_indices = torch.zeros((dataset.shape[0], T), dtype=torch.int32)
+
+SEARCH_BATCH_SIZE = 1024
+
+for i in range(0, len(dataset), SEARCH_BATCH_SIZE):
+ res = index.search(dataset[i:i+SEARCH_BATCH_SIZE], T)
+ nearby_query_indices[i:i+SEARCH_BATCH_SIZE] = torch.tensor(res[1])
+
+print("query indices ready")
+
+def pq_assign(centroids, batch):
+ quantized = torch.zeros_like(batch)
+
+ # Assign to nearest centroid in each subspace
+ for dmin in range(0, n_dims, n_dims_per_code):
+ dmax = dmin + n_dims_per_code
+ similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T)
+ assignments = similarities.argmax(dim=1)
+ quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax]
+
+ return quantized
+
+# OOD-DiskANN (https://arxiv.org/abs/2211.12850) uses a more complicated scheme because it uses L2 norm
+# We only care about inner product so our quantization error (wrt a query) is just abs(dot(query, centroid - vector))
+# Directly optimize for this (wrt top queries; it might actually be better to use a random sample instead?)
+def partition(vectors, centroids, projection, opt, queries, nearby_query_indices, k, max_iter=100, batch_size=4096):
+ n_vectors = len(vectors)
+ perm = torch.randperm(n_vectors, device=device)
+
+ t = tqdm.trange(max_iter)
+ for iter in t:
+ total_loss = 0
+ opt.zero_grad(set_to_none=True)
+
+ for i in range(0, n_vectors, batch_size):
+ loss = torch.tensor(0.0, device=device)
+ batch = vectors[i:i+batch_size] @ projection
+ quantized = pq_assign(centroids, batch)
+ residuals = batch - quantized
+
+ # for each index in our set of nearby queries
+ for j in range(0, nearby_query_indices.shape[1]):
+ queries_for_batch_j = queries[nearby_query_indices[i:i+batch_size, j]]
+ # minimize quantiation error in direction of query, i.e. mean abs(dot(query, centroid - vector))
+ # PyTorch won't do batched dot products cleanly, to spite me. Do componentwise multiplication and reduce.
+ sg_errs = (queries_for_batch_j * residuals).sum(dim=-1)
+ loss += torch.mean(torch.abs(sg_errs))
+
+ total_loss += loss.detach().item()
+ loss.backward()
+
+ opt.step()
+
+ t.set_description(f"loss: {total_loss:.4f}")
+
+def random_ortho(dim):
+ h = torch.randn(dim, dim, device=device)
+ q, r = torch.linalg.qr(h)
+ return q
+
+# non-parametric OPQ algorithm (roughly)
+# https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/opq_tr.pdf
+projection = random_ortho(n_dims)
+vectors = torch.tensor(dataset, device=device)
+queries = torch.tensor(queryset, device=device)
+perm = torch.randperm(len(vectors), device=device)
+centroids = vectors[perm[:output_codebook_size]]
+centroids.requires_grad = True
+opt = torch.optim.Adam([centroids], lr=0.001)
+for i in range(30):
+ # update centroids to minimize query-aware quantization loss
+ partition(vectors, centroids, projection, opt, queries, nearby_query_indices, output_codebook_size, max_iter=8)
+ # compute new projection as R = VU^T from XY^T = USV^T (SVD)
+ # where X is dataset vectors, Y is quantized dataset vectors
+ with torch.no_grad():
+ y = pq_assign(centroids, vectors)
+ # paper uses D*N and not N*D in its descriptions for whatever reason (so we transpose when they don't)
+ u, s, vt = torch.linalg.svd(vectors.T @ y)
+ projection = vt.T @ u.T
+
+print("done")
+
+with open("opq.msgpack", "wb") as f:
+ msgpack.pack({
+ "centroids": centroids.detach().cpu().numpy().flatten().tolist(),
+ "transform": projection.cpu().numpy().flatten().tolist(),
+ "n_dims_per_code": n_dims_per_code,
+ "n_dims": n_dims
+ }, f)
diff --git a/diskann/flamegraph.svg b/diskann/flamegraph.svg
new file mode 100644
index 0000000..9b564c4
--- /dev/null
+++ b/diskann/flamegraph.svg
@@ -0,0 +1,491 @@
+
\ No newline at end of file
diff --git a/diskann/opq_test.py b/diskann/opq_test.py
new file mode 100644
index 0000000..a15dd99
--- /dev/null
+++ b/diskann/opq_test.py
@@ -0,0 +1,44 @@
+import numpy as np
+import msgpack
+import math
+import torch
+import faiss
+import tqdm
+
+n_dims = 1152
+output_code_size = 64
+output_code_bits = 8
+output_codebook_size = 2**output_code_bits
+n_dims_per_code = n_dims // output_code_size
+dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
+queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
+device = "cpu"
+
+def pq_assign(centroids, batch):
+ quantized = torch.zeros_like(batch)
+
+ # Assign to nearest centroid in each subspace
+ for dmin in range(0, n_dims, n_dims_per_code):
+ dmax = dmin + n_dims_per_code
+ similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T)
+ assignments = similarities.argmax(dim=1)
+ quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax]
+
+ return quantized
+
+with open("opq.msgpack", "rb") as f:
+ data = msgpack.unpack(f)
+ centroids = torch.tensor(data["centroids"], device=device).reshape(2**output_code_bits, n_dims)
+ projection = torch.tensor(data["transform"], device=device).reshape(n_dims, n_dims)
+
+vectors = torch.tensor(dataset, device=device)
+queries = torch.tensor(queryset, device=device)
+
+sample_size = 64
+qsample = pq_assign(centroids, vectors[:sample_size] @ projection)
+print(qsample)
+print(vectors[:sample_size])
+exact_results = vectors[:sample_size] @ queries[0]
+approx_results = qsample @ (projection @ queries[0])
+print(np.argsort(approx_results))
+print(np.argsort(exact_results))
diff --git a/diskann/rabitq.py b/diskann/rabitq.py
new file mode 100644
index 0000000..b606200
--- /dev/null
+++ b/diskann/rabitq.py
@@ -0,0 +1,68 @@
+# https://arxiv.org/pdf/2405.12497
+
+import numpy as np
+import msgpack
+import math
+import tqdm
+
+n_dims = 1152
+output_dims = 64*8
+scale = 1 / math.sqrt(n_dims)
+dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
+queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
+mean = np.mean(dataset, axis=0)
+
+centered_dataset = dataset - mean
+norms = np.linalg.norm(centered_dataset, axis=1)
+centered_dataset = centered_dataset / norms[:, np.newaxis]
+print(centered_dataset)
+
+sample = centered_dataset[:64]
+
+def random_ortho(dim):
+ h = np.random.randn(dim, dim)
+ q, r = np.linalg.qr(h)
+ return q
+
+p = random_ortho(n_dims) # algorithm only uses the inverse of P, so just sample that directly
+p = p[:output_dims, :]
+
+def quantize(datavecs):
+ xs = (p @ datavecs.T).T
+ quantized = xs > 0
+ dequantized = scale * (2 * quantized - 1)
+ dots = np.sum(dequantized * xs, axis=1) #
+ return quantized, dots
+
+qsample, dots = quantize(sample)
+print(qsample.sum(axis=1).mean())
+#print(dots)
+#print(dots.mean())
+
+def approx_dot(quantized_samples, dots, query):
+ mean_to_query = np.dot(mean, query)
+ print(mean_to_query)
+ dequantized = scale * (2 * quantized_samples - 1)
+ query_transformed = p @ query
+ o_bar_dot_q = np.sum(dequantized * query_transformed, axis=1)
+ return norms[:sample.shape[0]] * o_bar_dot_q * dots + mean_to_query
+
+print(norms)
+approx_results = approx_dot(qsample, dots, queryset[0])
+exact_results = sample @ queryset[0]
+
+for x in zip(approx_results, exact_results):
+ print(*x)
+
+print(*[ f"{x:.2f}" for x in (approx_results - exact_results) / np.abs(exact_results).mean() ])
+
+print(np.argsort(approx_results))
+print(np.argsort(exact_results))
+
+with open("rabitq.msgpack", "wb") as f:
+ msgpack.pack({
+ "mean": mean.flatten().tolist(),
+ "transform": p.flatten().tolist(),
+ "output_dims": output_dims,
+ "n_dims": n_dims
+ }, f)
diff --git a/diskann/scalar_quantize.py b/diskann/scalar_quantize.py
new file mode 100644
index 0000000..818610b
--- /dev/null
+++ b/diskann/scalar_quantize.py
@@ -0,0 +1,167 @@
+import numpy as np
+import msgpack
+import math
+
+n_dims = 1152
+n_buckets = n_dims
+#n_buckets = n_dims // 2 # we now have one quant scale per pair of components
+#pair_separation = 16 # for efficient dot product computation, we need to have the second element of a pair exactly chunk_size after the first
+n_dims_per_bucket = n_dims // n_buckets
+data = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # sorry
+
+CUTOFF = 1e-3 / 2
+
+print("computing quantiles")
+smin = np.quantile(data, CUTOFF, axis=0)
+smax = np.quantile(data, 1 - CUTOFF, axis=0)
+
+# naive O(n²) greedy algorithm
+# probably overbuilt for the 2-components-per-bucket case but I'm not getting rid of it
+def assign_buckets():
+ import random
+ intervals = list(enumerate(zip(smin, smax)))
+ random.shuffle(intervals)
+ buckets = [ [ intervals.pop() ] for _ in range(n_buckets) ]
+ def bucket_cost(bucket):
+ bmin = min(cmin for id, (cmin, cmax) in bucket)
+ bmax = max(cmax for id, (cmin, cmax) in bucket)
+ #print("MIN", bmin, "MAX", bmax)
+ return sum(abs(cmin - bmin) + abs(cmax - bmax) for id, (cmin, cmax) in bucket)
+ while len(intervals):
+ for bucket in buckets:
+ def new_interval_cost(interval):
+ return bucket_cost(bucket + [interval[1]])
+ i, interval = min(enumerate(intervals), key=new_interval_cost)
+ bucket.append(intervals.pop(i))
+ return buckets
+
+ranges = smax - smin
+# TODO: it is possible to do better assignment to buckets
+#order = np.argsort(ranges)
+print("bucket assignment")
+order = np.arange(n_dims) # np.concatenate(np.stack([ [ id for id, (cmin, cmax) in bucket ] for bucket in assign_buckets() ]))
+
+bucket_ranges = []
+bucket_centres = []
+bucket_absmax = []
+bucket_gmins = []
+
+for bucket_min in range(0, n_dims, n_dims_per_bucket):
+ bucket_max = bucket_min + n_dims_per_bucket
+ indices = order[bucket_min:bucket_max]
+ gmin = float(np.min(smin[indices]))
+ gmax = float(np.max(smax[indices]))
+ bucket_range = gmax - gmin
+ bucket_centre = (gmax + gmin) / 2
+ bucket_gmins.append(gmin)
+ bucket_ranges.append(bucket_range)
+ bucket_centres.append(bucket_centre)
+ bucket_absmax.append(max(abs(gmin), abs(gmax)))
+
+print("determining scales")
+scales = [] # multiply by float and convert to quantize
+offsets = []
+q_offsets = [] # int16 value to add at dot time
+q_scales = [] # rescales channel up at dot time; must be proportional(ish) to square of scale factor but NOT cause overflow in accumulation or PLMULLW
+scale_factor_bound = float("inf")
+for bucket in range(n_buckets):
+ step_size = bucket_ranges[bucket] / 255
+ scales.append(1 / step_size)
+ q_offset = int(bucket_gmins[bucket] / step_size)
+ q_offsets.append(q_offset)
+ nsfb = (2**31 - 1) / (n_dims_per_bucket * abs((255**2) + 2 * q_offset * 255 + q_offset ** 2)) / 2
+ # we are bounded both by overflow in accumulation and PLMULLW (u8 plus offset times scale factor)
+ scale_factor_bound = min(scale_factor_bound, nsfb, (2**15 - 1) // (q_offset + 255))
+ offsets.append(bucket_gmins[bucket])
+
+for bucket in range(n_buckets):
+ sfb = scale_factor_bound / max(map(lambda x: x ** 2, bucket_ranges))
+ sf = (bucket_ranges[bucket]) ** 2 * sfb
+ q_scales.append(int(sf))
+
+print(bucket_ranges, bucket_centres, bucket_absmax)
+print(scales, offsets, q_offsets, q_scales)
+
+"""
+interleave = np.concatenate([
+ np.arange(0, n_dims, n_dims_per_bucket) + a
+ for a in range(n_dims_per_bucket)
+])
+"""
+
+"""
+interleave = np.arange(0, n_dims)
+for base in range(0, n_dims, 2 * pair_separation):
+ interleave[base:base + pair_separation] = np.arange(base, base + 2 * pair_separation, 2)
+ interleave[base + pair_separation:base + 2 * pair_separation] = np.arange(base + 1, base + 2 * pair_separation + 1, 2)
+"""
+
+#print(bucket_ranges, bucket_centres, order[interleave])
+#print(ranges[order][interleave].tolist())
+#print(ranges.tolist())
+
+with open("quantizer.msgpack", "wb") as f:
+ msgpack.pack({
+ "permutation": order.tolist(),
+ "offsets": offsets,
+ "scales": scales,
+ "q_offsets": q_offsets,
+ "q_scales": q_scales
+ }, f)
+
+def rquantize(vec):
+ out = np.zeros(len(vec), dtype=np.uint8)
+ for i, p in enumerate(order[interleave]):
+ bucket = p % n_buckets
+ raw = vec[i]
+ raw = (raw - offsets[bucket]) * scales[bucket]
+ raw = min(max(raw, 0.0), 255.0)
+ out[p] = round(raw)
+ return out
+
+def rdquantize(bytes):
+ vec = np.zeros(n_dims, dtype=np.float32)
+ for i, p in enumerate(order[interleave]):
+ bucket = p % n_buckets
+ raw = float(bytes[p])
+ vec[i] = raw / scales[bucket] + offsets[bucket]
+ return vec
+
+def rdot(x, y):
+ xq_offsets = np.array(q_offsets, dtype=np.int16)
+ xq_scales = np.array(q_scales, dtype=np.int16)
+ assert x.shape == y.shape
+ assert x.dtype == np.uint8 == y.dtype
+ acc = 0
+ for i in range(0, len(x), n_buckets):
+ x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets
+ y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets
+ x1 *= xq_scales
+ acc += np.dot(x1.astype(np.int32), y1.astype(np.int32))
+ return acc
+
+def cmp(i, j):
+ return np.dot(data[i], data[j]) / rdot(rquantize(data[i]), rquantize(data[j]))
+
+def rdot_cmp(a, b):
+ x = rquantize(a)
+ y = rquantize(b)
+ a = a[order[interleave]]
+ b = b[order[interleave]]
+ xq_offsets = np.array(q_offsets, dtype=np.int16)
+ xq_scales = np.array(q_scales, dtype=np.int16)
+ assert x.shape == y.shape
+ assert x.dtype == np.uint8 == y.dtype
+ acc = 0
+ for i in range(0, len(x), n_buckets):
+ x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets
+ y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets
+ x1 *= xq_scales
+ component = np.dot(x1.astype(np.int32), y1.astype(np.int32))
+ a1 = a[i:i+n_buckets]
+ b1 = b[i:i+n_buckets]
+ component_exact = np.dot(a1, b1)
+ print(x1, a1, sep="\n")
+ print(component, component_exact, component / component_exact)
+ acc += component
+ return acc
diff --git a/diskann/src/lib.rs b/diskann/src/lib.rs
new file mode 100644
index 0000000..bd2a866
--- /dev/null
+++ b/diskann/src/lib.rs
@@ -0,0 +1,389 @@
+#![feature(pointer_is_aligned_to)]
+#![feature(test)]
+
+extern crate test;
+
+use foldhash::{HashSet, HashMap, HashMapExt, HashSetExt};
+use fastrand::Rng;
+use rayon::prelude::*;
+use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, Mutex};
+
+pub mod vector;
+use vector::{dot, fast_dot, fast_dot_noprefetch, to_svector, VectorRef, SVector, VectorList};
+
+// ParlayANN improves parallelism by not using locks like this and instead using smarter batch operations
+// but I don't have enough cores that it matters
+#[derive(Debug)]
+pub struct IndexGraph {
+ pub graph: Vec>>
+}
+
+impl IndexGraph {
+ pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self {
+ let mut graph = Vec::with_capacity(n);
+ for _ in 0..n {
+ let mut adjacency = Vec::with_capacity(capacity);
+ for _ in 0..r {
+ adjacency.push(rng.u32(0..(n as u32)));
+ }
+ graph.push(RwLock::new(adjacency));
+ }
+ IndexGraph {
+ graph
+ }
+ }
+
+ pub fn empty(n: usize, capacity: usize) -> IndexGraph {
+ let mut graph = Vec::with_capacity(n);
+ for _ in 0..n {
+ graph.push(RwLock::new(Vec::with_capacity(capacity)));
+ }
+ IndexGraph {
+ graph
+ }
+ }
+
+ fn out_neighbours(&self, pt: u32) -> RwLockReadGuard> {
+ self.graph[pt as usize].read().unwrap()
+ }
+
+ fn out_neighbours_mut(&self, pt: u32) -> RwLockWriteGuard> {
+ self.graph[pt as usize].write().unwrap()
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+pub struct IndexBuildConfig {
+ pub r: usize,
+ pub r_cap: usize,
+ pub l: usize,
+ pub maxc: usize,
+ pub alpha: i64
+}
+
+
+fn centroid(vecs: &VectorList) -> SVector {
+ let mut centroid = SVector::zero(vecs.d_emb);
+
+ for (i, vec) in vecs.iter().enumerate() {
+ let weight = 1.0 / (i + 1) as f32;
+ centroid += (to_svector(vec) - ¢roid) * weight;
+ }
+
+ centroid
+}
+
+pub fn medioid(vecs: &VectorList) -> u32 {
+ let centroid = centroid(vecs).half();
+ vecs.iter().map(|vec| dot(vec, &*centroid)).enumerate().max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()).unwrap().0 as u32
+}
+
+// neighbours list sorted by score descending
+// TODO: this may actually be an awful datastructure
+#[derive(Clone, Debug)]
+pub struct NeighbourBuffer {
+ pub ids: Vec,
+ scores: Vec,
+ visited: Vec,
+ next_unvisited: Option,
+ size: usize
+}
+
+impl NeighbourBuffer {
+ pub fn new(size: usize) -> Self {
+ NeighbourBuffer {
+ ids: Vec::with_capacity(size + 1),
+ scores: Vec::with_capacity(size + 1),
+ visited: Vec::with_capacity(size + 1), //bitvec::vec::BitVec::with_capacity(size),
+ next_unvisited: None,
+ size
+ }
+ }
+
+ pub fn next_unvisited(&mut self) -> Option {
+ //println!("next_unvisited: {:?}", self);
+ let mut cur = self.next_unvisited? as usize;
+ let old_cur = cur;
+ self.visited[cur] = true;
+ while cur < self.len() && self.visited[cur] {
+ cur += 1;
+ }
+ if cur == self.len() {
+ self.next_unvisited = None;
+ } else {
+ self.next_unvisited = Some(cur as u32);
+ }
+ Some(self.ids[old_cur])
+ }
+
+ pub fn len(&self) -> usize {
+ self.ids.len()
+ }
+
+ pub fn cap(&self) -> usize {
+ self.size
+ }
+
+ pub fn insert(&mut self, id: u32, score: i64) {
+ if self.len() == self.cap() && self.scores[self.len() - 1] > score {
+ return;
+ }
+
+ let loc = match self.scores.binary_search_by(|x| score.partial_cmp(&x).unwrap()) {
+ Ok(loc) => loc,
+ Err(loc) => loc
+ };
+
+ if self.ids.get(loc) == Some(&id) {
+ return;
+ }
+
+ // slightly inefficient but we avoid unsafe code
+ self.ids.insert(loc, id);
+ self.scores.insert(loc, score);
+ self.visited.insert(loc, false);
+ self.ids.truncate(self.size);
+ self.scores.truncate(self.size);
+ self.visited.truncate(self.size);
+
+ self.next_unvisited = Some(loc as u32);
+ }
+
+ pub fn clear(&mut self) {
+ self.ids.clear();
+ self.scores.clear();
+ self.visited.clear();
+ self.next_unvisited = None;
+ }
+}
+
+pub struct Scratch {
+ visited: HashSet,
+ pub neighbour_buffer: NeighbourBuffer,
+ neighbour_pre_buffer: Vec,
+ visited_list: Vec<(u32, i64)>,
+ robust_prune_scratch_buffer: Vec<(usize, u32)>
+}
+
+impl Scratch {
+ pub fn new(IndexBuildConfig { l, r, maxc, .. }: IndexBuildConfig) -> Self {
+ Scratch {
+ visited: HashSet::with_capacity(l * 8),
+ neighbour_buffer: NeighbourBuffer::new(l),
+ neighbour_pre_buffer: Vec::with_capacity(r),
+ visited_list: Vec::with_capacity(l * 8),
+ robust_prune_scratch_buffer: Vec::with_capacity(r)
+ }
+ }
+}
+
+pub struct GreedySearchCounters {
+ pub distances: usize
+}
+
+// Algorithm 1 from the DiskANN paper
+// We support the dot product metric only, so we want to keep things with the HIGHEST dot product
+pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters {
+ scratch.visited.clear();
+ scratch.neighbour_buffer.clear();
+ scratch.visited_list.clear();
+
+ scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vecs[start as usize]));
+ scratch.visited.insert(start);
+
+ let mut counters = GreedySearchCounters { distances: 0 };
+
+ while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
+ //println!("pt {} {:?}", pt, graph.out_neighbours(pt));
+ scratch.neighbour_pre_buffer.clear();
+ for &neighbour in graph.out_neighbours(pt).iter() {
+ if scratch.visited.insert(neighbour) {
+ scratch.neighbour_pre_buffer.push(neighbour);
+ }
+ }
+ for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
+ let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
+ let distance = fast_dot(query, &vecs[neighbour as usize], &vecs[next_neighbour as usize]);
+ counters.distances += 1;
+ scratch.neighbour_buffer.insert(neighbour, distance);
+ scratch.visited_list.push((neighbour, distance));
+ }
+ }
+
+ counters
+}
+
+type CandidateList = Vec<(u32, i64)>;
+
+fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList, config: IndexBuildConfig) {
+ let p_vec = &vecs[point as usize];
+ for (i, &n) in neigh.iter().enumerate() {
+ let dot = fast_dot(p_vec, &vecs[n as usize], &vecs[neigh[(i + 1) % neigh.len() as usize] as usize]);
+ candidates.push((n, dot));
+ }
+}
+
+// "Robust prune" algorithm, kind of
+// The algorithm in the paper does not actually match the code as implemented in microsoft/DiskANN
+// and that's slightly different from the one in ParlayANN for no reason
+// This is closer to ParlayANN
+fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec, vecs: &VectorList, config: IndexBuildConfig) {
+ neigh.clear();
+
+ let candidates = &mut scratch.visited_list;
+
+ // distance low to high = score high to low
+ candidates.sort_unstable_by_key(|&(_id, score)| -score);
+ candidates.truncate(config.maxc);
+
+ let mut candidate_index = 0;
+ while neigh.len() < config.r && candidate_index < candidates.len() {
+ let p_star = candidates[candidate_index].0;
+ candidate_index += 1;
+ if p_star == p || p_star == u32::MAX {
+ continue;
+ }
+
+ neigh.push(p_star);
+
+ scratch.robust_prune_scratch_buffer.clear();
+
+ // mark remaining candidates as not-to-be-used if "not much better than" current candidate
+ for i in (candidate_index+1)..candidates.len() {
+ let p_prime = candidates[i].0;
+ if p_prime != u32::MAX {
+ scratch.robust_prune_scratch_buffer.push((i, p_prime));
+ }
+ }
+
+ for (i, &(ci, p_prime)) in scratch.robust_prune_scratch_buffer.iter().enumerate() {
+ let next_vec = &vecs[scratch.robust_prune_scratch_buffer[(i + 1) % scratch.robust_prune_scratch_buffer.len()].0 as usize];
+ let p_star_prime_score = fast_dot(&vecs[p_prime as usize], &vecs[p_star as usize], next_vec);
+ let p_prime_p_score = candidates[ci].1;
+ let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16;
+
+ if alpha_times_p_star_prime_score >= p_prime_p_score {
+ candidates[ci].0 = u32::MAX;
+ }
+ }
+ }
+}
+
+pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &VectorList, config: IndexBuildConfig) {
+ assert!(vecs.len() < u32::MAX as usize);
+
+ let mut sigmas: Vec = (0..(vecs.len() as u32)).collect();
+ rng.shuffle(&mut sigmas);
+
+ let rng = Mutex::new(rng.fork());
+
+ //let scratch = &mut Scratch::new(config);
+ //let mut rng = rng.lock().unwrap();
+ sigmas.into_par_iter().for_each_init(|| (Scratch::new(config), rng.lock().unwrap().fork()), |(scratch, rng), sigma_i| {
+ //sigmas.into_iter().for_each(|sigma_i| {
+ greedy_search(scratch, medioid, &vecs[sigma_i as usize], vecs, &graph, config);
+
+ {
+ let n = graph.out_neighbours(sigma_i);
+ merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs, config);
+ }
+
+ {
+ let mut n = graph.out_neighbours_mut(sigma_i);
+ robust_prune(scratch, sigma_i, &mut *n, vecs, config);
+ }
+
+ let neighbours = graph.out_neighbours(sigma_i).to_owned();
+ for neighbour in neighbours {
+ let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
+ // To cut down pruning time slightly, allow accumulating more neighbours than usual limit
+ if neighbour_neighbours.len() == config.r_cap {
+ let mut n = neighbour_neighbours.to_vec();
+ scratch.visited_list.clear();
+ merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
+ merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config);
+ robust_prune(scratch, neighbour, &mut n, vecs, config);
+ } else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap {
+ neighbour_neighbours.push(sigma_i);
+ }
+ }
+ });
+}
+
+// RoarGraph's AcquireNeighbours algorithm is actually almost identical to Vamana/DiskANN's RobustPrune, but with fixed α = 1.0.
+// We replace Vamana's random initialization of the graph with Neighbourhood-Aware Projection from RoarGraph - there's no way to use a large enough
+// query set that I would be confident in using *only* RoarGraph's algorithm
+pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec>, query_knns_bwd: &Vec>, config: IndexBuildConfig, vecs: &VectorList) {
+ let mut sigmas: Vec = (0..(graph.graph.len() as u32)).collect();
+ rng.shuffle(&mut sigmas);
+
+ // Iterate through graph vertices in a random order
+ let rng = Mutex::new(rng.fork());
+ sigmas.into_par_iter().for_each_init(|| (rng.lock().unwrap().fork(), Scratch::new(config)), |(rng, scratch), sigma_i| {
+ scratch.visited.clear();
+ scratch.visited_list.clear();
+ scratch.neighbour_pre_buffer.clear();
+ for &query_neighbour in query_knns[sigma_i as usize].iter() {
+ for &projected_neighbour in query_knns_bwd[query_neighbour as usize].iter() {
+ if scratch.visited.insert(projected_neighbour) {
+ scratch.neighbour_pre_buffer.push(projected_neighbour);
+ }
+ }
+ }
+ rng.shuffle(&mut scratch.neighbour_pre_buffer);
+ scratch.neighbour_pre_buffer.truncate(config.maxc * 2);
+ for (i, &projected_neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
+ let score = fast_dot(&vecs[sigma_i as usize], &vecs[projected_neighbour as usize], &vecs[scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()] as usize]);
+ scratch.visited_list.push((projected_neighbour, score));
+ }
+ let mut neighbours = graph.out_neighbours_mut(sigma_i);
+ robust_prune(scratch, sigma_i, &mut *neighbours, vecs, config);
+ })
+}
+
+pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec>, query_knns_bwd: Vec>, config: IndexBuildConfig) {
+ let mut sigmas: Vec = (0..(graph.graph.len() as u32)).collect();
+ rng.shuffle(&mut sigmas);
+
+ // Iterate through graph vertices in a random order
+ let rng = Mutex::new(rng.fork());
+ sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
+ let mut neighbours = graph.out_neighbours_mut(sigma_i);
+ let mut i = 0;
+ while neighbours.len() < config.r_cap && i < 100 {
+ let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
+ let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
+ if !neighbours.contains(&projected_neighbour) {
+ neighbours.push(projected_neighbour);
+ }
+ i += 1;
+ }
+ })
+}
+
+pub fn random_fill_graph(rng: &mut Rng, graph: &mut IndexGraph, r: usize) {
+ let rng = Mutex::new(rng.fork());
+ (0..graph.graph.len() as u32).into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, i| {
+ let mut neighbours = graph.out_neighbours_mut(i);
+ while neighbours.len() < r {
+ let next = rng.u32(0..(graph.graph.len() as u32));
+ if !neighbours.contains(&next) {
+ neighbours.push(next);
+ }
+ }
+ });
+}
+
+pub struct Timer(&'static str, std::time::Instant);
+
+impl Timer {
+ pub fn new(name: &'static str) -> Self {
+ Timer(name, std::time::Instant::now())
+ }
+}
+
+impl Drop for Timer {
+ fn drop(&mut self) {
+ println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32());
+ }
+}
diff --git a/diskann/src/main.rs b/diskann/src/main.rs
new file mode 100644
index 0000000..d687424
--- /dev/null
+++ b/diskann/src/main.rs
@@ -0,0 +1,121 @@
+#![feature(test)]
+#![feature(pointer_is_aligned_to)]
+
+extern crate test;
+
+use std::{io::Read, time::Instant};
+use anyhow::Result;
+use half::f16;
+
+use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, dot, VectorList, self}, Timer};
+use simsimd::SpatialSimilarity;
+
+const D_EMB: usize = 1152;
+
+fn load_file(path: &str, trunc: Option) -> Result {
+ let mut input = std::fs::File::open(path)?;
+ let mut buf = Vec::new();
+ input.read_to_end(&mut buf)?;
+ // TODO: this is not particularly efficient
+ let f16s = bytemuck::cast_slice::<_, f16>(&buf)[0..trunc.unwrap_or(buf.len()/2)].iter().copied().collect();
+ Ok(VectorList::from_f16s(f16s, D_EMB))
+}
+
+const PQ_TEST_SIZE: usize = 1000;
+
+fn main() -> Result<()> {
+ tracing_subscriber::fmt::init();
+
+ {
+ let file = std::fs::File::open("opq.msgpack")?;
+ let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?;
+ let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::>();
+ let codes = codec.quantize_batch(&input);
+ println!("{:?}", codes);
+ let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::>();
+ let query = codec.preprocess_query(&raw_query);
+ let mut real_scores = vec![];
+ for i in 0..PQ_TEST_SIZE {
+ real_scores.push(SpatialSimilarity::dot(&raw_query, &input[i * D_EMB .. (i + 1) * D_EMB]).unwrap() as f32);
+ }
+ let pq_scores = codec.asymmetric_dot_product(&query, &codes);
+ for (x, y) in real_scores.iter().zip(pq_scores.iter()) {
+ println!("{} {} {} {}", x, y, x - y, (x - y) / x);
+ }
+ }
+
+ let mut rng = fastrand::Rng::with_seed(1);
+
+ let n = 100000;
+ let vecs = {
+ let _timer = Timer::new("loaded vectors");
+
+ &load_file("embeddings.bin", Some(D_EMB * n))?
+ };
+
+ let (graph, medioid) = {
+ let _timer = Timer::new("index built");
+
+ let mut config = IndexBuildConfig {
+ r: 64,
+ r_cap: 80,
+ l: 128,
+ maxc: 750,
+ alpha: 65536,
+ };
+
+ let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap);
+
+ let medioid = medioid(&vecs);
+
+ build_graph(&mut rng, &mut graph, medioid, &vecs, config);
+ config.alpha = 58000;
+ build_graph(&mut rng, &mut graph, medioid, &vecs, config);
+
+ (graph, medioid)
+ };
+
+ let mut edge_ctr = 0;
+
+ for adjlist in graph.graph.iter() {
+ edge_ctr += adjlist.read().unwrap().len();
+ }
+
+ println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32);
+
+ let time = Instant::now();
+ let mut recall = 0;
+ let mut cmps_ctr = 0;
+ let mut cmps = vec![];
+
+ let mut config = IndexBuildConfig {
+ r: 64,
+ r_cap: 64,
+ l: 50,
+ alpha: 65536,
+ maxc: 0,
+ };
+
+ let mut scratch = Scratch::new(config);
+
+ for (i, vec) in tqdm::tqdm(vecs.iter().enumerate()) {
+ let ctr = greedy_search(&mut scratch, medioid, &vec, &vecs, &graph, config);
+ cmps_ctr += ctr.distances;
+ cmps.push(ctr.distances);
+ if scratch.neighbour_buffer.ids[0] == (i as u32) {
+ recall += 1;
+ }
+ }
+
+ cmps.sort();
+
+ let end = time.elapsed();
+
+ println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n);
+ println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n);
+ println!("median comparisons: {}", cmps[cmps.len() / 2]);
+ //println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries);
+ println!("{} QPS", n as f32 / end.as_secs_f32());
+
+ Ok(())
+}
diff --git a/diskann/src/vector.rs b/diskann/src/vector.rs
new file mode 100644
index 0000000..2283976
--- /dev/null
+++ b/diskann/src/vector.rs
@@ -0,0 +1,449 @@
+use core::f32;
+
+use half::f16;
+use simsimd::SpatialSimilarity;
+use fastrand::Rng;
+use serde::{Serialize, Deserialize};
+use tracing_subscriber::field::RecordFields;
+
+#[derive(Debug, Clone)]
+pub struct Vector(Vec);
+#[derive(Debug, Clone)]
+pub struct SVector(Vec);
+
+pub type VectorRef<'a> = &'a [f16];
+pub type QVectorRef<'a> = &'a [u8];
+pub type SVectorRef<'a> = &'a [f32];
+
+impl SVector {
+ pub fn zero(d: usize) -> Self {
+ SVector(vec![0.0; d])
+ }
+ pub fn half(&self) -> Vector {
+ Vector(self.0.iter().map(|a| f16::from_f32(*a)).collect())
+ }
+}
+
+fn box_muller(rng: &mut Rng) -> f32 {
+ loop {
+ let u = rng.f32();
+ let v = rng.f32();
+ let x = (v * std::f32::consts::TAU).cos() * (-2.0 * u.ln()).sqrt();
+ if x.is_finite() {
+ return x;
+ }
+ }
+}
+
+impl Vector {
+ pub fn zero(d: usize) -> Self {
+ Vector(vec![f16::from_f32(0.0); d])
+ }
+
+ pub fn randn(rng: &mut Rng, d: usize) -> Self {
+ Vector(Vec::from_iter((0..d).map(|_| f16::from_f32(box_muller(rng)))))
+ }
+}
+
+// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
+const SCALE: f32 = 281474976710656.0;
+const SCALE_F64: f64 = SCALE as f64;
+
+pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> f32 {
+ // safety is not real
+ (simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap()) as f32
+}
+
+pub fn to_svector(vec: VectorRef) -> SVector {
+ SVector(vec.iter().map(|a| a.to_f32()).collect())
+}
+
+impl<'a> std::ops::AddAssign> for SVector {
+ fn add_assign(&mut self, other: VectorRef<'a>) {
+ self.0.iter_mut().zip(other.iter()).for_each(|(a, b)| *a += b.to_f32());
+ }
+}
+
+impl std::ops::Div for SVector {
+ type Output = Self;
+
+ fn div(self, b: f32) -> Self::Output {
+ SVector(self.0.iter().map(|a| a / b).collect())
+ }
+}
+
+impl std::ops::Deref for Vector {
+ type Target = [f16];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl std::ops::Deref for SVector {
+ type Target = [f32];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl std::ops::Add<&SVector> for SVector {
+ type Output = Self;
+
+ fn add(self, other: &Self) -> Self::Output {
+ SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a + b).collect())
+ }
+}
+
+impl std::ops::Sub<&SVector> for SVector {
+ type Output = Self;
+
+ fn sub(self, other: &Self) -> Self::Output {
+ SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a - b).collect())
+ }
+}
+
+impl std::ops::AddAssign for SVector {
+ fn add_assign(&mut self, other: Self) {
+ self.0.iter_mut().zip(other.0.iter()).for_each(|(a, b)| *a += b);
+ }
+}
+
+impl std::ops::Mul for SVector {
+ type Output = Self;
+
+ fn mul(self, other: f32) -> Self {
+ SVector(self.0.iter().map(|a| *a * other).collect())
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct VectorList {
+ pub d_emb: usize,
+ pub length: usize,
+ pub data: Vec
+}
+
+impl std::ops::Index for VectorList {
+ type Output = [f16];
+
+ fn index(&self, index: usize) -> &Self::Output {
+ &self.data[index * self.d_emb..(index + 1) * self.d_emb]
+ }
+}
+
+pub struct VectorListIterator<'a> {
+ list: &'a VectorList,
+ index: usize
+}
+
+impl<'a> Iterator for VectorListIterator<'a> {
+ type Item = VectorRef<'a>;
+
+ fn next(&mut self) -> Option {
+ if self.index < self.list.len() {
+ let ret = &self.list[self.index];
+ self.index += 1;
+ Some(ret)
+ } else {
+ None
+ }
+ }
+}
+
+
+impl VectorList {
+ pub fn len(&self) -> usize {
+ self.length
+ }
+
+ pub fn iter(&self) -> VectorListIterator {
+ VectorListIterator {
+ list: self,
+ index: 0
+ }
+ }
+
+ pub fn empty(d: usize) -> Self {
+ VectorList {
+ d_emb: d,
+ length: 0,
+ data: Vec::new()
+ }
+ }
+
+ pub fn from_f16s(f16s: Vec, d: usize) -> Self {
+ assert!(f16s.len() % d == 0);
+ VectorList {
+ d_emb: d,
+ length: f16s.len() / d,
+ data: f16s
+ }
+ }
+
+ pub fn push(&mut self, vec: VectorRef) {
+ self.length += 1;
+ self.data.extend_from_slice(vec);
+ }
+}
+
+// SimSIMD has its own, but ours prefetches concurrently, is unrolled more, ignores inconveniently-sized vectors and does a cheaper reduction
+// Also, we return an int because floats are annoying (not Ord)
+// On Tiger Lake (i5-1135G7) we have about a 3x performance advantage ignoring the prefetching
+// (it would be better to use AVX512 for said CPU but this also has to run on Zen 3)
+pub fn fast_dot(x: VectorRef, y: VectorRef, prefetch: VectorRef) -> i64 {
+ use std::arch::x86_64::*;
+
+ debug_assert!(x.len() == y.len());
+ debug_assert!(prefetch.len() == x.len());
+ debug_assert!(x.len() % 64 == 0);
+
+ // safety is not real
+ // it's probably fine I guess
+ unsafe {
+ let mut x_ptr = x.as_ptr();
+ let mut y_ptr = y.as_ptr();
+ let end = x_ptr.add(x.len());
+ let mut prefetch_ptr = prefetch.as_ptr();
+
+ let mut acc1 = _mm256_setzero_ps();
+ let mut acc2 = _mm256_setzero_ps();
+ let mut acc3 = _mm256_setzero_ps();
+ let mut acc4 = _mm256_setzero_ps();
+
+ while x_ptr < end {
+ // fetch chunks and prefetch next vector
+ let x1 = _mm256_loadu_si256(x_ptr as *const __m256i);
+ let y1 = _mm256_loadu_si256(y_ptr as *const __m256i);
+ let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i);
+ let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i);
+ // technically, we only have to do this once per cache line but I don't care enough to test every way to optimize this
+ _mm_prefetch(prefetch_ptr as *const i8, _MM_HINT_T0);
+ x_ptr = x_ptr.add(32); // move 16 f16s at a time
+ y_ptr = y_ptr.add(32);
+ prefetch_ptr = prefetch_ptr.add(32);
+
+ // unpack f32 to f16
+ let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0));
+ let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1));
+ let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0));
+ let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1));
+ let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0));
+ let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1));
+ let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0));
+ let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1));
+
+ acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1);
+ acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2);
+ acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3);
+ acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4);
+ }
+
+ // reduce
+ let acc1 = _mm256_add_ps(acc1, acc2);
+ let acc2 = _mm256_add_ps(acc3, acc4);
+
+ let hsum = _mm256_hadd_ps(acc1, acc2);
+ let hsum_lo = _mm256_extractf128_ps(hsum, 0);
+ let hsum_hi = _mm256_extractf128_ps(hsum, 1);
+ let hsum = _mm_add_ps(hsum_lo, hsum_hi);
+
+ let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32);
+ (floatval * SCALE) as i64
+ }
+}
+
+// same as above, without prefetch pointer
+pub fn fast_dot_noprefetch(x: VectorRef, y: VectorRef) -> i64 {
+ use std::arch::x86_64::*;
+
+ debug_assert!(x.len() == y.len());
+ debug_assert!(x.len() % 64 == 0);
+
+ unsafe {
+ let mut x_ptr = x.as_ptr();
+ let mut y_ptr = y.as_ptr();
+ let end = x_ptr.add(x.len());
+
+ let mut acc1 = _mm256_setzero_ps();
+ let mut acc2 = _mm256_setzero_ps();
+ let mut acc3 = _mm256_setzero_ps();
+ let mut acc4 = _mm256_setzero_ps();
+
+ while x_ptr < end {
+ let x1 = _mm256_loadu_si256(x_ptr as *const __m256i);
+ let y1 = _mm256_loadu_si256(y_ptr as *const __m256i);
+ let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i);
+ let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i);
+ x_ptr = x_ptr.add(32);
+ y_ptr = y_ptr.add(32);
+
+ let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0));
+ let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1));
+ let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0));
+ let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1));
+ let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0));
+ let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1));
+ let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0));
+ let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1));
+
+ acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1);
+ acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2);
+ acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3);
+ acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4);
+ }
+
+ // reduce
+ let acc1 = _mm256_add_ps(acc1, acc2);
+ let acc2 = _mm256_add_ps(acc3, acc4);
+
+ let hsum = _mm256_hadd_ps(acc1, acc2);
+ let hsum_lo = _mm256_extractf128_ps(hsum, 0);
+ let hsum_hi = _mm256_extractf128_ps(hsum, 1);
+ let hsum = _mm_add_ps(hsum_lo, hsum_hi);
+
+ let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32);
+ (floatval * SCALE) as i64
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ProductQuantizer {
+ centroids: Vec,
+ transform: Vec, // D*D orthonormal matrix
+ pub n_dims_per_code: usize,
+ pub n_dims: usize
+}
+
+// chunk * centroid_index
+pub struct DistanceLUT(Vec);
+
+impl ProductQuantizer {
+ pub fn apply_transform(&self, x: &[f32]) -> Vec {
+ let dim = self.n_dims;
+ let n_vectors = x.len() / dim;
+ let mut transformed = vec![0.0; n_vectors * dim];
+ // transform_matrix (D * D) @ batch.T (D * B)
+ unsafe {
+ matrixmultiply::sgemm(dim, dim, n_vectors, 1.0, self.transform.as_ptr(), dim as isize, 1, x.as_ptr(), 1, dim as isize, 0.0, transformed.as_mut_ptr(), 1, dim as isize);
+ }
+ transformed
+ }
+
+ pub fn quantize_batch(&self, x: &[f32]) -> Vec {
+ // x is B * D
+ let dim = self.n_dims;
+ assert_eq!(dim * dim, self.transform.len());
+ let n_vectors = x.len() / dim;
+ let n_centroids = self.centroids.len() / dim;
+ assert!(n_centroids <= 256);
+ let transformed = self.apply_transform(&x); // B * D, as we write sgemm result in a weird order
+ let mut codes = vec![0; n_vectors * dim / self.n_dims_per_code];
+ let vec_len_codes = dim / self.n_dims_per_code;
+
+ // B * C buffer of similarity of each vector to each centroid, within subspace
+ let mut scratch = vec![0.0; n_vectors * n_centroids];
+
+ for i in 0..(dim / self.n_dims_per_code) {
+ let offset = i * self.n_dims_per_code;
+ // transformed_batch[:, range] (B * D_r) @ centroids[:, range].T (D_r * C)
+ unsafe {
+ matrixmultiply::sgemm(n_vectors, self.n_dims_per_code, n_centroids, 1.0, transformed.as_ptr().add(offset), dim as isize, 1, self.centroids.as_ptr().add(offset), 1, dim as isize, 0.0, scratch.as_mut_ptr(), n_centroids as isize, 1);
+ }
+ // assign this component to best centroid
+ for i_vec in 0..n_vectors {
+ let mut best = f32::NEG_INFINITY;
+ for i_centroid in 0..n_centroids {
+ let score = scratch[i_vec * n_centroids + i_centroid];
+ if score > best {
+ best = score;
+ codes[i_vec * vec_len_codes + i] = i_centroid as u8;
+ }
+ }
+ }
+ }
+ codes
+ }
+
+ // not particularly performance-sensitive right now; do unbatched
+ pub fn preprocess_query(&self, query: &[f32]) -> DistanceLUT {
+ let transformed = self.apply_transform(query);
+ let n_chunks = self.n_dims / self.n_dims_per_code;
+ let n_centroids = self.centroids.len() / self.n_dims;
+ let mut lut = Vec::with_capacity(n_chunks * n_centroids);
+
+ for i in 0..n_chunks {
+ let vec_component = &transformed[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code];
+ for j in 0..n_centroids {
+ let centroid = &self.centroids[j * self.n_dims..(j + 1) * self.n_dims];
+ let centroid_component = ¢roid[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code];
+ let score = SpatialSimilarity::dot(vec_component, centroid_component).unwrap();
+ lut.push(score as f32);
+ }
+ }
+
+ DistanceLUT(lut)
+ }
+
+ // compute dot products of query against product-quantized vectors
+ pub fn asymmetric_dot_product(&self, query: &DistanceLUT, pq_vectors: &[u8]) -> Vec {
+ let n_chunks = self.n_dims / self.n_dims_per_code;
+ let n_vectors = pq_vectors.len() / n_chunks;
+ let mut scores = vec![0.0; n_vectors];
+ let n_centroids = self.centroids.len() / self.n_dims;
+
+ for i in 0..n_chunks {
+ for j in 0..n_vectors {
+ let code = pq_vectors[j * n_chunks + i];
+ let chunk_score = query.0[i * n_centroids + code as usize];
+ scores[j] += chunk_score;
+ }
+ }
+
+ // I have no idea why but we somehow have significant degradation in search quality
+ // if this accumulates in integers. As such, do floats and convert at the end.
+ // I'm sure there are fascinating reasons for this, but God is dead, God remains dead, etc.
+ scores.into_iter().map(|x| (x * SCALE) as i64).collect()
+ }
+}
+
+pub fn scale_dot_result(x: f64) -> i64 {
+ (x * SCALE_F64) as i64
+}
+
+#[cfg(test)]
+mod bench {
+ use super::*;
+ use test::Bencher;
+
+ #[bench]
+ fn bench_dot(be: &mut Bencher) {
+ let mut rng = fastrand::Rng::with_seed(1);
+ let a = Vector::randn(&mut rng, 1024);
+ let b = Vector::randn(&mut rng, 1024);
+ be.iter(|| {
+ dot(&a, &b)
+ });
+ }
+
+ #[bench]
+ fn bench_fastdot(be: &mut Bencher) {
+ let mut rng = fastrand::Rng::with_seed(1);
+ let a = Vector::randn(&mut rng, 1024);
+ let b = Vector::randn(&mut rng, 1024);
+ be.iter(|| {
+ fast_dot(&a, &b, &a)
+ });
+ }
+
+ #[bench]
+ fn bench_fastdot_noprefetch(be: &mut Bencher) {
+ let mut rng = fastrand::Rng::with_seed(1);
+ let a = Vector::randn(&mut rng, 1024);
+ let b = Vector::randn(&mut rng, 1024);
+ be.iter(|| {
+ fast_dot_noprefetch(&a, &b)
+ });
+ }
+}
diff --git a/diskann/vec_dist.py b/diskann/vec_dist.py
new file mode 100644
index 0000000..1fbe0a6
--- /dev/null
+++ b/diskann/vec_dist.py
@@ -0,0 +1,44 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+A = 0.6
+LOG_A = np.log(A)
+
+def scale(xs):
+ return np.sign(xs) * (np.log(np.abs(xs) + A) - LOG_A)
+
+n_dims = 1152
+n_used_dims = 32
+data = np.frombuffer(open("embeddings.bin", "rb").read(), dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # TODO
+
+# Create histogram bins
+n_bins = 256
+s = __import__("math").sqrt(n_dims)
+hist_range = (-1.2, 1.2)
+histogram_data = np.zeros((n_used_dims, n_bins))
+
+# Calculate histograms for each dimension
+for dim in range(n_used_dims):
+ dbd = data[:, dim]
+ dbd = (dbd - np.mean(dbd)) / np.std(dbd)
+ dbd = scale(dbd)
+ hist, _ = np.histogram(dbd, bins=n_bins, range=hist_range, density=True)
+ histogram_data[dim] = hist
+
+# Create heatmap
+plt.figure(figsize=(12, 8))
+sns.heatmap(histogram_data,
+ cmap='viridis',
+ xticklabels=np.linspace(hist_range[0], hist_range[1], n_bins),
+ yticklabels=range(n_used_dims),
+ cbar_kws={'label': 'Density'})
+
+plt.xlabel('Value')
+plt.ylabel('Dimension')
+plt.title('Distribution Heatmap of First 16 Dimensions')
+
+# Adjust layout to prevent label cutoff
+plt.tight_layout()
+
+plt.show()