diff --git a/.gitignore b/.gitignore index c804eb0..4fb341b 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,9 @@ node_modules/* node_modules *sqlite3* thumbtemp -mse-test-db-small \ No newline at end of file +mse-test-db-small +clipfront2/static/bg* +diskann/target +*.bin +*.msgpack +*/flamegraph.svg diff --git a/diskann/Cargo.lock b/diskann/Cargo.lock new file mode 100644 index 0000000..f9a41da --- /dev/null +++ b/diskann/Cargo.lock @@ -0,0 +1,748 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "anyhow" +version = "1.0.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "bytemuck" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cc" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crossterm" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67" +dependencies = [ + "bitflags 1.3.2", + "crossterm_winapi", + "libc", + "mio", + "parking_lot", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "diskann" +version = "0.1.0" +dependencies = [ + "anyhow", + "bitvec", + "bytemuck", + "fastrand", + "foldhash", + "half", + "matrixmultiply", + "rayon", + "rmp-serde", + "serde", + "simsimd", + "tqdm", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "fastrand" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" + +[[package]] +name = "foldhash" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys", +] + +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.6", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pin-project-lite" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" + +[[package]] +name = "proc-macro2" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags 2.6.0", +] + +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.215" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.215" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" +dependencies = [ + "libc", + "mio", + "signal-hook", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "simsimd" +version = "6.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be2ad0164e13e58a994d3dd1ff57d44cee87c445708e3acea7ad4f03a47092ce" +dependencies = [ + "cc", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "syn" +version = "2.0.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + +[[package]] +name = "tqdm" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2d2932240205a99b65f15d9861992c95fbb8c9fb280b3a1f17a92db6dc611f" +dependencies = [ + "anyhow", + "crossterm", + "once_cell", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] diff --git a/diskann/Cargo.toml b/diskann/Cargo.toml new file mode 100644 index 0000000..13701ad --- /dev/null +++ b/diskann/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "diskann" +version = "0.1.0" +edition = "2021" + +[dependencies] +half = { version = "2", features = ["bytemuck"] } +fastrand = "2" +tracing = "0.1" +tracing-subscriber = "0.3" +simsimd = "6" +foldhash = "0.1" +bitvec = "1" +tqdm = "0.7" +anyhow = "1" +bytemuck = { version = "1", features = ["extern_crate_alloc"] } +serde = { version = "1", features = ["derive"] } +rmp-serde = "1" +rayon = "1" +matrixmultiply = "0.3" + +[lib] +name = "diskann" +path = "src/lib.rs" + +[[bin]] +name = "diskann" +path = "src/main.rs" diff --git a/diskann/aopq_train.py b/diskann/aopq_train.py new file mode 100644 index 0000000..03c840e --- /dev/null +++ b/diskann/aopq_train.py @@ -0,0 +1,113 @@ +import numpy as np +import msgpack +import math +import torch +from torch import autograd +import faiss +import tqdm + +n_dims = 1152 +output_code_size = 64 +output_code_bits = 8 +output_codebook_size = 2**output_code_bits +n_dims_per_code = n_dims // output_code_size +dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) +queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) +device = "cpu" + +index = faiss.index_factory(n_dims, "HNSW32,SQfp16", faiss.METRIC_INNER_PRODUCT) +index.train(queryset) +index.add(queryset) +print("index ready") + +T = 64 + +nearby_query_indices = torch.zeros((dataset.shape[0], T), dtype=torch.int32) + +SEARCH_BATCH_SIZE = 1024 + +for i in range(0, len(dataset), SEARCH_BATCH_SIZE): + res = index.search(dataset[i:i+SEARCH_BATCH_SIZE], T) + nearby_query_indices[i:i+SEARCH_BATCH_SIZE] = torch.tensor(res[1]) + +print("query indices ready") + +def pq_assign(centroids, batch): + quantized = torch.zeros_like(batch) + + # Assign to nearest centroid in each subspace + for dmin in range(0, n_dims, n_dims_per_code): + dmax = dmin + n_dims_per_code + similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T) + assignments = similarities.argmax(dim=1) + quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax] + + return quantized + +# OOD-DiskANN (https://arxiv.org/abs/2211.12850) uses a more complicated scheme because it uses L2 norm +# We only care about inner product so our quantization error (wrt a query) is just abs(dot(query, centroid - vector)) +# Directly optimize for this (wrt top queries; it might actually be better to use a random sample instead?) +def partition(vectors, centroids, projection, opt, queries, nearby_query_indices, k, max_iter=100, batch_size=4096): + n_vectors = len(vectors) + perm = torch.randperm(n_vectors, device=device) + + t = tqdm.trange(max_iter) + for iter in t: + total_loss = 0 + opt.zero_grad(set_to_none=True) + + for i in range(0, n_vectors, batch_size): + loss = torch.tensor(0.0, device=device) + batch = vectors[i:i+batch_size] @ projection + quantized = pq_assign(centroids, batch) + residuals = batch - quantized + + # for each index in our set of nearby queries + for j in range(0, nearby_query_indices.shape[1]): + queries_for_batch_j = queries[nearby_query_indices[i:i+batch_size, j]] + # minimize quantiation error in direction of query, i.e. mean abs(dot(query, centroid - vector)) + # PyTorch won't do batched dot products cleanly, to spite me. Do componentwise multiplication and reduce. + sg_errs = (queries_for_batch_j * residuals).sum(dim=-1) + loss += torch.mean(torch.abs(sg_errs)) + + total_loss += loss.detach().item() + loss.backward() + + opt.step() + + t.set_description(f"loss: {total_loss:.4f}") + +def random_ortho(dim): + h = torch.randn(dim, dim, device=device) + q, r = torch.linalg.qr(h) + return q + +# non-parametric OPQ algorithm (roughly) +# https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/opq_tr.pdf +projection = random_ortho(n_dims) +vectors = torch.tensor(dataset, device=device) +queries = torch.tensor(queryset, device=device) +perm = torch.randperm(len(vectors), device=device) +centroids = vectors[perm[:output_codebook_size]] +centroids.requires_grad = True +opt = torch.optim.Adam([centroids], lr=0.001) +for i in range(30): + # update centroids to minimize query-aware quantization loss + partition(vectors, centroids, projection, opt, queries, nearby_query_indices, output_codebook_size, max_iter=8) + # compute new projection as R = VU^T from XY^T = USV^T (SVD) + # where X is dataset vectors, Y is quantized dataset vectors + with torch.no_grad(): + y = pq_assign(centroids, vectors) + # paper uses D*N and not N*D in its descriptions for whatever reason (so we transpose when they don't) + u, s, vt = torch.linalg.svd(vectors.T @ y) + projection = vt.T @ u.T + +print("done") + +with open("opq.msgpack", "wb") as f: + msgpack.pack({ + "centroids": centroids.detach().cpu().numpy().flatten().tolist(), + "transform": projection.cpu().numpy().flatten().tolist(), + "n_dims_per_code": n_dims_per_code, + "n_dims": n_dims + }, f) diff --git a/diskann/flamegraph.svg b/diskann/flamegraph.svg new file mode 100644 index 0000000..9b564c4 --- /dev/null +++ b/diskann/flamegraph.svg @@ -0,0 +1,491 @@ +Flame Graph Reset ZoomSearch [[heap]] (10 samples, 0.02%)[libc.so.6] (12 samples, 0.02%)diskann::NeighbourBuffer::insert (19 samples, 0.03%)alloc::vec::Vec<T,A>::insert (19 samples, 0.03%)core::intrinsics::copy (19 samples, 0.03%)diskann::greedy_search_nosearchlist (82 samples, 0.13%)std::collections::hash::set::HashSet<T,S>::insert (59 samples, 0.10%)hashbrown::set::HashSet<T,S,A>::insert (59 samples, 0.10%)[[stack]] (186 samples, 0.30%)hashbrown::map::HashMap<K,V,S,A>::insert (77 samples, 0.12%)hashbrown::raw::RawTable<T,A>::find_or_find_insert_slot (8 samples, 0.01%)hashbrown::raw::RawTableInner::find_or_find_insert_slot_inner (8 samples, 0.01%)core::slice::sort::shared::pivot::median3_rec (10 samples, 0.02%)[anon] (13 samples, 0.02%)core::slice::sort::shared::pivot::median3_rec (9 samples, 0.01%)core::slice::sort::shared::smallsort::small_sort_general (8 samples, 0.01%)core::slice::sort::unstable::quicksort::quicksort (13 samples, 0.02%)diskann::NeighbourBuffer::insert (9 samples, 0.01%)alloc::vec::Vec<T,A>::insert (8 samples, 0.01%)core::intrinsics::copy (8 samples, 0.01%)core::core_arch::x86::f16c::_mm256_cvtph_ps (12 samples, 0.02%)diskann::vector::fast_dot (27 samples, 0.04%)core::core_arch::x86::sse::_mm_prefetch (10 samples, 0.02%)diskann::greedy_search_nosearchlist (49 samples, 0.08%)std::collections::hash::set::HashSet<T,S>::insert (13 samples, 0.02%)hashbrown::set::HashSet<T,S,A>::insert (13 samples, 0.02%)core::core_arch::x86::f16c::_mm256_cvtph_ps (45 samples, 0.07%)diskann::robust_prune (68 samples, 0.11%)diskann::vector::fast_dot (64 samples, 0.10%)[unknown] (217 samples, 0.35%)hashbrown::map::HashMap<K,V,S,A>::insert (57 samples, 0.09%)hashbrown::map::make_hash (41 samples, 0.07%)core::hash::BuildHasher::hash_one (41 samples, 0.07%)<foldhash::seed::fast::RandomState as core::hash::BuildHasher>::build_hasher (41 samples, 0.07%)foldhash::fast::FoldHasher::with_seed (41 samples, 0.07%)<T as alloc::borrow::ToOwned>::to_owned (22 samples, 0.04%)<alloc::vec::Vec<T,A> as core::clone::Clone>::clone (22 samples, 0.04%)alloc::slice::<impl [T]>::to_vec_in (22 samples, 0.04%)alloc::slice::hack::to_vec (22 samples, 0.04%)<T as alloc::slice::hack::ConvertVec>::to_vec (22 samples, 0.04%)alloc::vec::Vec<T,A>::with_capacity_in (22 samples, 0.04%)alloc::raw_vec::RawVec<T,A>::with_capacity_in (22 samples, 0.04%)alloc::raw_vec::RawVec<T,A>::try_allocate_in (22 samples, 0.04%)<alloc::alloc::Global as core::alloc::Allocator>::allocate (22 samples, 0.04%)alloc::alloc::Global::alloc_impl (22 samples, 0.04%)alloc::alloc::alloc (22 samples, 0.04%)malloc (15 samples, 0.02%)alloc::vec::Vec<T,A>::with_capacity_in (7 samples, 0.01%)alloc::raw_vec::RawVec<T,A>::with_capacity_in (7 samples, 0.01%)alloc::raw_vec::RawVec<T,A>::try_allocate_in (7 samples, 0.01%)alloc::slice::<impl [T]>::to_vec (21 samples, 0.03%)alloc::slice::<impl [T]>::to_vec_in (21 samples, 0.03%)alloc::slice::hack::to_vec (21 samples, 0.03%)<T as alloc::slice::hack::ConvertVec>::to_vec (21 samples, 0.03%)core::ptr::const_ptr::<impl *const T>::copy_to_nonoverlapping (14 samples, 0.02%)core::intrinsics::copy_nonoverlapping (14 samples, 0.02%)[libc.so.6] (14 samples, 0.02%)core::ptr::drop_in_place<alloc::vec::Vec<u32>> (8 samples, 0.01%)core::ptr::drop_in_place<alloc::raw_vec::RawVec<u32>> (8 samples, 0.01%)<alloc::raw_vec::RawVec<T,A> as core::ops::drop::Drop>::drop (8 samples, 0.01%)<alloc::alloc::Global as core::alloc::Allocator>::deallocate (8 samples, 0.01%)alloc::alloc::dealloc (8 samples, 0.01%)cfree (7 samples, 0.01%)core::ptr::drop_in_place<std::sync::rwlock::RwLockWriteGuard<alloc::vec::Vec<u32>>> (9 samples, 0.01%)<std::sync::rwlock::RwLockWriteGuard<T> as core::ops::drop::Drop>::drop (9 samples, 0.01%)std::sys::sync::rwlock::futex::RwLock::write_unlock (8 samples, 0.01%)core::sync::atomic::AtomicU32::fetch_sub (8 samples, 0.01%)core::sync::atomic::atomic_sub (8 samples, 0.01%)core::slice::<impl [T]>::contains (66 samples, 0.11%)<T as core::slice::cmp::SliceContains>::slice_contains (66 samples, 0.11%)<core::slice::iter::Iter<T> as core::iter::traits::iterator::Iterator>::any (66 samples, 0.11%)<T as core::slice::cmp::SliceContains>::slice_contains::_{{closure}} (56 samples, 0.09%)core::cmp::impls::<impl core::cmp::PartialEq for u32>::eq (56 samples, 0.09%)diskann::IndexGraph::out_neighbours_mut (56 samples, 0.09%)std::sync::rwlock::RwLock<T>::write (56 samples, 0.09%)std::sys::sync::rwlock::futex::RwLock::write (53 samples, 0.09%)core::sync::atomic::AtomicU32::compare_exchange_weak (52 samples, 0.08%)core::sync::atomic::atomic_compare_exchange_weak (52 samples, 0.08%)<core::iter::adapters::enumerate::Enumerate<I> as core::iter::traits::iterator::Iterator>::next (11 samples, 0.02%)<core::slice::iter::Iter<T> as core::iter::traits::iterator::Iterator>::next (11 samples, 0.02%)<core::ptr::non_null::NonNull<T> as core::cmp::PartialEq>::eq (11 samples, 0.02%)core::num::<impl usize>::checked_sub (23 samples, 0.04%)<diskann::vector::VectorList as core::ops::index::Index<usize>>::index (62 samples, 0.10%)<alloc::vec::Vec<T,A> as core::ops::index::Index<I>>::index (25 samples, 0.04%)core::slice::index::<impl core::ops::index::Index<I> for [T]>::index (25 samples, 0.04%)<core::ops::range::Range<usize> as core::slice::index::SliceIndex<[T]>>::index (25 samples, 0.04%)core::ptr::mut_ptr::<impl *mut T>::add (16 samples, 0.03%)alloc::vec::Vec<T,A>::push (115 samples, 0.19%)core::ptr::write (14 samples, 0.02%)core::ptr::drop_in_place<std::sync::rwlock::RwLockReadGuard<alloc::vec::Vec<u32>>> (26 samples, 0.04%)<std::sync::rwlock::RwLockReadGuard<T> as core::ops::drop::Drop>::drop (26 samples, 0.04%)std::sys::sync::rwlock::futex::RwLock::read_unlock (26 samples, 0.04%)core::sync::atomic::AtomicU32::fetch_sub (26 samples, 0.04%)core::sync::atomic::atomic_sub (26 samples, 0.04%)core::sync::atomic::AtomicU32::compare_exchange_weak (18 samples, 0.03%)core::sync::atomic::atomic_compare_exchange_weak (18 samples, 0.03%)core::sync::atomic::AtomicU32::load (208 samples, 0.34%)core::sync::atomic::atomic_load (208 samples, 0.34%)diskann::IndexGraph::out_neighbours (231 samples, 0.37%)std::sync::rwlock::RwLock<T>::read (231 samples, 0.37%)std::sys::sync::rwlock::futex::RwLock::read (230 samples, 0.37%)<core::option::Option<T> as core::cmp::PartialEq>::eq (12 samples, 0.02%)core::cmp::impls::<impl core::cmp::PartialEq<&B> for &A>::eq (11 samples, 0.02%)core::cmp::impls::<impl core::cmp::PartialEq for u32>::eq (11 samples, 0.02%)alloc::vec::Vec<T,A>::len (7 samples, 0.01%)core::intrinsics::copy (265 samples, 0.43%)[libc.so.6] (257 samples, 0.42%)alloc::vec::Vec<T,A>::insert (297 samples, 0.48%)<core::cmp::Ordering as core::cmp::PartialEq>::eq (115 samples, 0.19%)core::slice::<impl [T]>::binary_search_by (237 samples, 0.38%)diskann::NeighbourBuffer::insert::_{{closure}} (16 samples, 0.03%)core::cmp::impls::<impl core::cmp::PartialOrd for i64>::partial_cmp (16 samples, 0.03%)core::slice::<impl [T]>::get (9 samples, 0.01%)<usize as core::slice::index::SliceIndex<[T]>>::get (9 samples, 0.01%)diskann::NeighbourBuffer::insert (619 samples, 1.00%)diskann::NeighbourBuffer::next_unvisited (13 samples, 0.02%)asm_sysvec_apic_timer_interrupt (10 samples, 0.02%)sysvec_apic_timer_interrupt (9 samples, 0.01%)__sysvec_apic_timer_interrupt (9 samples, 0.01%)hrtimer_interrupt (8 samples, 0.01%)__hrtimer_run_queues (7 samples, 0.01%)tick_nohz_handler (7 samples, 0.01%)update_process_times (7 samples, 0.01%)core::core_arch::x86::avx::_mm256_add_ps (36 samples, 0.06%)core::core_arch::x86::avx::_mm256_extractf128_ps (36 samples, 0.06%)core::core_arch::x86::avx::_mm256_hadd_ps (47 samples, 0.08%)perf_event_task_tick (7 samples, 0.01%)update_process_times (27 samples, 0.04%)sched_tick (19 samples, 0.03%)task_tick_fair (9 samples, 0.01%)tick_nohz_handler (36 samples, 0.06%)update_wall_time (7 samples, 0.01%)timekeeping_advance (7 samples, 0.01%)__hrtimer_run_queues (40 samples, 0.06%)__sysvec_apic_timer_interrupt (42 samples, 0.07%)hrtimer_interrupt (41 samples, 0.07%)irq_exit_rcu (7 samples, 0.01%)handle_softirqs (7 samples, 0.01%)asm_sysvec_apic_timer_interrupt (57 samples, 0.09%)sysvec_apic_timer_interrupt (51 samples, 0.08%)__perf_event_task_sched_in (12 samples, 0.02%)__intel_pmu_enable_all.isra.0 (12 samples, 0.02%)native_write_msr (12 samples, 0.02%)core::core_arch::x86::f16c::_mm256_cvtph_ps (2,902 samples, 4.71%)core:..asm_sysvec_reschedule_ipi (14 samples, 0.02%)irqentry_exit_to_user_mode (14 samples, 0.02%)schedule (13 samples, 0.02%)__schedule (13 samples, 0.02%)finish_task_switch.isra.0 (13 samples, 0.02%)update_process_times (9 samples, 0.01%)sched_tick (9 samples, 0.01%)__sysvec_apic_timer_interrupt (11 samples, 0.02%)hrtimer_interrupt (11 samples, 0.02%)__hrtimer_run_queues (11 samples, 0.02%)tick_nohz_handler (11 samples, 0.02%)core::core_arch::x86::fma::_mm256_fmadd_ps (780 samples, 1.27%)asm_sysvec_apic_timer_interrupt (14 samples, 0.02%)sysvec_apic_timer_interrupt (14 samples, 0.02%)core::core_arch::x86::sse::_mm_add_ps (48 samples, 0.08%)diskann::vector::fast_dot (8,772 samples, 14.23%)diskann::vector::fast_..core::core_arch::x86::sse::_mm_prefetch (4,032 samples, 6.54%)core::cor..diskann::vector::fast_dot_noprefetch (19 samples, 0.03%)<foldhash::fast::FoldHasher as core::hash::Hasher>::finish (48 samples, 0.08%)foldhash::folded_multiply (12 samples, 0.02%)hashbrown::map::make_hash (49 samples, 0.08%)core::hash::BuildHasher::hash_one (49 samples, 0.08%)hashbrown::raw::RawTable<T,A>::reserve (30 samples, 0.05%)asm_sysvec_apic_timer_interrupt (7 samples, 0.01%)hashbrown::raw::bitmask::BitMask::lowest_set_bit (162 samples, 0.26%)hashbrown::raw::bitmask::BitMask::nonzero_trailing_zeros (22 samples, 0.04%)core::num::nonzero::NonZero<u16>::trailing_zeros (22 samples, 0.04%)<hashbrown::raw::bitmask::BitMaskIter as core::iter::traits::iterator::Iterator>::next (165 samples, 0.27%)core::option::Option<T>::is_none (32 samples, 0.05%)core::option::Option<T>::is_some (32 samples, 0.05%)hashbrown::raw::RawTable<T,A>::find_or_find_insert_slot::_{{closure}} (23 samples, 0.04%)hashbrown::map::equivalent_key::_{{closure}} (23 samples, 0.04%)<Q as hashbrown::Equivalent<K>>::equivalent (23 samples, 0.04%)core::cmp::impls::<impl core::cmp::PartialEq<&B> for &A>::eq (23 samples, 0.04%)hashbrown::raw::bitmask::BitMask::lowest_set_bit (18 samples, 0.03%)hashbrown::raw::RawTableInner::find_insert_slot_in_group (44 samples, 0.07%)hashbrown::raw::sse2::Group::match_empty_or_deleted (14 samples, 0.02%)core::core_arch::x86::sse2::_mm_movemask_epi8 (14 samples, 0.02%)hashbrown::raw::RawTableInner::fix_insert_slot (56 samples, 0.09%)hashbrown::raw::RawTableInner::is_bucket_full (27 samples, 0.04%)hashbrown::raw::bitmask::BitMask::any_bit_set (12 samples, 0.02%)hashbrown::raw::h2 (50 samples, 0.08%)hashbrown::raw::sse2::Group::load (122 samples, 0.20%)core::core_arch::x86::sse2::_mm_loadu_si128 (122 samples, 0.20%)core::intrinsics::copy_nonoverlapping (122 samples, 0.20%)hashbrown::raw::sse2::Group::match_byte (77 samples, 0.12%)core::core_arch::x86::sse2::_mm_cmpeq_epi8 (77 samples, 0.12%)hashbrown::raw::RawTable<T,A>::find_or_find_insert_slot (842 samples, 1.37%)hashbrown::raw::RawTableInner::find_or_find_insert_slot_inner (810 samples, 1.31%)core::ptr::mut_ptr::<impl *mut T>::write (37 samples, 0.06%)core::ptr::write (37 samples, 0.06%)hashbrown::raw::Bucket<T>::write (38 samples, 0.06%)hashbrown::raw::RawTable<T,A>::bucket (18 samples, 0.03%)hashbrown::raw::Bucket<T>::from_base_index (18 samples, 0.03%)core::ptr::mut_ptr::<impl *mut T>::sub (18 samples, 0.03%)core::ptr::mut_ptr::<impl *mut T>::offset (18 samples, 0.03%)core::convert::num::<impl core::convert::From<bool> for usize>::from (7 samples, 0.01%)hashbrown::raw::RawTableInner::set_ctrl_h2 (23 samples, 0.04%)hashbrown::raw::RawTableInner::set_ctrl (23 samples, 0.04%)diskann::greedy_search_nosearchlist (11,299 samples, 18.33%)diskann::greedy_search_nosear..std::collections::hash::set::HashSet<T,S>::insert (1,021 samples, 1.66%)hashbrown::set::HashSet<T,S,A>::insert (1,021 samples, 1.66%)hashbrown::map::HashMap<K,V,S,A>::insert (1,019 samples, 1.65%)hashbrown::raw::RawTable<T,A>::insert_in_slot (126 samples, 0.20%)hashbrown::raw::RawTableInner::record_item_insert_at (70 samples, 0.11%)<diskann::vector::VectorList as core::ops::index::Index<usize>>::index (9 samples, 0.01%)alloc::vec::Vec<T,A>::push (8 samples, 0.01%)core::core_arch::x86::avx::_mm256_extractf128_ps (11 samples, 0.02%)core::core_arch::x86::f16c::_mm256_cvtph_ps (397 samples, 0.64%)core::core_arch::x86::fma::_mm256_fmadd_ps (127 samples, 0.21%)core::core_arch::x86::sse::_mm_add_ps (10 samples, 0.02%)diskann::merge_existing_neighbours (1,074 samples, 1.74%)diskann::vector::fast_dot (1,040 samples, 1.69%)core::core_arch::x86::sse::_mm_prefetch (310 samples, 0.50%)alloc::vec::Vec<T,A>::as_ptr (139 samples, 0.23%)alloc::raw_vec::RawVec<T,A>::ptr (139 samples, 0.23%)<alloc::vec::Vec<T,A> as core::ops::deref::Deref>::deref (243 samples, 0.39%)<alloc::vec::Vec<T,A> as core::ops::index::Index<I>>::index (618 samples, 1.00%)core::slice::index::<impl core::ops::index::Index<I> for [T]>::index (375 samples, 0.61%)<usize as core::slice::index::SliceIndex<[T]>>::index (375 samples, 0.61%)<core::iter::adapters::enumerate::Enumerate<I> as core::iter::traits::iterator::Iterator>::next (20 samples, 0.03%)<core::slice::iter::Iter<T> as core::iter::traits::iterator::Iterator>::next (20 samples, 0.03%)<core::ptr::non_null::NonNull<T> as core::cmp::PartialEq>::eq (14 samples, 0.02%)<diskann::vector::VectorList as core::ops::index::Index<usize>>::index (442 samples, 0.72%)<alloc::vec::Vec<T,A> as core::ops::index::Index<I>>::index (141 samples, 0.23%)core::slice::index::<impl core::ops::index::Index<I> for [T]>::index (141 samples, 0.23%)<core::ops::range::Range<usize> as core::slice::index::SliceIndex<[T]>>::index (141 samples, 0.23%)core::num::<impl usize>::checked_sub (133 samples, 0.22%)alloc::vec::Vec<T,A>::len (15 samples, 0.02%)alloc::vec::Vec<T,A>::push (498 samples, 0.81%)core::ptr::write (218 samples, 0.35%)core::iter::range::<impl core::iter::traits::iterator::Iterator for core::ops::range::Range<A>>::next (22 samples, 0.04%)<core::ops::range::Range<T> as core::iter::range::RangeIteratorImpl>::spec_next (22 samples, 0.04%)core::cmp::impls::<impl core::cmp::PartialOrd for usize>::lt (21 samples, 0.03%)core::slice::sort::shared::smallsort::bidirectional_merge (11 samples, 0.02%)core::slice::sort::shared::smallsort::insert_tail (10 samples, 0.02%)core::slice::sort::shared::smallsort::small_sort_general (30 samples, 0.05%)core::slice::sort::shared::smallsort::small_sort_general_with_scratch (30 samples, 0.05%)core::slice::sort::shared::pivot::choose_pivot (14 samples, 0.02%)core::slice::sort::shared::pivot::median3_rec (11 samples, 0.02%)core::slice::sort::shared::pivot::median3_rec (7 samples, 0.01%)core::slice::sort::shared::smallsort::bidirectional_merge (26 samples, 0.04%)core::slice::sort::shared::smallsort::merge_up (12 samples, 0.02%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (7 samples, 0.01%)core::slice::sort::shared::smallsort::insert_tail (11 samples, 0.02%)core::slice::sort::shared::smallsort::merge_down (7 samples, 0.01%)core::slice::sort::shared::smallsort::bidirectional_merge (9 samples, 0.01%)core::slice::sort::shared::smallsort::small_sort_general (61 samples, 0.10%)core::slice::sort::shared::smallsort::small_sort_general_with_scratch (61 samples, 0.10%)core::slice::sort::shared::smallsort::sort8_stable (16 samples, 0.03%)core::slice::sort::shared::smallsort::sort4_stable (7 samples, 0.01%)core::intrinsics::copy (30 samples, 0.05%)core::intrinsics::copy_nonoverlapping (24 samples, 0.04%)core::cmp::impls::<impl core::cmp::PartialOrd for i64>::lt (29 samples, 0.05%)core::slice::sort::unstable::quicksort::partition (124 samples, 0.20%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic (121 samples, 0.20%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic::_{{closure}} (85 samples, 0.14%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (31 samples, 0.05%)core::slice::sort::shared::pivot::choose_pivot (9 samples, 0.01%)core::slice::sort::shared::smallsort::merge_down (8 samples, 0.01%)core::intrinsics::copy_nonoverlapping (7 samples, 0.01%)core::slice::sort::shared::smallsort::bidirectional_merge (47 samples, 0.08%)core::slice::sort::shared::smallsort::merge_up (20 samples, 0.03%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (7 samples, 0.01%)core::slice::sort::shared::smallsort::insert_tail (31 samples, 0.05%)core::slice::sort::shared::smallsort::bidirectional_merge (10 samples, 0.02%)core::slice::sort::shared::smallsort::small_sort_general (114 samples, 0.18%)core::slice::sort::shared::smallsort::small_sort_general_with_scratch (114 samples, 0.18%)core::slice::sort::shared::smallsort::sort8_stable (19 samples, 0.03%)core::slice::sort::shared::smallsort::sort4_stable (9 samples, 0.01%)core::intrinsics::copy (35 samples, 0.06%)core::intrinsics::copy_nonoverlapping (15 samples, 0.02%)core::cmp::impls::<impl core::cmp::PartialOrd for i64>::lt (31 samples, 0.05%)core::slice::sort::unstable::quicksort::partition (113 samples, 0.18%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic (110 samples, 0.18%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic::_{{closure}} (83 samples, 0.13%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (33 samples, 0.05%)core::intrinsics::copy_nonoverlapping (8 samples, 0.01%)core::slice::sort::shared::smallsort::merge_down (17 samples, 0.03%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (10 samples, 0.02%)core::intrinsics::copy_nonoverlapping (10 samples, 0.02%)core::slice::sort::shared::smallsort::bidirectional_merge (55 samples, 0.09%)core::slice::sort::shared::smallsort::merge_up (20 samples, 0.03%)core::slice::sort::shared::smallsort::insert_tail (42 samples, 0.07%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (7 samples, 0.01%)core::slice::sort::shared::smallsort::small_sort_general (140 samples, 0.23%)core::slice::sort::shared::smallsort::small_sort_general_with_scratch (140 samples, 0.23%)core::slice::sort::shared::smallsort::sort8_stable (21 samples, 0.03%)core::slice::sort::shared::smallsort::sort4_stable (16 samples, 0.03%)core::intrinsics::copy (24 samples, 0.04%)core::intrinsics::copy_nonoverlapping (8 samples, 0.01%)core::cmp::impls::<impl core::cmp::PartialOrd for i64>::lt (21 samples, 0.03%)core::slice::sort::unstable::quicksort::partition (83 samples, 0.13%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic (80 samples, 0.13%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic::_{{closure}} (56 samples, 0.09%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (23 samples, 0.04%)core::intrinsics::copy_nonoverlapping (8 samples, 0.01%)core::slice::sort::shared::smallsort::merge_down (15 samples, 0.02%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (7 samples, 0.01%)core::intrinsics::copy_nonoverlapping (10 samples, 0.02%)core::cmp::impls::<impl core::cmp::PartialOrd for i64>::lt (9 samples, 0.01%)core::slice::sort::shared::smallsort::bidirectional_merge (79 samples, 0.13%)core::slice::sort::shared::smallsort::merge_up (28 samples, 0.05%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (11 samples, 0.02%)core::slice::sort::shared::smallsort::insert_tail (46 samples, 0.07%)core::slice::sort::shared::smallsort::bidirectional_merge (8 samples, 0.01%)core::slice::sort::shared::smallsort::small_sort_general (161 samples, 0.26%)core::slice::sort::shared::smallsort::small_sort_general_with_scratch (161 samples, 0.26%)core::slice::sort::shared::smallsort::sort8_stable (15 samples, 0.02%)core::slice::sort::shared::smallsort::sort4_stable (7 samples, 0.01%)core::intrinsics::copy (16 samples, 0.03%)core::intrinsics::copy_nonoverlapping (13 samples, 0.02%)core::cmp::impls::<impl core::cmp::PartialOrd for i64>::lt (16 samples, 0.03%)core::slice::sort::unstable::quicksort::partition (65 samples, 0.11%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic (62 samples, 0.10%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic::_{{closure}} (49 samples, 0.08%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (18 samples, 0.03%)core::slice::sort::shared::smallsort::merge_down (10 samples, 0.02%)core::intrinsics::copy_nonoverlapping (8 samples, 0.01%)core::slice::sort::shared::smallsort::bidirectional_merge (50 samples, 0.08%)core::slice::sort::shared::smallsort::merge_up (23 samples, 0.04%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (9 samples, 0.01%)core::slice::sort::shared::smallsort::insert_tail (29 samples, 0.05%)core::slice::sort::shared::smallsort::small_sort_general (102 samples, 0.17%)core::slice::sort::shared::smallsort::small_sort_general_with_scratch (102 samples, 0.17%)core::slice::sort::shared::smallsort::sort8_stable (10 samples, 0.02%)core::intrinsics::copy (9 samples, 0.01%)core::cmp::impls::<impl core::cmp::PartialOrd for i64>::lt (11 samples, 0.02%)core::slice::sort::unstable::quicksort::partition (40 samples, 0.06%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic (39 samples, 0.06%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic::_{{closure}} (28 samples, 0.05%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (13 samples, 0.02%)core::slice::sort::shared::smallsort::merge_down (9 samples, 0.01%)core::slice::sort::shared::smallsort::bidirectional_merge (29 samples, 0.05%)core::slice::sort::shared::smallsort::merge_up (8 samples, 0.01%)core::slice::sort::shared::smallsort::insert_tail (11 samples, 0.02%)core::slice::sort::shared::smallsort::small_sort_general (57 samples, 0.09%)core::slice::sort::shared::smallsort::small_sort_general_with_scratch (57 samples, 0.09%)core::slice::sort::shared::smallsort::sort8_stable (8 samples, 0.01%)core::intrinsics::copy (8 samples, 0.01%)core::slice::sort::unstable::quicksort::partition (25 samples, 0.04%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic (23 samples, 0.04%)core::slice::sort::unstable::quicksort::partition_lomuto_branchless_cyclic::_{{closure}} (20 samples, 0.03%)core::slice::_<impl [T]>::sort_unstable_by_key::_{{closure}} (8 samples, 0.01%)core::cmp::impls::<impl core::cmp::PartialOrd for i64>::lt (8 samples, 0.01%)core::slice::sort::shared::smallsort::bidirectional_merge (8 samples, 0.01%)core::slice::sort::shared::smallsort::insert_tail (7 samples, 0.01%)core::slice::sort::shared::smallsort::small_sort_general (22 samples, 0.04%)core::slice::sort::shared::smallsort::small_sort_general_with_scratch (22 samples, 0.04%)core::slice::<impl [T]>::sort_unstable_by_key (1,231 samples, 2.00%)c..core::slice::sort::unstable::sort (1,231 samples, 2.00%)c..core::slice::sort::unstable::quicksort::quicksort (1,197 samples, 1.94%)c..core::slice::sort::unstable::quicksort::quicksort (994 samples, 1.61%)core::slice::sort::unstable::quicksort::quicksort (756 samples, 1.23%)core::slice::sort::unstable::quicksort::quicksort (522 samples, 0.85%)core::slice::sort::unstable::quicksort::quicksort (280 samples, 0.45%)core::slice::sort::unstable::quicksort::quicksort (127 samples, 0.21%)core::slice::sort::unstable::quicksort::quicksort (41 samples, 0.07%)core::slice::sort::unstable::quicksort::quicksort (10 samples, 0.02%)sched_tick (14 samples, 0.02%)update_process_times (24 samples, 0.04%)__hrtimer_run_queues (32 samples, 0.05%)tick_nohz_handler (29 samples, 0.05%)hrtimer_interrupt (34 samples, 0.06%)__sysvec_apic_timer_interrupt (36 samples, 0.06%)asm_sysvec_apic_timer_interrupt (47 samples, 0.08%)sysvec_apic_timer_interrupt (43 samples, 0.07%)irq_exit_rcu (7 samples, 0.01%)core::core_arch::x86::avx::_mm256_add_ps (315 samples, 0.51%)core::core_arch::x86::avx::_mm256_extractf128_ps (368 samples, 0.60%)core::core_arch::x86::avx::_mm256_hadd_ps (522 samples, 0.85%)handle_softirqs (7 samples, 0.01%)asm_common_interrupt (9 samples, 0.01%)common_interrupt (9 samples, 0.01%)irq_exit_rcu (8 samples, 0.01%)update_curr (12 samples, 0.02%)task_tick_fair (31 samples, 0.05%)sched_tick (49 samples, 0.08%)update_process_times (65 samples, 0.11%)tick_nohz_handler (78 samples, 0.13%)update_wall_time (12 samples, 0.02%)timekeeping_advance (12 samples, 0.02%)timekeeping_update (7 samples, 0.01%)__sysvec_apic_timer_interrupt (86 samples, 0.14%)hrtimer_interrupt (86 samples, 0.14%)__hrtimer_run_queues (86 samples, 0.14%)handle_softirqs (16 samples, 0.03%)irq_exit_rcu (17 samples, 0.03%)asm_sysvec_apic_timer_interrupt (114 samples, 0.18%)sysvec_apic_timer_interrupt (104 samples, 0.17%)core::core_arch::x86::f16c::_mm256_cvtph_ps (23,478 samples, 38.08%)core::core_arch::x86::f16c::_mm256_cvtph_pstask_tick_fair (10 samples, 0.02%)sched_tick (14 samples, 0.02%)update_process_times (17 samples, 0.03%)__hrtimer_run_queues (19 samples, 0.03%)tick_nohz_handler (19 samples, 0.03%)__sysvec_apic_timer_interrupt (23 samples, 0.04%)hrtimer_interrupt (23 samples, 0.04%)asm_sysvec_apic_timer_interrupt (28 samples, 0.05%)sysvec_apic_timer_interrupt (25 samples, 0.04%)core::core_arch::x86::fma::_mm256_fmadd_ps (7,844 samples, 12.72%)core::core_arch::x8..core::core_arch::x86::sse::_mm_add_ps (452 samples, 0.73%)core::iter::traits::iterator::Iterator::for_each::call::_{{closure}} (58,470 samples, 94.83%)core::iter::traits::iterator::Iterator::for_each::call::_{{closure}}diskann::build_graph::_{{closure}} (58,470 samples, 94.83%)diskann::build_graph::_{{closure}}diskann::robust_prune (45,899 samples, 74.44%)diskann::robust_prunediskann::vector::fast_dot (41,202 samples, 66.83%)diskann::vector::fast_dotcore::core_arch::x86::sse::_mm_prefetch (106 samples, 0.17%)<alloc::vec::into_iter::IntoIter<T,A> as core::iter::traits::iterator::Iterator>::fold (58,471 samples, 94.83%)<alloc::vec::into_iter::IntoIter<T,A> as core::iter::traits::iterator::Iterator>::folddiskann::build_graph (58,476 samples, 94.84%)diskann::build_graphcore::iter::traits::iterator::Iterator::for_each (58,476 samples, 94.84%)core::iter::traits::iterator::Iterator::for_each<diskann::vector::VectorList as core::ops::index::Index<usize>>::index (13 samples, 0.02%)alloc::vec::Vec<T,A>::push (22 samples, 0.04%)diskann::IndexGraph::out_neighbours (39 samples, 0.06%)std::sync::rwlock::RwLock<T>::read (39 samples, 0.06%)std::sys::sync::rwlock::futex::RwLock::read (39 samples, 0.06%)core::sync::atomic::AtomicU32::load (37 samples, 0.06%)core::sync::atomic::atomic_load (37 samples, 0.06%)alloc::vec::Vec<T,A>::insert (38 samples, 0.06%)core::intrinsics::copy (32 samples, 0.05%)[libc.so.6] (29 samples, 0.05%)<core::cmp::Ordering as core::cmp::PartialEq>::eq (20 samples, 0.03%)core::slice::<impl [T]>::binary_search_by (42 samples, 0.07%)diskann::NeighbourBuffer::insert (93 samples, 0.15%)core::core_arch::x86::avx::_mm256_add_ps (7 samples, 0.01%)core::core_arch::x86::avx::_mm256_extractf128_ps (7 samples, 0.01%)core::core_arch::x86::avx::_mm256_hadd_ps (8 samples, 0.01%)core::core_arch::x86::f16c::_mm256_cvtph_ps (503 samples, 0.82%)core::core_arch::x86::fma::_mm256_fmadd_ps (126 samples, 0.20%)core::core_arch::x86::sse::_mm_add_ps (13 samples, 0.02%)diskann::vector::fast_dot (1,307 samples, 2.12%)d..core::core_arch::x86::sse::_mm_prefetch (440 samples, 0.71%)hashbrown::map::make_hash (9 samples, 0.01%)core::hash::BuildHasher::hash_one (9 samples, 0.01%)<foldhash::fast::FoldHasher as core::hash::Hasher>::finish (9 samples, 0.01%)hashbrown::raw::RawTable<T,A>::reserve (9 samples, 0.01%)hashbrown::raw::bitmask::BitMask::lowest_set_bit (26 samples, 0.04%)<hashbrown::raw::bitmask::BitMaskIter as core::iter::traits::iterator::Iterator>::next (27 samples, 0.04%)core::option::Option<T>::is_none (11 samples, 0.02%)core::option::Option<T>::is_some (11 samples, 0.02%)hashbrown::raw::RawTableInner::fix_insert_slot (11 samples, 0.02%)hashbrown::raw::RawTableInner::is_bucket_full (9 samples, 0.01%)hashbrown::raw::h2 (19 samples, 0.03%)hashbrown::raw::sse2::Group::load (14 samples, 0.02%)core::core_arch::x86::sse2::_mm_loadu_si128 (14 samples, 0.02%)core::intrinsics::copy_nonoverlapping (14 samples, 0.02%)hashbrown::raw::RawTable<T,A>::find_or_find_insert_slot (146 samples, 0.24%)hashbrown::raw::RawTableInner::find_or_find_insert_slot_inner (136 samples, 0.22%)hashbrown::raw::sse2::Group::match_byte (9 samples, 0.01%)core::core_arch::x86::sse2::_mm_cmpeq_epi8 (9 samples, 0.01%)diskann::greedy_search_nosearchlist (1,727 samples, 2.80%)di..std::collections::hash::set::HashSet<T,S>::insert (181 samples, 0.29%)hashbrown::set::HashSet<T,S,A>::insert (181 samples, 0.29%)hashbrown::map::HashMap<K,V,S,A>::insert (179 samples, 0.29%)hashbrown::raw::RawTable<T,A>::insert_in_slot (23 samples, 0.04%)hashbrown::raw::RawTableInner::record_item_insert_at (18 samples, 0.03%)__lruvec_stat_mod_folio (16 samples, 0.03%)__pte_offset_map_lock (11 samples, 0.02%)folio_add_lru_vma (44 samples, 0.07%)lru_add_fn (30 samples, 0.05%)__folio_throttle_swaprate (17 samples, 0.03%)blk_cgroup_congested (16 samples, 0.03%)get_mem_cgroup_from_mm (16 samples, 0.03%)__count_memcg_events (13 samples, 0.02%)mem_cgroup_commit_charge (22 samples, 0.04%)__mem_cgroup_charge (78 samples, 0.13%)try_charge_memcg (34 samples, 0.06%)post_alloc_hook (24 samples, 0.04%)clear_page_erms (23 samples, 0.04%)get_page_from_freelist (96 samples, 0.16%)rmqueue_bulk (34 samples, 0.06%)folio_prealloc (205 samples, 0.33%)vma_alloc_folio_noprof (110 samples, 0.18%)alloc_pages_mpol_noprof (107 samples, 0.17%)__alloc_pages_noprof (103 samples, 0.17%)do_anonymous_page (300 samples, 0.49%)do_huge_pmd_anonymous_page (13 samples, 0.02%)vma_alloc_folio_noprof (8 samples, 0.01%)alloc_pages_mpol_noprof (8 samples, 0.01%)__alloc_pages_noprof (7 samples, 0.01%)handle_mm_fault (351 samples, 0.57%)lock_mm_and_find_vma (27 samples, 0.04%)find_vma (20 samples, 0.03%)mt_find (14 samples, 0.02%)asm_exc_page_fault (727 samples, 1.18%)exc_page_fault (402 samples, 0.65%)do_user_addr_fault (398 samples, 0.65%)copy_page_to_iter (824 samples, 1.34%)_copy_to_iter (822 samples, 1.33%)filemap_get_read_batch (14 samples, 0.02%)page_cache_ra_order (7 samples, 0.01%)filemap_get_pages (28 samples, 0.05%)entry_SYSCALL_64 (857 samples, 1.39%)do_syscall_64 (857 samples, 1.39%)__x64_sys_read (857 samples, 1.39%)vfs_read (857 samples, 1.39%)[[xfs]] (857 samples, 1.39%)[[xfs]] (857 samples, 1.39%)filemap_read (857 samples, 1.39%)<std::fs::File as std::io::Read>::read_to_end (859 samples, 1.39%)<&std::fs::File as std::io::Read>::read_to_end (859 samples, 1.39%)std::io::default_read_to_end (859 samples, 1.39%)<&std::fs::File as std::io::Read>::read_buf (859 samples, 1.39%)std::sys::pal::unix::fs::File::read_buf (859 samples, 1.39%)std::sys::pal::unix::fd::FileDesc::read_buf (859 samples, 1.39%)read (859 samples, 1.39%)_compound_head (20 samples, 0.03%)folio_remove_rmap_ptes (9 samples, 0.01%)__page_cache_release (16 samples, 0.03%)free_unref_page_commit (24 samples, 0.04%)free_pcppages_bulk (22 samples, 0.04%)tlb_flush_mmu (63 samples, 0.10%)free_pages_and_swap_cache (63 samples, 0.10%)folios_put_refs (57 samples, 0.09%)free_unref_folios (34 samples, 0.06%)diskann::load_file (966 samples, 1.57%)core::ptr::drop_in_place<alloc::vec::Vec<u8>> (102 samples, 0.17%)core::ptr::drop_in_place<alloc::raw_vec::RawVec<u8>> (102 samples, 0.17%)<alloc::raw_vec::RawVec<T,A> as core::ops::drop::Drop>::drop (102 samples, 0.17%)<alloc::alloc::Global as core::alloc::Allocator>::deallocate (102 samples, 0.17%)alloc::alloc::dealloc (102 samples, 0.17%)cfree (102 samples, 0.17%)__munmap (102 samples, 0.17%)entry_SYSCALL_64 (102 samples, 0.17%)do_syscall_64 (102 samples, 0.17%)__x64_sys_munmap (102 samples, 0.17%)__vm_munmap (102 samples, 0.17%)do_vmi_munmap (102 samples, 0.17%)do_vmi_align_munmap (102 samples, 0.17%)unmap_vmas (102 samples, 0.17%)unmap_page_range (102 samples, 0.17%)zap_pte_range (102 samples, 0.17%)_start (61,183 samples, 99.23%)_start__libc_start_main (61,183 samples, 99.23%)__libc_start_main[libc.so.6] (61,183 samples, 99.23%)[libc.so.6]main (61,183 samples, 99.23%)mainstd::rt::lang_start_internal (61,183 samples, 99.23%)std::rt::lang_start_internalstd::panic::catch_unwind (61,183 samples, 99.23%)std::panic::catch_unwindstd::panicking::try (61,183 samples, 99.23%)std::panicking::trystd::panicking::try::do_call (61,183 samples, 99.23%)std::panicking::try::do_callstd::rt::lang_start_internal::_{{closure}} (61,183 samples, 99.23%)std::rt::lang_start_internal::_{{closure}}std::panic::catch_unwind (61,183 samples, 99.23%)std::panic::catch_unwindstd::panicking::try (61,183 samples, 99.23%)std::panicking::trystd::panicking::try::do_call (61,183 samples, 99.23%)std::panicking::try::do_callcore::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once (61,183 samples, 99.23%)core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_oncestd::rt::lang_start::_{{closure}} (61,183 samples, 99.23%)std::rt::lang_start::_{{closure}}std::sys::backtrace::__rust_begin_short_backtrace (61,183 samples, 99.23%)std::sys::backtrace::__rust_begin_short_backtracecore::ops::function::FnOnce::call_once (61,183 samples, 99.23%)core::ops::function::FnOnce::call_oncediskann::main (61,183 samples, 99.23%)diskann::maindiskann::medioid (7 samples, 0.01%)asm_sysvec_apic_timer_interrupt (36 samples, 0.06%)diskann (61,652 samples, 99.99%)diskannall (61,656 samples, 100%) \ No newline at end of file diff --git a/diskann/opq_test.py b/diskann/opq_test.py new file mode 100644 index 0000000..a15dd99 --- /dev/null +++ b/diskann/opq_test.py @@ -0,0 +1,44 @@ +import numpy as np +import msgpack +import math +import torch +import faiss +import tqdm + +n_dims = 1152 +output_code_size = 64 +output_code_bits = 8 +output_codebook_size = 2**output_code_bits +n_dims_per_code = n_dims // output_code_size +dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) +queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) +device = "cpu" + +def pq_assign(centroids, batch): + quantized = torch.zeros_like(batch) + + # Assign to nearest centroid in each subspace + for dmin in range(0, n_dims, n_dims_per_code): + dmax = dmin + n_dims_per_code + similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T) + assignments = similarities.argmax(dim=1) + quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax] + + return quantized + +with open("opq.msgpack", "rb") as f: + data = msgpack.unpack(f) + centroids = torch.tensor(data["centroids"], device=device).reshape(2**output_code_bits, n_dims) + projection = torch.tensor(data["transform"], device=device).reshape(n_dims, n_dims) + +vectors = torch.tensor(dataset, device=device) +queries = torch.tensor(queryset, device=device) + +sample_size = 64 +qsample = pq_assign(centroids, vectors[:sample_size] @ projection) +print(qsample) +print(vectors[:sample_size]) +exact_results = vectors[:sample_size] @ queries[0] +approx_results = qsample @ (projection @ queries[0]) +print(np.argsort(approx_results)) +print(np.argsort(exact_results)) diff --git a/diskann/rabitq.py b/diskann/rabitq.py new file mode 100644 index 0000000..b606200 --- /dev/null +++ b/diskann/rabitq.py @@ -0,0 +1,68 @@ +# https://arxiv.org/pdf/2405.12497 + +import numpy as np +import msgpack +import math +import tqdm + +n_dims = 1152 +output_dims = 64*8 +scale = 1 / math.sqrt(n_dims) +dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) +queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) +mean = np.mean(dataset, axis=0) + +centered_dataset = dataset - mean +norms = np.linalg.norm(centered_dataset, axis=1) +centered_dataset = centered_dataset / norms[:, np.newaxis] +print(centered_dataset) + +sample = centered_dataset[:64] + +def random_ortho(dim): + h = np.random.randn(dim, dim) + q, r = np.linalg.qr(h) + return q + +p = random_ortho(n_dims) # algorithm only uses the inverse of P, so just sample that directly +p = p[:output_dims, :] + +def quantize(datavecs): + xs = (p @ datavecs.T).T + quantized = xs > 0 + dequantized = scale * (2 * quantized - 1) + dots = np.sum(dequantized * xs, axis=1) # + return quantized, dots + +qsample, dots = quantize(sample) +print(qsample.sum(axis=1).mean()) +#print(dots) +#print(dots.mean()) + +def approx_dot(quantized_samples, dots, query): + mean_to_query = np.dot(mean, query) + print(mean_to_query) + dequantized = scale * (2 * quantized_samples - 1) + query_transformed = p @ query + o_bar_dot_q = np.sum(dequantized * query_transformed, axis=1) + return norms[:sample.shape[0]] * o_bar_dot_q * dots + mean_to_query + +print(norms) +approx_results = approx_dot(qsample, dots, queryset[0]) +exact_results = sample @ queryset[0] + +for x in zip(approx_results, exact_results): + print(*x) + +print(*[ f"{x:.2f}" for x in (approx_results - exact_results) / np.abs(exact_results).mean() ]) + +print(np.argsort(approx_results)) +print(np.argsort(exact_results)) + +with open("rabitq.msgpack", "wb") as f: + msgpack.pack({ + "mean": mean.flatten().tolist(), + "transform": p.flatten().tolist(), + "output_dims": output_dims, + "n_dims": n_dims + }, f) diff --git a/diskann/scalar_quantize.py b/diskann/scalar_quantize.py new file mode 100644 index 0000000..818610b --- /dev/null +++ b/diskann/scalar_quantize.py @@ -0,0 +1,167 @@ +import numpy as np +import msgpack +import math + +n_dims = 1152 +n_buckets = n_dims +#n_buckets = n_dims // 2 # we now have one quant scale per pair of components +#pair_separation = 16 # for efficient dot product computation, we need to have the second element of a pair exactly chunk_size after the first +n_dims_per_bucket = n_dims // n_buckets +data = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # sorry + +CUTOFF = 1e-3 / 2 + +print("computing quantiles") +smin = np.quantile(data, CUTOFF, axis=0) +smax = np.quantile(data, 1 - CUTOFF, axis=0) + +# naive O(n²) greedy algorithm +# probably overbuilt for the 2-components-per-bucket case but I'm not getting rid of it +def assign_buckets(): + import random + intervals = list(enumerate(zip(smin, smax))) + random.shuffle(intervals) + buckets = [ [ intervals.pop() ] for _ in range(n_buckets) ] + def bucket_cost(bucket): + bmin = min(cmin for id, (cmin, cmax) in bucket) + bmax = max(cmax for id, (cmin, cmax) in bucket) + #print("MIN", bmin, "MAX", bmax) + return sum(abs(cmin - bmin) + abs(cmax - bmax) for id, (cmin, cmax) in bucket) + while len(intervals): + for bucket in buckets: + def new_interval_cost(interval): + return bucket_cost(bucket + [interval[1]]) + i, interval = min(enumerate(intervals), key=new_interval_cost) + bucket.append(intervals.pop(i)) + return buckets + +ranges = smax - smin +# TODO: it is possible to do better assignment to buckets +#order = np.argsort(ranges) +print("bucket assignment") +order = np.arange(n_dims) # np.concatenate(np.stack([ [ id for id, (cmin, cmax) in bucket ] for bucket in assign_buckets() ])) + +bucket_ranges = [] +bucket_centres = [] +bucket_absmax = [] +bucket_gmins = [] + +for bucket_min in range(0, n_dims, n_dims_per_bucket): + bucket_max = bucket_min + n_dims_per_bucket + indices = order[bucket_min:bucket_max] + gmin = float(np.min(smin[indices])) + gmax = float(np.max(smax[indices])) + bucket_range = gmax - gmin + bucket_centre = (gmax + gmin) / 2 + bucket_gmins.append(gmin) + bucket_ranges.append(bucket_range) + bucket_centres.append(bucket_centre) + bucket_absmax.append(max(abs(gmin), abs(gmax))) + +print("determining scales") +scales = [] # multiply by float and convert to quantize +offsets = [] +q_offsets = [] # int16 value to add at dot time +q_scales = [] # rescales channel up at dot time; must be proportional(ish) to square of scale factor but NOT cause overflow in accumulation or PLMULLW +scale_factor_bound = float("inf") +for bucket in range(n_buckets): + step_size = bucket_ranges[bucket] / 255 + scales.append(1 / step_size) + q_offset = int(bucket_gmins[bucket] / step_size) + q_offsets.append(q_offset) + nsfb = (2**31 - 1) / (n_dims_per_bucket * abs((255**2) + 2 * q_offset * 255 + q_offset ** 2)) / 2 + # we are bounded both by overflow in accumulation and PLMULLW (u8 plus offset times scale factor) + scale_factor_bound = min(scale_factor_bound, nsfb, (2**15 - 1) // (q_offset + 255)) + offsets.append(bucket_gmins[bucket]) + +for bucket in range(n_buckets): + sfb = scale_factor_bound / max(map(lambda x: x ** 2, bucket_ranges)) + sf = (bucket_ranges[bucket]) ** 2 * sfb + q_scales.append(int(sf)) + +print(bucket_ranges, bucket_centres, bucket_absmax) +print(scales, offsets, q_offsets, q_scales) + +""" +interleave = np.concatenate([ + np.arange(0, n_dims, n_dims_per_bucket) + a + for a in range(n_dims_per_bucket) +]) +""" + +""" +interleave = np.arange(0, n_dims) +for base in range(0, n_dims, 2 * pair_separation): + interleave[base:base + pair_separation] = np.arange(base, base + 2 * pair_separation, 2) + interleave[base + pair_separation:base + 2 * pair_separation] = np.arange(base + 1, base + 2 * pair_separation + 1, 2) +""" + +#print(bucket_ranges, bucket_centres, order[interleave]) +#print(ranges[order][interleave].tolist()) +#print(ranges.tolist()) + +with open("quantizer.msgpack", "wb") as f: + msgpack.pack({ + "permutation": order.tolist(), + "offsets": offsets, + "scales": scales, + "q_offsets": q_offsets, + "q_scales": q_scales + }, f) + +def rquantize(vec): + out = np.zeros(len(vec), dtype=np.uint8) + for i, p in enumerate(order[interleave]): + bucket = p % n_buckets + raw = vec[i] + raw = (raw - offsets[bucket]) * scales[bucket] + raw = min(max(raw, 0.0), 255.0) + out[p] = round(raw) + return out + +def rdquantize(bytes): + vec = np.zeros(n_dims, dtype=np.float32) + for i, p in enumerate(order[interleave]): + bucket = p % n_buckets + raw = float(bytes[p]) + vec[i] = raw / scales[bucket] + offsets[bucket] + return vec + +def rdot(x, y): + xq_offsets = np.array(q_offsets, dtype=np.int16) + xq_scales = np.array(q_scales, dtype=np.int16) + assert x.shape == y.shape + assert x.dtype == np.uint8 == y.dtype + acc = 0 + for i in range(0, len(x), n_buckets): + x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets + y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets + x1 *= xq_scales + acc += np.dot(x1.astype(np.int32), y1.astype(np.int32)) + return acc + +def cmp(i, j): + return np.dot(data[i], data[j]) / rdot(rquantize(data[i]), rquantize(data[j])) + +def rdot_cmp(a, b): + x = rquantize(a) + y = rquantize(b) + a = a[order[interleave]] + b = b[order[interleave]] + xq_offsets = np.array(q_offsets, dtype=np.int16) + xq_scales = np.array(q_scales, dtype=np.int16) + assert x.shape == y.shape + assert x.dtype == np.uint8 == y.dtype + acc = 0 + for i in range(0, len(x), n_buckets): + x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets + y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets + x1 *= xq_scales + component = np.dot(x1.astype(np.int32), y1.astype(np.int32)) + a1 = a[i:i+n_buckets] + b1 = b[i:i+n_buckets] + component_exact = np.dot(a1, b1) + print(x1, a1, sep="\n") + print(component, component_exact, component / component_exact) + acc += component + return acc diff --git a/diskann/src/lib.rs b/diskann/src/lib.rs new file mode 100644 index 0000000..bd2a866 --- /dev/null +++ b/diskann/src/lib.rs @@ -0,0 +1,389 @@ +#![feature(pointer_is_aligned_to)] +#![feature(test)] + +extern crate test; + +use foldhash::{HashSet, HashMap, HashMapExt, HashSetExt}; +use fastrand::Rng; +use rayon::prelude::*; +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, Mutex}; + +pub mod vector; +use vector::{dot, fast_dot, fast_dot_noprefetch, to_svector, VectorRef, SVector, VectorList}; + +// ParlayANN improves parallelism by not using locks like this and instead using smarter batch operations +// but I don't have enough cores that it matters +#[derive(Debug)] +pub struct IndexGraph { + pub graph: Vec>> +} + +impl IndexGraph { + pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self { + let mut graph = Vec::with_capacity(n); + for _ in 0..n { + let mut adjacency = Vec::with_capacity(capacity); + for _ in 0..r { + adjacency.push(rng.u32(0..(n as u32))); + } + graph.push(RwLock::new(adjacency)); + } + IndexGraph { + graph + } + } + + pub fn empty(n: usize, capacity: usize) -> IndexGraph { + let mut graph = Vec::with_capacity(n); + for _ in 0..n { + graph.push(RwLock::new(Vec::with_capacity(capacity))); + } + IndexGraph { + graph + } + } + + fn out_neighbours(&self, pt: u32) -> RwLockReadGuard> { + self.graph[pt as usize].read().unwrap() + } + + fn out_neighbours_mut(&self, pt: u32) -> RwLockWriteGuard> { + self.graph[pt as usize].write().unwrap() + } +} + +#[derive(Clone, Copy, Debug)] +pub struct IndexBuildConfig { + pub r: usize, + pub r_cap: usize, + pub l: usize, + pub maxc: usize, + pub alpha: i64 +} + + +fn centroid(vecs: &VectorList) -> SVector { + let mut centroid = SVector::zero(vecs.d_emb); + + for (i, vec) in vecs.iter().enumerate() { + let weight = 1.0 / (i + 1) as f32; + centroid += (to_svector(vec) - ¢roid) * weight; + } + + centroid +} + +pub fn medioid(vecs: &VectorList) -> u32 { + let centroid = centroid(vecs).half(); + vecs.iter().map(|vec| dot(vec, &*centroid)).enumerate().max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()).unwrap().0 as u32 +} + +// neighbours list sorted by score descending +// TODO: this may actually be an awful datastructure +#[derive(Clone, Debug)] +pub struct NeighbourBuffer { + pub ids: Vec, + scores: Vec, + visited: Vec, + next_unvisited: Option, + size: usize +} + +impl NeighbourBuffer { + pub fn new(size: usize) -> Self { + NeighbourBuffer { + ids: Vec::with_capacity(size + 1), + scores: Vec::with_capacity(size + 1), + visited: Vec::with_capacity(size + 1), //bitvec::vec::BitVec::with_capacity(size), + next_unvisited: None, + size + } + } + + pub fn next_unvisited(&mut self) -> Option { + //println!("next_unvisited: {:?}", self); + let mut cur = self.next_unvisited? as usize; + let old_cur = cur; + self.visited[cur] = true; + while cur < self.len() && self.visited[cur] { + cur += 1; + } + if cur == self.len() { + self.next_unvisited = None; + } else { + self.next_unvisited = Some(cur as u32); + } + Some(self.ids[old_cur]) + } + + pub fn len(&self) -> usize { + self.ids.len() + } + + pub fn cap(&self) -> usize { + self.size + } + + pub fn insert(&mut self, id: u32, score: i64) { + if self.len() == self.cap() && self.scores[self.len() - 1] > score { + return; + } + + let loc = match self.scores.binary_search_by(|x| score.partial_cmp(&x).unwrap()) { + Ok(loc) => loc, + Err(loc) => loc + }; + + if self.ids.get(loc) == Some(&id) { + return; + } + + // slightly inefficient but we avoid unsafe code + self.ids.insert(loc, id); + self.scores.insert(loc, score); + self.visited.insert(loc, false); + self.ids.truncate(self.size); + self.scores.truncate(self.size); + self.visited.truncate(self.size); + + self.next_unvisited = Some(loc as u32); + } + + pub fn clear(&mut self) { + self.ids.clear(); + self.scores.clear(); + self.visited.clear(); + self.next_unvisited = None; + } +} + +pub struct Scratch { + visited: HashSet, + pub neighbour_buffer: NeighbourBuffer, + neighbour_pre_buffer: Vec, + visited_list: Vec<(u32, i64)>, + robust_prune_scratch_buffer: Vec<(usize, u32)> +} + +impl Scratch { + pub fn new(IndexBuildConfig { l, r, maxc, .. }: IndexBuildConfig) -> Self { + Scratch { + visited: HashSet::with_capacity(l * 8), + neighbour_buffer: NeighbourBuffer::new(l), + neighbour_pre_buffer: Vec::with_capacity(r), + visited_list: Vec::with_capacity(l * 8), + robust_prune_scratch_buffer: Vec::with_capacity(r) + } + } +} + +pub struct GreedySearchCounters { + pub distances: usize +} + +// Algorithm 1 from the DiskANN paper +// We support the dot product metric only, so we want to keep things with the HIGHEST dot product +pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters { + scratch.visited.clear(); + scratch.neighbour_buffer.clear(); + scratch.visited_list.clear(); + + scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vecs[start as usize])); + scratch.visited.insert(start); + + let mut counters = GreedySearchCounters { distances: 0 }; + + while let Some(pt) = scratch.neighbour_buffer.next_unvisited() { + //println!("pt {} {:?}", pt, graph.out_neighbours(pt)); + scratch.neighbour_pre_buffer.clear(); + for &neighbour in graph.out_neighbours(pt).iter() { + if scratch.visited.insert(neighbour) { + scratch.neighbour_pre_buffer.push(neighbour); + } + } + for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { + let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO + let distance = fast_dot(query, &vecs[neighbour as usize], &vecs[next_neighbour as usize]); + counters.distances += 1; + scratch.neighbour_buffer.insert(neighbour, distance); + scratch.visited_list.push((neighbour, distance)); + } + } + + counters +} + +type CandidateList = Vec<(u32, i64)>; + +fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList, config: IndexBuildConfig) { + let p_vec = &vecs[point as usize]; + for (i, &n) in neigh.iter().enumerate() { + let dot = fast_dot(p_vec, &vecs[n as usize], &vecs[neigh[(i + 1) % neigh.len() as usize] as usize]); + candidates.push((n, dot)); + } +} + +// "Robust prune" algorithm, kind of +// The algorithm in the paper does not actually match the code as implemented in microsoft/DiskANN +// and that's slightly different from the one in ParlayANN for no reason +// This is closer to ParlayANN +fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec, vecs: &VectorList, config: IndexBuildConfig) { + neigh.clear(); + + let candidates = &mut scratch.visited_list; + + // distance low to high = score high to low + candidates.sort_unstable_by_key(|&(_id, score)| -score); + candidates.truncate(config.maxc); + + let mut candidate_index = 0; + while neigh.len() < config.r && candidate_index < candidates.len() { + let p_star = candidates[candidate_index].0; + candidate_index += 1; + if p_star == p || p_star == u32::MAX { + continue; + } + + neigh.push(p_star); + + scratch.robust_prune_scratch_buffer.clear(); + + // mark remaining candidates as not-to-be-used if "not much better than" current candidate + for i in (candidate_index+1)..candidates.len() { + let p_prime = candidates[i].0; + if p_prime != u32::MAX { + scratch.robust_prune_scratch_buffer.push((i, p_prime)); + } + } + + for (i, &(ci, p_prime)) in scratch.robust_prune_scratch_buffer.iter().enumerate() { + let next_vec = &vecs[scratch.robust_prune_scratch_buffer[(i + 1) % scratch.robust_prune_scratch_buffer.len()].0 as usize]; + let p_star_prime_score = fast_dot(&vecs[p_prime as usize], &vecs[p_star as usize], next_vec); + let p_prime_p_score = candidates[ci].1; + let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16; + + if alpha_times_p_star_prime_score >= p_prime_p_score { + candidates[ci].0 = u32::MAX; + } + } + } +} + +pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &VectorList, config: IndexBuildConfig) { + assert!(vecs.len() < u32::MAX as usize); + + let mut sigmas: Vec = (0..(vecs.len() as u32)).collect(); + rng.shuffle(&mut sigmas); + + let rng = Mutex::new(rng.fork()); + + //let scratch = &mut Scratch::new(config); + //let mut rng = rng.lock().unwrap(); + sigmas.into_par_iter().for_each_init(|| (Scratch::new(config), rng.lock().unwrap().fork()), |(scratch, rng), sigma_i| { + //sigmas.into_iter().for_each(|sigma_i| { + greedy_search(scratch, medioid, &vecs[sigma_i as usize], vecs, &graph, config); + + { + let n = graph.out_neighbours(sigma_i); + merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs, config); + } + + { + let mut n = graph.out_neighbours_mut(sigma_i); + robust_prune(scratch, sigma_i, &mut *n, vecs, config); + } + + let neighbours = graph.out_neighbours(sigma_i).to_owned(); + for neighbour in neighbours { + let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour); + // To cut down pruning time slightly, allow accumulating more neighbours than usual limit + if neighbour_neighbours.len() == config.r_cap { + let mut n = neighbour_neighbours.to_vec(); + scratch.visited_list.clear(); + merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config); + merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config); + robust_prune(scratch, neighbour, &mut n, vecs, config); + } else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap { + neighbour_neighbours.push(sigma_i); + } + } + }); +} + +// RoarGraph's AcquireNeighbours algorithm is actually almost identical to Vamana/DiskANN's RobustPrune, but with fixed α = 1.0. +// We replace Vamana's random initialization of the graph with Neighbourhood-Aware Projection from RoarGraph - there's no way to use a large enough +// query set that I would be confident in using *only* RoarGraph's algorithm +pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec>, query_knns_bwd: &Vec>, config: IndexBuildConfig, vecs: &VectorList) { + let mut sigmas: Vec = (0..(graph.graph.len() as u32)).collect(); + rng.shuffle(&mut sigmas); + + // Iterate through graph vertices in a random order + let rng = Mutex::new(rng.fork()); + sigmas.into_par_iter().for_each_init(|| (rng.lock().unwrap().fork(), Scratch::new(config)), |(rng, scratch), sigma_i| { + scratch.visited.clear(); + scratch.visited_list.clear(); + scratch.neighbour_pre_buffer.clear(); + for &query_neighbour in query_knns[sigma_i as usize].iter() { + for &projected_neighbour in query_knns_bwd[query_neighbour as usize].iter() { + if scratch.visited.insert(projected_neighbour) { + scratch.neighbour_pre_buffer.push(projected_neighbour); + } + } + } + rng.shuffle(&mut scratch.neighbour_pre_buffer); + scratch.neighbour_pre_buffer.truncate(config.maxc * 2); + for (i, &projected_neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { + let score = fast_dot(&vecs[sigma_i as usize], &vecs[projected_neighbour as usize], &vecs[scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()] as usize]); + scratch.visited_list.push((projected_neighbour, score)); + } + let mut neighbours = graph.out_neighbours_mut(sigma_i); + robust_prune(scratch, sigma_i, &mut *neighbours, vecs, config); + }) +} + +pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec>, query_knns_bwd: Vec>, config: IndexBuildConfig) { + let mut sigmas: Vec = (0..(graph.graph.len() as u32)).collect(); + rng.shuffle(&mut sigmas); + + // Iterate through graph vertices in a random order + let rng = Mutex::new(rng.fork()); + sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| { + let mut neighbours = graph.out_neighbours_mut(sigma_i); + let mut i = 0; + while neighbours.len() < config.r_cap && i < 100 { + let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap(); + let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap(); + if !neighbours.contains(&projected_neighbour) { + neighbours.push(projected_neighbour); + } + i += 1; + } + }) +} + +pub fn random_fill_graph(rng: &mut Rng, graph: &mut IndexGraph, r: usize) { + let rng = Mutex::new(rng.fork()); + (0..graph.graph.len() as u32).into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, i| { + let mut neighbours = graph.out_neighbours_mut(i); + while neighbours.len() < r { + let next = rng.u32(0..(graph.graph.len() as u32)); + if !neighbours.contains(&next) { + neighbours.push(next); + } + } + }); +} + +pub struct Timer(&'static str, std::time::Instant); + +impl Timer { + pub fn new(name: &'static str) -> Self { + Timer(name, std::time::Instant::now()) + } +} + +impl Drop for Timer { + fn drop(&mut self) { + println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32()); + } +} diff --git a/diskann/src/main.rs b/diskann/src/main.rs new file mode 100644 index 0000000..d687424 --- /dev/null +++ b/diskann/src/main.rs @@ -0,0 +1,121 @@ +#![feature(test)] +#![feature(pointer_is_aligned_to)] + +extern crate test; + +use std::{io::Read, time::Instant}; +use anyhow::Result; +use half::f16; + +use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, dot, VectorList, self}, Timer}; +use simsimd::SpatialSimilarity; + +const D_EMB: usize = 1152; + +fn load_file(path: &str, trunc: Option) -> Result { + let mut input = std::fs::File::open(path)?; + let mut buf = Vec::new(); + input.read_to_end(&mut buf)?; + // TODO: this is not particularly efficient + let f16s = bytemuck::cast_slice::<_, f16>(&buf)[0..trunc.unwrap_or(buf.len()/2)].iter().copied().collect(); + Ok(VectorList::from_f16s(f16s, D_EMB)) +} + +const PQ_TEST_SIZE: usize = 1000; + +fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + + { + let file = std::fs::File::open("opq.msgpack")?; + let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?; + let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::>(); + let codes = codec.quantize_batch(&input); + println!("{:?}", codes); + let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::>(); + let query = codec.preprocess_query(&raw_query); + let mut real_scores = vec![]; + for i in 0..PQ_TEST_SIZE { + real_scores.push(SpatialSimilarity::dot(&raw_query, &input[i * D_EMB .. (i + 1) * D_EMB]).unwrap() as f32); + } + let pq_scores = codec.asymmetric_dot_product(&query, &codes); + for (x, y) in real_scores.iter().zip(pq_scores.iter()) { + println!("{} {} {} {}", x, y, x - y, (x - y) / x); + } + } + + let mut rng = fastrand::Rng::with_seed(1); + + let n = 100000; + let vecs = { + let _timer = Timer::new("loaded vectors"); + + &load_file("embeddings.bin", Some(D_EMB * n))? + }; + + let (graph, medioid) = { + let _timer = Timer::new("index built"); + + let mut config = IndexBuildConfig { + r: 64, + r_cap: 80, + l: 128, + maxc: 750, + alpha: 65536, + }; + + let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); + + let medioid = medioid(&vecs); + + build_graph(&mut rng, &mut graph, medioid, &vecs, config); + config.alpha = 58000; + build_graph(&mut rng, &mut graph, medioid, &vecs, config); + + (graph, medioid) + }; + + let mut edge_ctr = 0; + + for adjlist in graph.graph.iter() { + edge_ctr += adjlist.read().unwrap().len(); + } + + println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32); + + let time = Instant::now(); + let mut recall = 0; + let mut cmps_ctr = 0; + let mut cmps = vec![]; + + let mut config = IndexBuildConfig { + r: 64, + r_cap: 64, + l: 50, + alpha: 65536, + maxc: 0, + }; + + let mut scratch = Scratch::new(config); + + for (i, vec) in tqdm::tqdm(vecs.iter().enumerate()) { + let ctr = greedy_search(&mut scratch, medioid, &vec, &vecs, &graph, config); + cmps_ctr += ctr.distances; + cmps.push(ctr.distances); + if scratch.neighbour_buffer.ids[0] == (i as u32) { + recall += 1; + } + } + + cmps.sort(); + + let end = time.elapsed(); + + println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n); + println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n); + println!("median comparisons: {}", cmps[cmps.len() / 2]); + //println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries); + println!("{} QPS", n as f32 / end.as_secs_f32()); + + Ok(()) +} diff --git a/diskann/src/vector.rs b/diskann/src/vector.rs new file mode 100644 index 0000000..2283976 --- /dev/null +++ b/diskann/src/vector.rs @@ -0,0 +1,449 @@ +use core::f32; + +use half::f16; +use simsimd::SpatialSimilarity; +use fastrand::Rng; +use serde::{Serialize, Deserialize}; +use tracing_subscriber::field::RecordFields; + +#[derive(Debug, Clone)] +pub struct Vector(Vec); +#[derive(Debug, Clone)] +pub struct SVector(Vec); + +pub type VectorRef<'a> = &'a [f16]; +pub type QVectorRef<'a> = &'a [u8]; +pub type SVectorRef<'a> = &'a [f32]; + +impl SVector { + pub fn zero(d: usize) -> Self { + SVector(vec![0.0; d]) + } + pub fn half(&self) -> Vector { + Vector(self.0.iter().map(|a| f16::from_f32(*a)).collect()) + } +} + +fn box_muller(rng: &mut Rng) -> f32 { + loop { + let u = rng.f32(); + let v = rng.f32(); + let x = (v * std::f32::consts::TAU).cos() * (-2.0 * u.ln()).sqrt(); + if x.is_finite() { + return x; + } + } +} + +impl Vector { + pub fn zero(d: usize) -> Self { + Vector(vec![f16::from_f32(0.0); d]) + } + + pub fn randn(rng: &mut Rng, d: usize) -> Self { + Vector(Vec::from_iter((0..d).map(|_| f16::from_f32(box_muller(rng))))) + } +} + +// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers +const SCALE: f32 = 281474976710656.0; +const SCALE_F64: f64 = SCALE as f64; + +pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> f32 { + // safety is not real + (simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap()) as f32 +} + +pub fn to_svector(vec: VectorRef) -> SVector { + SVector(vec.iter().map(|a| a.to_f32()).collect()) +} + +impl<'a> std::ops::AddAssign> for SVector { + fn add_assign(&mut self, other: VectorRef<'a>) { + self.0.iter_mut().zip(other.iter()).for_each(|(a, b)| *a += b.to_f32()); + } +} + +impl std::ops::Div for SVector { + type Output = Self; + + fn div(self, b: f32) -> Self::Output { + SVector(self.0.iter().map(|a| a / b).collect()) + } +} + +impl std::ops::Deref for Vector { + type Target = [f16]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::Deref for SVector { + type Target = [f32]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::Add<&SVector> for SVector { + type Output = Self; + + fn add(self, other: &Self) -> Self::Output { + SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a + b).collect()) + } +} + +impl std::ops::Sub<&SVector> for SVector { + type Output = Self; + + fn sub(self, other: &Self) -> Self::Output { + SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a - b).collect()) + } +} + +impl std::ops::AddAssign for SVector { + fn add_assign(&mut self, other: Self) { + self.0.iter_mut().zip(other.0.iter()).for_each(|(a, b)| *a += b); + } +} + +impl std::ops::Mul for SVector { + type Output = Self; + + fn mul(self, other: f32) -> Self { + SVector(self.0.iter().map(|a| *a * other).collect()) + } +} + +#[derive(Debug, Clone)] +pub struct VectorList { + pub d_emb: usize, + pub length: usize, + pub data: Vec +} + +impl std::ops::Index for VectorList { + type Output = [f16]; + + fn index(&self, index: usize) -> &Self::Output { + &self.data[index * self.d_emb..(index + 1) * self.d_emb] + } +} + +pub struct VectorListIterator<'a> { + list: &'a VectorList, + index: usize +} + +impl<'a> Iterator for VectorListIterator<'a> { + type Item = VectorRef<'a>; + + fn next(&mut self) -> Option { + if self.index < self.list.len() { + let ret = &self.list[self.index]; + self.index += 1; + Some(ret) + } else { + None + } + } +} + + +impl VectorList { + pub fn len(&self) -> usize { + self.length + } + + pub fn iter(&self) -> VectorListIterator { + VectorListIterator { + list: self, + index: 0 + } + } + + pub fn empty(d: usize) -> Self { + VectorList { + d_emb: d, + length: 0, + data: Vec::new() + } + } + + pub fn from_f16s(f16s: Vec, d: usize) -> Self { + assert!(f16s.len() % d == 0); + VectorList { + d_emb: d, + length: f16s.len() / d, + data: f16s + } + } + + pub fn push(&mut self, vec: VectorRef) { + self.length += 1; + self.data.extend_from_slice(vec); + } +} + +// SimSIMD has its own, but ours prefetches concurrently, is unrolled more, ignores inconveniently-sized vectors and does a cheaper reduction +// Also, we return an int because floats are annoying (not Ord) +// On Tiger Lake (i5-1135G7) we have about a 3x performance advantage ignoring the prefetching +// (it would be better to use AVX512 for said CPU but this also has to run on Zen 3) +pub fn fast_dot(x: VectorRef, y: VectorRef, prefetch: VectorRef) -> i64 { + use std::arch::x86_64::*; + + debug_assert!(x.len() == y.len()); + debug_assert!(prefetch.len() == x.len()); + debug_assert!(x.len() % 64 == 0); + + // safety is not real + // it's probably fine I guess + unsafe { + let mut x_ptr = x.as_ptr(); + let mut y_ptr = y.as_ptr(); + let end = x_ptr.add(x.len()); + let mut prefetch_ptr = prefetch.as_ptr(); + + let mut acc1 = _mm256_setzero_ps(); + let mut acc2 = _mm256_setzero_ps(); + let mut acc3 = _mm256_setzero_ps(); + let mut acc4 = _mm256_setzero_ps(); + + while x_ptr < end { + // fetch chunks and prefetch next vector + let x1 = _mm256_loadu_si256(x_ptr as *const __m256i); + let y1 = _mm256_loadu_si256(y_ptr as *const __m256i); + let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i); + let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i); + // technically, we only have to do this once per cache line but I don't care enough to test every way to optimize this + _mm_prefetch(prefetch_ptr as *const i8, _MM_HINT_T0); + x_ptr = x_ptr.add(32); // move 16 f16s at a time + y_ptr = y_ptr.add(32); + prefetch_ptr = prefetch_ptr.add(32); + + // unpack f32 to f16 + let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0)); + let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1)); + let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0)); + let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1)); + let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0)); + let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1)); + let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0)); + let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1)); + + acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1); + acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2); + acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3); + acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4); + } + + // reduce + let acc1 = _mm256_add_ps(acc1, acc2); + let acc2 = _mm256_add_ps(acc3, acc4); + + let hsum = _mm256_hadd_ps(acc1, acc2); + let hsum_lo = _mm256_extractf128_ps(hsum, 0); + let hsum_hi = _mm256_extractf128_ps(hsum, 1); + let hsum = _mm_add_ps(hsum_lo, hsum_hi); + + let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32); + (floatval * SCALE) as i64 + } +} + +// same as above, without prefetch pointer +pub fn fast_dot_noprefetch(x: VectorRef, y: VectorRef) -> i64 { + use std::arch::x86_64::*; + + debug_assert!(x.len() == y.len()); + debug_assert!(x.len() % 64 == 0); + + unsafe { + let mut x_ptr = x.as_ptr(); + let mut y_ptr = y.as_ptr(); + let end = x_ptr.add(x.len()); + + let mut acc1 = _mm256_setzero_ps(); + let mut acc2 = _mm256_setzero_ps(); + let mut acc3 = _mm256_setzero_ps(); + let mut acc4 = _mm256_setzero_ps(); + + while x_ptr < end { + let x1 = _mm256_loadu_si256(x_ptr as *const __m256i); + let y1 = _mm256_loadu_si256(y_ptr as *const __m256i); + let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i); + let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i); + x_ptr = x_ptr.add(32); + y_ptr = y_ptr.add(32); + + let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0)); + let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1)); + let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0)); + let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1)); + let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0)); + let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1)); + let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0)); + let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1)); + + acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1); + acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2); + acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3); + acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4); + } + + // reduce + let acc1 = _mm256_add_ps(acc1, acc2); + let acc2 = _mm256_add_ps(acc3, acc4); + + let hsum = _mm256_hadd_ps(acc1, acc2); + let hsum_lo = _mm256_extractf128_ps(hsum, 0); + let hsum_hi = _mm256_extractf128_ps(hsum, 1); + let hsum = _mm_add_ps(hsum_lo, hsum_hi); + + let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32); + (floatval * SCALE) as i64 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProductQuantizer { + centroids: Vec, + transform: Vec, // D*D orthonormal matrix + pub n_dims_per_code: usize, + pub n_dims: usize +} + +// chunk * centroid_index +pub struct DistanceLUT(Vec); + +impl ProductQuantizer { + pub fn apply_transform(&self, x: &[f32]) -> Vec { + let dim = self.n_dims; + let n_vectors = x.len() / dim; + let mut transformed = vec![0.0; n_vectors * dim]; + // transform_matrix (D * D) @ batch.T (D * B) + unsafe { + matrixmultiply::sgemm(dim, dim, n_vectors, 1.0, self.transform.as_ptr(), dim as isize, 1, x.as_ptr(), 1, dim as isize, 0.0, transformed.as_mut_ptr(), 1, dim as isize); + } + transformed + } + + pub fn quantize_batch(&self, x: &[f32]) -> Vec { + // x is B * D + let dim = self.n_dims; + assert_eq!(dim * dim, self.transform.len()); + let n_vectors = x.len() / dim; + let n_centroids = self.centroids.len() / dim; + assert!(n_centroids <= 256); + let transformed = self.apply_transform(&x); // B * D, as we write sgemm result in a weird order + let mut codes = vec![0; n_vectors * dim / self.n_dims_per_code]; + let vec_len_codes = dim / self.n_dims_per_code; + + // B * C buffer of similarity of each vector to each centroid, within subspace + let mut scratch = vec![0.0; n_vectors * n_centroids]; + + for i in 0..(dim / self.n_dims_per_code) { + let offset = i * self.n_dims_per_code; + // transformed_batch[:, range] (B * D_r) @ centroids[:, range].T (D_r * C) + unsafe { + matrixmultiply::sgemm(n_vectors, self.n_dims_per_code, n_centroids, 1.0, transformed.as_ptr().add(offset), dim as isize, 1, self.centroids.as_ptr().add(offset), 1, dim as isize, 0.0, scratch.as_mut_ptr(), n_centroids as isize, 1); + } + // assign this component to best centroid + for i_vec in 0..n_vectors { + let mut best = f32::NEG_INFINITY; + for i_centroid in 0..n_centroids { + let score = scratch[i_vec * n_centroids + i_centroid]; + if score > best { + best = score; + codes[i_vec * vec_len_codes + i] = i_centroid as u8; + } + } + } + } + codes + } + + // not particularly performance-sensitive right now; do unbatched + pub fn preprocess_query(&self, query: &[f32]) -> DistanceLUT { + let transformed = self.apply_transform(query); + let n_chunks = self.n_dims / self.n_dims_per_code; + let n_centroids = self.centroids.len() / self.n_dims; + let mut lut = Vec::with_capacity(n_chunks * n_centroids); + + for i in 0..n_chunks { + let vec_component = &transformed[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code]; + for j in 0..n_centroids { + let centroid = &self.centroids[j * self.n_dims..(j + 1) * self.n_dims]; + let centroid_component = ¢roid[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code]; + let score = SpatialSimilarity::dot(vec_component, centroid_component).unwrap(); + lut.push(score as f32); + } + } + + DistanceLUT(lut) + } + + // compute dot products of query against product-quantized vectors + pub fn asymmetric_dot_product(&self, query: &DistanceLUT, pq_vectors: &[u8]) -> Vec { + let n_chunks = self.n_dims / self.n_dims_per_code; + let n_vectors = pq_vectors.len() / n_chunks; + let mut scores = vec![0.0; n_vectors]; + let n_centroids = self.centroids.len() / self.n_dims; + + for i in 0..n_chunks { + for j in 0..n_vectors { + let code = pq_vectors[j * n_chunks + i]; + let chunk_score = query.0[i * n_centroids + code as usize]; + scores[j] += chunk_score; + } + } + + // I have no idea why but we somehow have significant degradation in search quality + // if this accumulates in integers. As such, do floats and convert at the end. + // I'm sure there are fascinating reasons for this, but God is dead, God remains dead, etc. + scores.into_iter().map(|x| (x * SCALE) as i64).collect() + } +} + +pub fn scale_dot_result(x: f64) -> i64 { + (x * SCALE_F64) as i64 +} + +#[cfg(test)] +mod bench { + use super::*; + use test::Bencher; + + #[bench] + fn bench_dot(be: &mut Bencher) { + let mut rng = fastrand::Rng::with_seed(1); + let a = Vector::randn(&mut rng, 1024); + let b = Vector::randn(&mut rng, 1024); + be.iter(|| { + dot(&a, &b) + }); + } + + #[bench] + fn bench_fastdot(be: &mut Bencher) { + let mut rng = fastrand::Rng::with_seed(1); + let a = Vector::randn(&mut rng, 1024); + let b = Vector::randn(&mut rng, 1024); + be.iter(|| { + fast_dot(&a, &b, &a) + }); + } + + #[bench] + fn bench_fastdot_noprefetch(be: &mut Bencher) { + let mut rng = fastrand::Rng::with_seed(1); + let a = Vector::randn(&mut rng, 1024); + let b = Vector::randn(&mut rng, 1024); + be.iter(|| { + fast_dot_noprefetch(&a, &b) + }); + } +} diff --git a/diskann/vec_dist.py b/diskann/vec_dist.py new file mode 100644 index 0000000..1fbe0a6 --- /dev/null +++ b/diskann/vec_dist.py @@ -0,0 +1,44 @@ +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +A = 0.6 +LOG_A = np.log(A) + +def scale(xs): + return np.sign(xs) * (np.log(np.abs(xs) + A) - LOG_A) + +n_dims = 1152 +n_used_dims = 32 +data = np.frombuffer(open("embeddings.bin", "rb").read(), dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # TODO + +# Create histogram bins +n_bins = 256 +s = __import__("math").sqrt(n_dims) +hist_range = (-1.2, 1.2) +histogram_data = np.zeros((n_used_dims, n_bins)) + +# Calculate histograms for each dimension +for dim in range(n_used_dims): + dbd = data[:, dim] + dbd = (dbd - np.mean(dbd)) / np.std(dbd) + dbd = scale(dbd) + hist, _ = np.histogram(dbd, bins=n_bins, range=hist_range, density=True) + histogram_data[dim] = hist + +# Create heatmap +plt.figure(figsize=(12, 8)) +sns.heatmap(histogram_data, + cmap='viridis', + xticklabels=np.linspace(hist_range[0], hist_range[1], n_bins), + yticklabels=range(n_used_dims), + cbar_kws={'label': 'Density'}) + +plt.xlabel('Value') +plt.ylabel('Dimension') +plt.title('Distribution Heatmap of First 16 Dimensions') + +# Adjust layout to prevent label cutoff +plt.tight_layout() + +plt.show()