1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-07 19:26:39 +00:00

release early draft of index code

This commit is contained in:
osmarks 2024-12-31 23:05:48 +00:00
parent 512b776e10
commit e0cf65204b
12 changed files with 2668 additions and 1 deletions

7
.gitignore vendored
View File

@ -9,4 +9,9 @@ node_modules/*
node_modules
*sqlite3*
thumbtemp
mse-test-db-small
mse-test-db-small
clipfront2/static/bg*
diskann/target
*.bin
*.msgpack
*/flamegraph.svg

748
diskann/Cargo.lock generated Normal file
View File

@ -0,0 +1,748 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "anyhow"
version = "1.0.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775"
[[package]]
name = "autocfg"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "bytemuck"
version = "1.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a"
dependencies = [
"bytemuck_derive",
]
[[package]]
name = "bytemuck_derive"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cc"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
dependencies = [
"shlex",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]]
name = "crossterm"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67"
dependencies = [
"bitflags 1.3.2",
"crossterm_winapi",
"libc",
"mio",
"parking_lot",
"signal-hook",
"signal-hook-mio",
"winapi",
]
[[package]]
name = "crossterm_winapi"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
dependencies = [
"winapi",
]
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "diskann"
version = "0.1.0"
dependencies = [
"anyhow",
"bitvec",
"bytemuck",
"fastrand",
"foldhash",
"half",
"matrixmultiply",
"rayon",
"rmp-serde",
"serde",
"simsimd",
"tqdm",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "either"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "fastrand"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
[[package]]
name = "foldhash"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "half"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"bytemuck",
"cfg-if",
"crunchy",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.164"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
[[package]]
name = "lock_api"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "matrixmultiply"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "mio"
version = "0.8.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
dependencies = [
"libc",
"log",
"wasi",
"windows-sys",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "once_cell"
version = "1.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking_lot"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-targets 0.52.6",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "pin-project-lite"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff"
[[package]]
name = "proc-macro2"
version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "redox_syscall"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
dependencies = [
"bitflags 2.6.0",
]
[[package]]
name = "rmp"
version = "0.8.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4"
dependencies = [
"byteorder",
"num-traits",
"paste",
]
[[package]]
name = "rmp-serde"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db"
dependencies = [
"byteorder",
"rmp",
"serde",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
version = "1.0.215"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.215"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]]
name = "signal-hook-mio"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
dependencies = [
"libc",
"mio",
"signal-hook",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1"
dependencies = [
"libc",
]
[[package]]
name = "simsimd"
version = "6.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be2ad0164e13e58a994d3dd1ff57d44cee87c445708e3acea7ad4f03a47092ce"
dependencies = [
"cc",
]
[[package]]
name = "smallvec"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "syn"
version = "2.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "thread_local"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]]
name = "tqdm"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa2d2932240205a99b65f15d9861992c95fbb8c9fb280b3a1f17a92db6dc611f"
dependencies = [
"anyhow",
"crossterm",
"once_cell",
]
[[package]]
name = "tracing"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"nu-ansi-term",
"sharded-slab",
"smallvec",
"thread_local",
"tracing-core",
"tracing-log",
]
[[package]]
name = "unicode-ident"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows_aarch64_gnullvm 0.48.5",
"windows_aarch64_msvc 0.48.5",
"windows_i686_gnu 0.48.5",
"windows_i686_msvc 0.48.5",
"windows_x86_64_gnu 0.48.5",
"windows_x86_64_gnullvm 0.48.5",
"windows_x86_64_msvc 0.48.5",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
"windows_i686_gnullvm",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]

28
diskann/Cargo.toml Normal file
View File

@ -0,0 +1,28 @@
[package]
name = "diskann"
version = "0.1.0"
edition = "2021"
[dependencies]
half = { version = "2", features = ["bytemuck"] }
fastrand = "2"
tracing = "0.1"
tracing-subscriber = "0.3"
simsimd = "6"
foldhash = "0.1"
bitvec = "1"
tqdm = "0.7"
anyhow = "1"
bytemuck = { version = "1", features = ["extern_crate_alloc"] }
serde = { version = "1", features = ["derive"] }
rmp-serde = "1"
rayon = "1"
matrixmultiply = "0.3"
[lib]
name = "diskann"
path = "src/lib.rs"
[[bin]]
name = "diskann"
path = "src/main.rs"

113
diskann/aopq_train.py Normal file
View File

@ -0,0 +1,113 @@
import numpy as np
import msgpack
import math
import torch
from torch import autograd
import faiss
import tqdm
n_dims = 1152
output_code_size = 64
output_code_bits = 8
output_codebook_size = 2**output_code_bits
n_dims_per_code = n_dims // output_code_size
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
device = "cpu"
index = faiss.index_factory(n_dims, "HNSW32,SQfp16", faiss.METRIC_INNER_PRODUCT)
index.train(queryset)
index.add(queryset)
print("index ready")
T = 64
nearby_query_indices = torch.zeros((dataset.shape[0], T), dtype=torch.int32)
SEARCH_BATCH_SIZE = 1024
for i in range(0, len(dataset), SEARCH_BATCH_SIZE):
res = index.search(dataset[i:i+SEARCH_BATCH_SIZE], T)
nearby_query_indices[i:i+SEARCH_BATCH_SIZE] = torch.tensor(res[1])
print("query indices ready")
def pq_assign(centroids, batch):
quantized = torch.zeros_like(batch)
# Assign to nearest centroid in each subspace
for dmin in range(0, n_dims, n_dims_per_code):
dmax = dmin + n_dims_per_code
similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T)
assignments = similarities.argmax(dim=1)
quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax]
return quantized
# OOD-DiskANN (https://arxiv.org/abs/2211.12850) uses a more complicated scheme because it uses L2 norm
# We only care about inner product so our quantization error (wrt a query) is just abs(dot(query, centroid - vector))
# Directly optimize for this (wrt top queries; it might actually be better to use a random sample instead?)
def partition(vectors, centroids, projection, opt, queries, nearby_query_indices, k, max_iter=100, batch_size=4096):
n_vectors = len(vectors)
perm = torch.randperm(n_vectors, device=device)
t = tqdm.trange(max_iter)
for iter in t:
total_loss = 0
opt.zero_grad(set_to_none=True)
for i in range(0, n_vectors, batch_size):
loss = torch.tensor(0.0, device=device)
batch = vectors[i:i+batch_size] @ projection
quantized = pq_assign(centroids, batch)
residuals = batch - quantized
# for each index in our set of nearby queries
for j in range(0, nearby_query_indices.shape[1]):
queries_for_batch_j = queries[nearby_query_indices[i:i+batch_size, j]]
# minimize quantiation error in direction of query, i.e. mean abs(dot(query, centroid - vector))
# PyTorch won't do batched dot products cleanly, to spite me. Do componentwise multiplication and reduce.
sg_errs = (queries_for_batch_j * residuals).sum(dim=-1)
loss += torch.mean(torch.abs(sg_errs))
total_loss += loss.detach().item()
loss.backward()
opt.step()
t.set_description(f"loss: {total_loss:.4f}")
def random_ortho(dim):
h = torch.randn(dim, dim, device=device)
q, r = torch.linalg.qr(h)
return q
# non-parametric OPQ algorithm (roughly)
# https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/opq_tr.pdf
projection = random_ortho(n_dims)
vectors = torch.tensor(dataset, device=device)
queries = torch.tensor(queryset, device=device)
perm = torch.randperm(len(vectors), device=device)
centroids = vectors[perm[:output_codebook_size]]
centroids.requires_grad = True
opt = torch.optim.Adam([centroids], lr=0.001)
for i in range(30):
# update centroids to minimize query-aware quantization loss
partition(vectors, centroids, projection, opt, queries, nearby_query_indices, output_codebook_size, max_iter=8)
# compute new projection as R = VU^T from XY^T = USV^T (SVD)
# where X is dataset vectors, Y is quantized dataset vectors
with torch.no_grad():
y = pq_assign(centroids, vectors)
# paper uses D*N and not N*D in its descriptions for whatever reason (so we transpose when they don't)
u, s, vt = torch.linalg.svd(vectors.T @ y)
projection = vt.T @ u.T
print("done")
with open("opq.msgpack", "wb") as f:
msgpack.pack({
"centroids": centroids.detach().cpu().numpy().flatten().tolist(),
"transform": projection.cpu().numpy().flatten().tolist(),
"n_dims_per_code": n_dims_per_code,
"n_dims": n_dims
}, f)

491
diskann/flamegraph.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 139 KiB

44
diskann/opq_test.py Normal file
View File

@ -0,0 +1,44 @@
import numpy as np
import msgpack
import math
import torch
import faiss
import tqdm
n_dims = 1152
output_code_size = 64
output_code_bits = 8
output_codebook_size = 2**output_code_bits
n_dims_per_code = n_dims // output_code_size
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
device = "cpu"
def pq_assign(centroids, batch):
quantized = torch.zeros_like(batch)
# Assign to nearest centroid in each subspace
for dmin in range(0, n_dims, n_dims_per_code):
dmax = dmin + n_dims_per_code
similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T)
assignments = similarities.argmax(dim=1)
quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax]
return quantized
with open("opq.msgpack", "rb") as f:
data = msgpack.unpack(f)
centroids = torch.tensor(data["centroids"], device=device).reshape(2**output_code_bits, n_dims)
projection = torch.tensor(data["transform"], device=device).reshape(n_dims, n_dims)
vectors = torch.tensor(dataset, device=device)
queries = torch.tensor(queryset, device=device)
sample_size = 64
qsample = pq_assign(centroids, vectors[:sample_size] @ projection)
print(qsample)
print(vectors[:sample_size])
exact_results = vectors[:sample_size] @ queries[0]
approx_results = qsample @ (projection @ queries[0])
print(np.argsort(approx_results))
print(np.argsort(exact_results))

68
diskann/rabitq.py Normal file
View File

@ -0,0 +1,68 @@
# https://arxiv.org/pdf/2405.12497
import numpy as np
import msgpack
import math
import tqdm
n_dims = 1152
output_dims = 64*8
scale = 1 / math.sqrt(n_dims)
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
mean = np.mean(dataset, axis=0)
centered_dataset = dataset - mean
norms = np.linalg.norm(centered_dataset, axis=1)
centered_dataset = centered_dataset / norms[:, np.newaxis]
print(centered_dataset)
sample = centered_dataset[:64]
def random_ortho(dim):
h = np.random.randn(dim, dim)
q, r = np.linalg.qr(h)
return q
p = random_ortho(n_dims) # algorithm only uses the inverse of P, so just sample that directly
p = p[:output_dims, :]
def quantize(datavecs):
xs = (p @ datavecs.T).T
quantized = xs > 0
dequantized = scale * (2 * quantized - 1)
dots = np.sum(dequantized * xs, axis=1) # <o_bar, o>
return quantized, dots
qsample, dots = quantize(sample)
print(qsample.sum(axis=1).mean())
#print(dots)
#print(dots.mean())
def approx_dot(quantized_samples, dots, query):
mean_to_query = np.dot(mean, query)
print(mean_to_query)
dequantized = scale * (2 * quantized_samples - 1)
query_transformed = p @ query
o_bar_dot_q = np.sum(dequantized * query_transformed, axis=1)
return norms[:sample.shape[0]] * o_bar_dot_q * dots + mean_to_query
print(norms)
approx_results = approx_dot(qsample, dots, queryset[0])
exact_results = sample @ queryset[0]
for x in zip(approx_results, exact_results):
print(*x)
print(*[ f"{x:.2f}" for x in (approx_results - exact_results) / np.abs(exact_results).mean() ])
print(np.argsort(approx_results))
print(np.argsort(exact_results))
with open("rabitq.msgpack", "wb") as f:
msgpack.pack({
"mean": mean.flatten().tolist(),
"transform": p.flatten().tolist(),
"output_dims": output_dims,
"n_dims": n_dims
}, f)

167
diskann/scalar_quantize.py Normal file
View File

@ -0,0 +1,167 @@
import numpy as np
import msgpack
import math
n_dims = 1152
n_buckets = n_dims
#n_buckets = n_dims // 2 # we now have one quant scale per pair of components
#pair_separation = 16 # for efficient dot product computation, we need to have the second element of a pair exactly chunk_size after the first
n_dims_per_bucket = n_dims // n_buckets
data = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # sorry
CUTOFF = 1e-3 / 2
print("computing quantiles")
smin = np.quantile(data, CUTOFF, axis=0)
smax = np.quantile(data, 1 - CUTOFF, axis=0)
# naive O(n²) greedy algorithm
# probably overbuilt for the 2-components-per-bucket case but I'm not getting rid of it
def assign_buckets():
import random
intervals = list(enumerate(zip(smin, smax)))
random.shuffle(intervals)
buckets = [ [ intervals.pop() ] for _ in range(n_buckets) ]
def bucket_cost(bucket):
bmin = min(cmin for id, (cmin, cmax) in bucket)
bmax = max(cmax for id, (cmin, cmax) in bucket)
#print("MIN", bmin, "MAX", bmax)
return sum(abs(cmin - bmin) + abs(cmax - bmax) for id, (cmin, cmax) in bucket)
while len(intervals):
for bucket in buckets:
def new_interval_cost(interval):
return bucket_cost(bucket + [interval[1]])
i, interval = min(enumerate(intervals), key=new_interval_cost)
bucket.append(intervals.pop(i))
return buckets
ranges = smax - smin
# TODO: it is possible to do better assignment to buckets
#order = np.argsort(ranges)
print("bucket assignment")
order = np.arange(n_dims) # np.concatenate(np.stack([ [ id for id, (cmin, cmax) in bucket ] for bucket in assign_buckets() ]))
bucket_ranges = []
bucket_centres = []
bucket_absmax = []
bucket_gmins = []
for bucket_min in range(0, n_dims, n_dims_per_bucket):
bucket_max = bucket_min + n_dims_per_bucket
indices = order[bucket_min:bucket_max]
gmin = float(np.min(smin[indices]))
gmax = float(np.max(smax[indices]))
bucket_range = gmax - gmin
bucket_centre = (gmax + gmin) / 2
bucket_gmins.append(gmin)
bucket_ranges.append(bucket_range)
bucket_centres.append(bucket_centre)
bucket_absmax.append(max(abs(gmin), abs(gmax)))
print("determining scales")
scales = [] # multiply by float and convert to quantize
offsets = []
q_offsets = [] # int16 value to add at dot time
q_scales = [] # rescales channel up at dot time; must be proportional(ish) to square of scale factor but NOT cause overflow in accumulation or PLMULLW
scale_factor_bound = float("inf")
for bucket in range(n_buckets):
step_size = bucket_ranges[bucket] / 255
scales.append(1 / step_size)
q_offset = int(bucket_gmins[bucket] / step_size)
q_offsets.append(q_offset)
nsfb = (2**31 - 1) / (n_dims_per_bucket * abs((255**2) + 2 * q_offset * 255 + q_offset ** 2)) / 2
# we are bounded both by overflow in accumulation and PLMULLW (u8 plus offset times scale factor)
scale_factor_bound = min(scale_factor_bound, nsfb, (2**15 - 1) // (q_offset + 255))
offsets.append(bucket_gmins[bucket])
for bucket in range(n_buckets):
sfb = scale_factor_bound / max(map(lambda x: x ** 2, bucket_ranges))
sf = (bucket_ranges[bucket]) ** 2 * sfb
q_scales.append(int(sf))
print(bucket_ranges, bucket_centres, bucket_absmax)
print(scales, offsets, q_offsets, q_scales)
"""
interleave = np.concatenate([
np.arange(0, n_dims, n_dims_per_bucket) + a
for a in range(n_dims_per_bucket)
])
"""
"""
interleave = np.arange(0, n_dims)
for base in range(0, n_dims, 2 * pair_separation):
interleave[base:base + pair_separation] = np.arange(base, base + 2 * pair_separation, 2)
interleave[base + pair_separation:base + 2 * pair_separation] = np.arange(base + 1, base + 2 * pair_separation + 1, 2)
"""
#print(bucket_ranges, bucket_centres, order[interleave])
#print(ranges[order][interleave].tolist())
#print(ranges.tolist())
with open("quantizer.msgpack", "wb") as f:
msgpack.pack({
"permutation": order.tolist(),
"offsets": offsets,
"scales": scales,
"q_offsets": q_offsets,
"q_scales": q_scales
}, f)
def rquantize(vec):
out = np.zeros(len(vec), dtype=np.uint8)
for i, p in enumerate(order[interleave]):
bucket = p % n_buckets
raw = vec[i]
raw = (raw - offsets[bucket]) * scales[bucket]
raw = min(max(raw, 0.0), 255.0)
out[p] = round(raw)
return out
def rdquantize(bytes):
vec = np.zeros(n_dims, dtype=np.float32)
for i, p in enumerate(order[interleave]):
bucket = p % n_buckets
raw = float(bytes[p])
vec[i] = raw / scales[bucket] + offsets[bucket]
return vec
def rdot(x, y):
xq_offsets = np.array(q_offsets, dtype=np.int16)
xq_scales = np.array(q_scales, dtype=np.int16)
assert x.shape == y.shape
assert x.dtype == np.uint8 == y.dtype
acc = 0
for i in range(0, len(x), n_buckets):
x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets
y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets
x1 *= xq_scales
acc += np.dot(x1.astype(np.int32), y1.astype(np.int32))
return acc
def cmp(i, j):
return np.dot(data[i], data[j]) / rdot(rquantize(data[i]), rquantize(data[j]))
def rdot_cmp(a, b):
x = rquantize(a)
y = rquantize(b)
a = a[order[interleave]]
b = b[order[interleave]]
xq_offsets = np.array(q_offsets, dtype=np.int16)
xq_scales = np.array(q_scales, dtype=np.int16)
assert x.shape == y.shape
assert x.dtype == np.uint8 == y.dtype
acc = 0
for i in range(0, len(x), n_buckets):
x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets
y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets
x1 *= xq_scales
component = np.dot(x1.astype(np.int32), y1.astype(np.int32))
a1 = a[i:i+n_buckets]
b1 = b[i:i+n_buckets]
component_exact = np.dot(a1, b1)
print(x1, a1, sep="\n")
print(component, component_exact, component / component_exact)
acc += component
return acc

389
diskann/src/lib.rs Normal file
View File

@ -0,0 +1,389 @@
#![feature(pointer_is_aligned_to)]
#![feature(test)]
extern crate test;
use foldhash::{HashSet, HashMap, HashMapExt, HashSetExt};
use fastrand::Rng;
use rayon::prelude::*;
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, Mutex};
pub mod vector;
use vector::{dot, fast_dot, fast_dot_noprefetch, to_svector, VectorRef, SVector, VectorList};
// ParlayANN improves parallelism by not using locks like this and instead using smarter batch operations
// but I don't have enough cores that it matters
#[derive(Debug)]
pub struct IndexGraph {
pub graph: Vec<RwLock<Vec<u32>>>
}
impl IndexGraph {
pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self {
let mut graph = Vec::with_capacity(n);
for _ in 0..n {
let mut adjacency = Vec::with_capacity(capacity);
for _ in 0..r {
adjacency.push(rng.u32(0..(n as u32)));
}
graph.push(RwLock::new(adjacency));
}
IndexGraph {
graph
}
}
pub fn empty(n: usize, capacity: usize) -> IndexGraph {
let mut graph = Vec::with_capacity(n);
for _ in 0..n {
graph.push(RwLock::new(Vec::with_capacity(capacity)));
}
IndexGraph {
graph
}
}
fn out_neighbours(&self, pt: u32) -> RwLockReadGuard<Vec<u32>> {
self.graph[pt as usize].read().unwrap()
}
fn out_neighbours_mut(&self, pt: u32) -> RwLockWriteGuard<Vec<u32>> {
self.graph[pt as usize].write().unwrap()
}
}
#[derive(Clone, Copy, Debug)]
pub struct IndexBuildConfig {
pub r: usize,
pub r_cap: usize,
pub l: usize,
pub maxc: usize,
pub alpha: i64
}
fn centroid(vecs: &VectorList) -> SVector {
let mut centroid = SVector::zero(vecs.d_emb);
for (i, vec) in vecs.iter().enumerate() {
let weight = 1.0 / (i + 1) as f32;
centroid += (to_svector(vec) - &centroid) * weight;
}
centroid
}
pub fn medioid(vecs: &VectorList) -> u32 {
let centroid = centroid(vecs).half();
vecs.iter().map(|vec| dot(vec, &*centroid)).enumerate().max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()).unwrap().0 as u32
}
// neighbours list sorted by score descending
// TODO: this may actually be an awful datastructure
#[derive(Clone, Debug)]
pub struct NeighbourBuffer {
pub ids: Vec<u32>,
scores: Vec<i64>,
visited: Vec<bool>,
next_unvisited: Option<u32>,
size: usize
}
impl NeighbourBuffer {
pub fn new(size: usize) -> Self {
NeighbourBuffer {
ids: Vec::with_capacity(size + 1),
scores: Vec::with_capacity(size + 1),
visited: Vec::with_capacity(size + 1), //bitvec::vec::BitVec::with_capacity(size),
next_unvisited: None,
size
}
}
pub fn next_unvisited(&mut self) -> Option<u32> {
//println!("next_unvisited: {:?}", self);
let mut cur = self.next_unvisited? as usize;
let old_cur = cur;
self.visited[cur] = true;
while cur < self.len() && self.visited[cur] {
cur += 1;
}
if cur == self.len() {
self.next_unvisited = None;
} else {
self.next_unvisited = Some(cur as u32);
}
Some(self.ids[old_cur])
}
pub fn len(&self) -> usize {
self.ids.len()
}
pub fn cap(&self) -> usize {
self.size
}
pub fn insert(&mut self, id: u32, score: i64) {
if self.len() == self.cap() && self.scores[self.len() - 1] > score {
return;
}
let loc = match self.scores.binary_search_by(|x| score.partial_cmp(&x).unwrap()) {
Ok(loc) => loc,
Err(loc) => loc
};
if self.ids.get(loc) == Some(&id) {
return;
}
// slightly inefficient but we avoid unsafe code
self.ids.insert(loc, id);
self.scores.insert(loc, score);
self.visited.insert(loc, false);
self.ids.truncate(self.size);
self.scores.truncate(self.size);
self.visited.truncate(self.size);
self.next_unvisited = Some(loc as u32);
}
pub fn clear(&mut self) {
self.ids.clear();
self.scores.clear();
self.visited.clear();
self.next_unvisited = None;
}
}
pub struct Scratch {
visited: HashSet<u32>,
pub neighbour_buffer: NeighbourBuffer,
neighbour_pre_buffer: Vec<u32>,
visited_list: Vec<(u32, i64)>,
robust_prune_scratch_buffer: Vec<(usize, u32)>
}
impl Scratch {
pub fn new(IndexBuildConfig { l, r, maxc, .. }: IndexBuildConfig) -> Self {
Scratch {
visited: HashSet::with_capacity(l * 8),
neighbour_buffer: NeighbourBuffer::new(l),
neighbour_pre_buffer: Vec::with_capacity(r),
visited_list: Vec::with_capacity(l * 8),
robust_prune_scratch_buffer: Vec::with_capacity(r)
}
}
}
pub struct GreedySearchCounters {
pub distances: usize
}
// Algorithm 1 from the DiskANN paper
// We support the dot product metric only, so we want to keep things with the HIGHEST dot product
pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters {
scratch.visited.clear();
scratch.neighbour_buffer.clear();
scratch.visited_list.clear();
scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vecs[start as usize]));
scratch.visited.insert(start);
let mut counters = GreedySearchCounters { distances: 0 };
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
//println!("pt {} {:?}", pt, graph.out_neighbours(pt));
scratch.neighbour_pre_buffer.clear();
for &neighbour in graph.out_neighbours(pt).iter() {
if scratch.visited.insert(neighbour) {
scratch.neighbour_pre_buffer.push(neighbour);
}
}
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
let distance = fast_dot(query, &vecs[neighbour as usize], &vecs[next_neighbour as usize]);
counters.distances += 1;
scratch.neighbour_buffer.insert(neighbour, distance);
scratch.visited_list.push((neighbour, distance));
}
}
counters
}
type CandidateList = Vec<(u32, i64)>;
fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList, config: IndexBuildConfig) {
let p_vec = &vecs[point as usize];
for (i, &n) in neigh.iter().enumerate() {
let dot = fast_dot(p_vec, &vecs[n as usize], &vecs[neigh[(i + 1) % neigh.len() as usize] as usize]);
candidates.push((n, dot));
}
}
// "Robust prune" algorithm, kind of
// The algorithm in the paper does not actually match the code as implemented in microsoft/DiskANN
// and that's slightly different from the one in ParlayANN for no reason
// This is closer to ParlayANN
fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &VectorList, config: IndexBuildConfig) {
neigh.clear();
let candidates = &mut scratch.visited_list;
// distance low to high = score high to low
candidates.sort_unstable_by_key(|&(_id, score)| -score);
candidates.truncate(config.maxc);
let mut candidate_index = 0;
while neigh.len() < config.r && candidate_index < candidates.len() {
let p_star = candidates[candidate_index].0;
candidate_index += 1;
if p_star == p || p_star == u32::MAX {
continue;
}
neigh.push(p_star);
scratch.robust_prune_scratch_buffer.clear();
// mark remaining candidates as not-to-be-used if "not much better than" current candidate
for i in (candidate_index+1)..candidates.len() {
let p_prime = candidates[i].0;
if p_prime != u32::MAX {
scratch.robust_prune_scratch_buffer.push((i, p_prime));
}
}
for (i, &(ci, p_prime)) in scratch.robust_prune_scratch_buffer.iter().enumerate() {
let next_vec = &vecs[scratch.robust_prune_scratch_buffer[(i + 1) % scratch.robust_prune_scratch_buffer.len()].0 as usize];
let p_star_prime_score = fast_dot(&vecs[p_prime as usize], &vecs[p_star as usize], next_vec);
let p_prime_p_score = candidates[ci].1;
let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16;
if alpha_times_p_star_prime_score >= p_prime_p_score {
candidates[ci].0 = u32::MAX;
}
}
}
}
pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &VectorList, config: IndexBuildConfig) {
assert!(vecs.len() < u32::MAX as usize);
let mut sigmas: Vec<u32> = (0..(vecs.len() as u32)).collect();
rng.shuffle(&mut sigmas);
let rng = Mutex::new(rng.fork());
//let scratch = &mut Scratch::new(config);
//let mut rng = rng.lock().unwrap();
sigmas.into_par_iter().for_each_init(|| (Scratch::new(config), rng.lock().unwrap().fork()), |(scratch, rng), sigma_i| {
//sigmas.into_iter().for_each(|sigma_i| {
greedy_search(scratch, medioid, &vecs[sigma_i as usize], vecs, &graph, config);
{
let n = graph.out_neighbours(sigma_i);
merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs, config);
}
{
let mut n = graph.out_neighbours_mut(sigma_i);
robust_prune(scratch, sigma_i, &mut *n, vecs, config);
}
let neighbours = graph.out_neighbours(sigma_i).to_owned();
for neighbour in neighbours {
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
// To cut down pruning time slightly, allow accumulating more neighbours than usual limit
if neighbour_neighbours.len() == config.r_cap {
let mut n = neighbour_neighbours.to_vec();
scratch.visited_list.clear();
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config);
robust_prune(scratch, neighbour, &mut n, vecs, config);
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap {
neighbour_neighbours.push(sigma_i);
}
}
});
}
// RoarGraph's AcquireNeighbours algorithm is actually almost identical to Vamana/DiskANN's RobustPrune, but with fixed α = 1.0.
// We replace Vamana's random initialization of the graph with Neighbourhood-Aware Projection from RoarGraph - there's no way to use a large enough
// query set that I would be confident in using *only* RoarGraph's algorithm
pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec<Vec<u32>>, query_knns_bwd: &Vec<Vec<u32>>, config: IndexBuildConfig, vecs: &VectorList) {
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
rng.shuffle(&mut sigmas);
// Iterate through graph vertices in a random order
let rng = Mutex::new(rng.fork());
sigmas.into_par_iter().for_each_init(|| (rng.lock().unwrap().fork(), Scratch::new(config)), |(rng, scratch), sigma_i| {
scratch.visited.clear();
scratch.visited_list.clear();
scratch.neighbour_pre_buffer.clear();
for &query_neighbour in query_knns[sigma_i as usize].iter() {
for &projected_neighbour in query_knns_bwd[query_neighbour as usize].iter() {
if scratch.visited.insert(projected_neighbour) {
scratch.neighbour_pre_buffer.push(projected_neighbour);
}
}
}
rng.shuffle(&mut scratch.neighbour_pre_buffer);
scratch.neighbour_pre_buffer.truncate(config.maxc * 2);
for (i, &projected_neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
let score = fast_dot(&vecs[sigma_i as usize], &vecs[projected_neighbour as usize], &vecs[scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()] as usize]);
scratch.visited_list.push((projected_neighbour, score));
}
let mut neighbours = graph.out_neighbours_mut(sigma_i);
robust_prune(scratch, sigma_i, &mut *neighbours, vecs, config);
})
}
pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig) {
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
rng.shuffle(&mut sigmas);
// Iterate through graph vertices in a random order
let rng = Mutex::new(rng.fork());
sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
let mut neighbours = graph.out_neighbours_mut(sigma_i);
let mut i = 0;
while neighbours.len() < config.r_cap && i < 100 {
let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
if !neighbours.contains(&projected_neighbour) {
neighbours.push(projected_neighbour);
}
i += 1;
}
})
}
pub fn random_fill_graph(rng: &mut Rng, graph: &mut IndexGraph, r: usize) {
let rng = Mutex::new(rng.fork());
(0..graph.graph.len() as u32).into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, i| {
let mut neighbours = graph.out_neighbours_mut(i);
while neighbours.len() < r {
let next = rng.u32(0..(graph.graph.len() as u32));
if !neighbours.contains(&next) {
neighbours.push(next);
}
}
});
}
pub struct Timer(&'static str, std::time::Instant);
impl Timer {
pub fn new(name: &'static str) -> Self {
Timer(name, std::time::Instant::now())
}
}
impl Drop for Timer {
fn drop(&mut self) {
println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32());
}
}

121
diskann/src/main.rs Normal file
View File

@ -0,0 +1,121 @@
#![feature(test)]
#![feature(pointer_is_aligned_to)]
extern crate test;
use std::{io::Read, time::Instant};
use anyhow::Result;
use half::f16;
use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, dot, VectorList, self}, Timer};
use simsimd::SpatialSimilarity;
const D_EMB: usize = 1152;
fn load_file(path: &str, trunc: Option<usize>) -> Result<VectorList> {
let mut input = std::fs::File::open(path)?;
let mut buf = Vec::new();
input.read_to_end(&mut buf)?;
// TODO: this is not particularly efficient
let f16s = bytemuck::cast_slice::<_, f16>(&buf)[0..trunc.unwrap_or(buf.len()/2)].iter().copied().collect();
Ok(VectorList::from_f16s(f16s, D_EMB))
}
const PQ_TEST_SIZE: usize = 1000;
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
{
let file = std::fs::File::open("opq.msgpack")?;
let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?;
let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
let codes = codec.quantize_batch(&input);
println!("{:?}", codes);
let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
let query = codec.preprocess_query(&raw_query);
let mut real_scores = vec![];
for i in 0..PQ_TEST_SIZE {
real_scores.push(SpatialSimilarity::dot(&raw_query, &input[i * D_EMB .. (i + 1) * D_EMB]).unwrap() as f32);
}
let pq_scores = codec.asymmetric_dot_product(&query, &codes);
for (x, y) in real_scores.iter().zip(pq_scores.iter()) {
println!("{} {} {} {}", x, y, x - y, (x - y) / x);
}
}
let mut rng = fastrand::Rng::with_seed(1);
let n = 100000;
let vecs = {
let _timer = Timer::new("loaded vectors");
&load_file("embeddings.bin", Some(D_EMB * n))?
};
let (graph, medioid) = {
let _timer = Timer::new("index built");
let mut config = IndexBuildConfig {
r: 64,
r_cap: 80,
l: 128,
maxc: 750,
alpha: 65536,
};
let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap);
let medioid = medioid(&vecs);
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
config.alpha = 58000;
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
(graph, medioid)
};
let mut edge_ctr = 0;
for adjlist in graph.graph.iter() {
edge_ctr += adjlist.read().unwrap().len();
}
println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32);
let time = Instant::now();
let mut recall = 0;
let mut cmps_ctr = 0;
let mut cmps = vec![];
let mut config = IndexBuildConfig {
r: 64,
r_cap: 64,
l: 50,
alpha: 65536,
maxc: 0,
};
let mut scratch = Scratch::new(config);
for (i, vec) in tqdm::tqdm(vecs.iter().enumerate()) {
let ctr = greedy_search(&mut scratch, medioid, &vec, &vecs, &graph, config);
cmps_ctr += ctr.distances;
cmps.push(ctr.distances);
if scratch.neighbour_buffer.ids[0] == (i as u32) {
recall += 1;
}
}
cmps.sort();
let end = time.elapsed();
println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n);
println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n);
println!("median comparisons: {}", cmps[cmps.len() / 2]);
//println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries);
println!("{} QPS", n as f32 / end.as_secs_f32());
Ok(())
}

449
diskann/src/vector.rs Normal file
View File

@ -0,0 +1,449 @@
use core::f32;
use half::f16;
use simsimd::SpatialSimilarity;
use fastrand::Rng;
use serde::{Serialize, Deserialize};
use tracing_subscriber::field::RecordFields;
#[derive(Debug, Clone)]
pub struct Vector(Vec<f16>);
#[derive(Debug, Clone)]
pub struct SVector(Vec<f32>);
pub type VectorRef<'a> = &'a [f16];
pub type QVectorRef<'a> = &'a [u8];
pub type SVectorRef<'a> = &'a [f32];
impl SVector {
pub fn zero(d: usize) -> Self {
SVector(vec![0.0; d])
}
pub fn half(&self) -> Vector {
Vector(self.0.iter().map(|a| f16::from_f32(*a)).collect())
}
}
fn box_muller(rng: &mut Rng) -> f32 {
loop {
let u = rng.f32();
let v = rng.f32();
let x = (v * std::f32::consts::TAU).cos() * (-2.0 * u.ln()).sqrt();
if x.is_finite() {
return x;
}
}
}
impl Vector {
pub fn zero(d: usize) -> Self {
Vector(vec![f16::from_f32(0.0); d])
}
pub fn randn(rng: &mut Rng, d: usize) -> Self {
Vector(Vec::from_iter((0..d).map(|_| f16::from_f32(box_muller(rng)))))
}
}
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
const SCALE: f32 = 281474976710656.0;
const SCALE_F64: f64 = SCALE as f64;
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> f32 {
// safety is not real
(simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap()) as f32
}
pub fn to_svector(vec: VectorRef) -> SVector {
SVector(vec.iter().map(|a| a.to_f32()).collect())
}
impl<'a> std::ops::AddAssign<VectorRef<'a>> for SVector {
fn add_assign(&mut self, other: VectorRef<'a>) {
self.0.iter_mut().zip(other.iter()).for_each(|(a, b)| *a += b.to_f32());
}
}
impl std::ops::Div<f32> for SVector {
type Output = Self;
fn div(self, b: f32) -> Self::Output {
SVector(self.0.iter().map(|a| a / b).collect())
}
}
impl std::ops::Deref for Vector {
type Target = [f16];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::Deref for SVector {
type Target = [f32];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::Add<&SVector> for SVector {
type Output = Self;
fn add(self, other: &Self) -> Self::Output {
SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a + b).collect())
}
}
impl std::ops::Sub<&SVector> for SVector {
type Output = Self;
fn sub(self, other: &Self) -> Self::Output {
SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a - b).collect())
}
}
impl std::ops::AddAssign for SVector {
fn add_assign(&mut self, other: Self) {
self.0.iter_mut().zip(other.0.iter()).for_each(|(a, b)| *a += b);
}
}
impl std::ops::Mul<f32> for SVector {
type Output = Self;
fn mul(self, other: f32) -> Self {
SVector(self.0.iter().map(|a| *a * other).collect())
}
}
#[derive(Debug, Clone)]
pub struct VectorList {
pub d_emb: usize,
pub length: usize,
pub data: Vec<f16>
}
impl std::ops::Index<usize> for VectorList {
type Output = [f16];
fn index(&self, index: usize) -> &Self::Output {
&self.data[index * self.d_emb..(index + 1) * self.d_emb]
}
}
pub struct VectorListIterator<'a> {
list: &'a VectorList,
index: usize
}
impl<'a> Iterator for VectorListIterator<'a> {
type Item = VectorRef<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.list.len() {
let ret = &self.list[self.index];
self.index += 1;
Some(ret)
} else {
None
}
}
}
impl VectorList {
pub fn len(&self) -> usize {
self.length
}
pub fn iter(&self) -> VectorListIterator {
VectorListIterator {
list: self,
index: 0
}
}
pub fn empty(d: usize) -> Self {
VectorList {
d_emb: d,
length: 0,
data: Vec::new()
}
}
pub fn from_f16s(f16s: Vec<f16>, d: usize) -> Self {
assert!(f16s.len() % d == 0);
VectorList {
d_emb: d,
length: f16s.len() / d,
data: f16s
}
}
pub fn push(&mut self, vec: VectorRef) {
self.length += 1;
self.data.extend_from_slice(vec);
}
}
// SimSIMD has its own, but ours prefetches concurrently, is unrolled more, ignores inconveniently-sized vectors and does a cheaper reduction
// Also, we return an int because floats are annoying (not Ord)
// On Tiger Lake (i5-1135G7) we have about a 3x performance advantage ignoring the prefetching
// (it would be better to use AVX512 for said CPU but this also has to run on Zen 3)
pub fn fast_dot(x: VectorRef, y: VectorRef, prefetch: VectorRef) -> i64 {
use std::arch::x86_64::*;
debug_assert!(x.len() == y.len());
debug_assert!(prefetch.len() == x.len());
debug_assert!(x.len() % 64 == 0);
// safety is not real
// it's probably fine I guess
unsafe {
let mut x_ptr = x.as_ptr();
let mut y_ptr = y.as_ptr();
let end = x_ptr.add(x.len());
let mut prefetch_ptr = prefetch.as_ptr();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
let mut acc4 = _mm256_setzero_ps();
while x_ptr < end {
// fetch chunks and prefetch next vector
let x1 = _mm256_loadu_si256(x_ptr as *const __m256i);
let y1 = _mm256_loadu_si256(y_ptr as *const __m256i);
let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i);
let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i);
// technically, we only have to do this once per cache line but I don't care enough to test every way to optimize this
_mm_prefetch(prefetch_ptr as *const i8, _MM_HINT_T0);
x_ptr = x_ptr.add(32); // move 16 f16s at a time
y_ptr = y_ptr.add(32);
prefetch_ptr = prefetch_ptr.add(32);
// unpack f32 to f16
let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0));
let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1));
let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0));
let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1));
let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0));
let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1));
let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0));
let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1));
acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1);
acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2);
acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3);
acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4);
}
// reduce
let acc1 = _mm256_add_ps(acc1, acc2);
let acc2 = _mm256_add_ps(acc3, acc4);
let hsum = _mm256_hadd_ps(acc1, acc2);
let hsum_lo = _mm256_extractf128_ps(hsum, 0);
let hsum_hi = _mm256_extractf128_ps(hsum, 1);
let hsum = _mm_add_ps(hsum_lo, hsum_hi);
let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32);
(floatval * SCALE) as i64
}
}
// same as above, without prefetch pointer
pub fn fast_dot_noprefetch(x: VectorRef, y: VectorRef) -> i64 {
use std::arch::x86_64::*;
debug_assert!(x.len() == y.len());
debug_assert!(x.len() % 64 == 0);
unsafe {
let mut x_ptr = x.as_ptr();
let mut y_ptr = y.as_ptr();
let end = x_ptr.add(x.len());
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
let mut acc4 = _mm256_setzero_ps();
while x_ptr < end {
let x1 = _mm256_loadu_si256(x_ptr as *const __m256i);
let y1 = _mm256_loadu_si256(y_ptr as *const __m256i);
let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i);
let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i);
x_ptr = x_ptr.add(32);
y_ptr = y_ptr.add(32);
let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0));
let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1));
let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0));
let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1));
let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0));
let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1));
let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0));
let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1));
acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1);
acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2);
acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3);
acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4);
}
// reduce
let acc1 = _mm256_add_ps(acc1, acc2);
let acc2 = _mm256_add_ps(acc3, acc4);
let hsum = _mm256_hadd_ps(acc1, acc2);
let hsum_lo = _mm256_extractf128_ps(hsum, 0);
let hsum_hi = _mm256_extractf128_ps(hsum, 1);
let hsum = _mm_add_ps(hsum_lo, hsum_hi);
let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32);
(floatval * SCALE) as i64
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizer {
centroids: Vec<f32>,
transform: Vec<f32>, // D*D orthonormal matrix
pub n_dims_per_code: usize,
pub n_dims: usize
}
// chunk * centroid_index
pub struct DistanceLUT(Vec<f32>);
impl ProductQuantizer {
pub fn apply_transform(&self, x: &[f32]) -> Vec<f32> {
let dim = self.n_dims;
let n_vectors = x.len() / dim;
let mut transformed = vec![0.0; n_vectors * dim];
// transform_matrix (D * D) @ batch.T (D * B)
unsafe {
matrixmultiply::sgemm(dim, dim, n_vectors, 1.0, self.transform.as_ptr(), dim as isize, 1, x.as_ptr(), 1, dim as isize, 0.0, transformed.as_mut_ptr(), 1, dim as isize);
}
transformed
}
pub fn quantize_batch(&self, x: &[f32]) -> Vec<u8> {
// x is B * D
let dim = self.n_dims;
assert_eq!(dim * dim, self.transform.len());
let n_vectors = x.len() / dim;
let n_centroids = self.centroids.len() / dim;
assert!(n_centroids <= 256);
let transformed = self.apply_transform(&x); // B * D, as we write sgemm result in a weird order
let mut codes = vec![0; n_vectors * dim / self.n_dims_per_code];
let vec_len_codes = dim / self.n_dims_per_code;
// B * C buffer of similarity of each vector to each centroid, within subspace
let mut scratch = vec![0.0; n_vectors * n_centroids];
for i in 0..(dim / self.n_dims_per_code) {
let offset = i * self.n_dims_per_code;
// transformed_batch[:, range] (B * D_r) @ centroids[:, range].T (D_r * C)
unsafe {
matrixmultiply::sgemm(n_vectors, self.n_dims_per_code, n_centroids, 1.0, transformed.as_ptr().add(offset), dim as isize, 1, self.centroids.as_ptr().add(offset), 1, dim as isize, 0.0, scratch.as_mut_ptr(), n_centroids as isize, 1);
}
// assign this component to best centroid
for i_vec in 0..n_vectors {
let mut best = f32::NEG_INFINITY;
for i_centroid in 0..n_centroids {
let score = scratch[i_vec * n_centroids + i_centroid];
if score > best {
best = score;
codes[i_vec * vec_len_codes + i] = i_centroid as u8;
}
}
}
}
codes
}
// not particularly performance-sensitive right now; do unbatched
pub fn preprocess_query(&self, query: &[f32]) -> DistanceLUT {
let transformed = self.apply_transform(query);
let n_chunks = self.n_dims / self.n_dims_per_code;
let n_centroids = self.centroids.len() / self.n_dims;
let mut lut = Vec::with_capacity(n_chunks * n_centroids);
for i in 0..n_chunks {
let vec_component = &transformed[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code];
for j in 0..n_centroids {
let centroid = &self.centroids[j * self.n_dims..(j + 1) * self.n_dims];
let centroid_component = &centroid[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code];
let score = SpatialSimilarity::dot(vec_component, centroid_component).unwrap();
lut.push(score as f32);
}
}
DistanceLUT(lut)
}
// compute dot products of query against product-quantized vectors
pub fn asymmetric_dot_product(&self, query: &DistanceLUT, pq_vectors: &[u8]) -> Vec<i64> {
let n_chunks = self.n_dims / self.n_dims_per_code;
let n_vectors = pq_vectors.len() / n_chunks;
let mut scores = vec![0.0; n_vectors];
let n_centroids = self.centroids.len() / self.n_dims;
for i in 0..n_chunks {
for j in 0..n_vectors {
let code = pq_vectors[j * n_chunks + i];
let chunk_score = query.0[i * n_centroids + code as usize];
scores[j] += chunk_score;
}
}
// I have no idea why but we somehow have significant degradation in search quality
// if this accumulates in integers. As such, do floats and convert at the end.
// I'm sure there are fascinating reasons for this, but God is dead, God remains dead, etc.
scores.into_iter().map(|x| (x * SCALE) as i64).collect()
}
}
pub fn scale_dot_result(x: f64) -> i64 {
(x * SCALE_F64) as i64
}
#[cfg(test)]
mod bench {
use super::*;
use test::Bencher;
#[bench]
fn bench_dot(be: &mut Bencher) {
let mut rng = fastrand::Rng::with_seed(1);
let a = Vector::randn(&mut rng, 1024);
let b = Vector::randn(&mut rng, 1024);
be.iter(|| {
dot(&a, &b)
});
}
#[bench]
fn bench_fastdot(be: &mut Bencher) {
let mut rng = fastrand::Rng::with_seed(1);
let a = Vector::randn(&mut rng, 1024);
let b = Vector::randn(&mut rng, 1024);
be.iter(|| {
fast_dot(&a, &b, &a)
});
}
#[bench]
fn bench_fastdot_noprefetch(be: &mut Bencher) {
let mut rng = fastrand::Rng::with_seed(1);
let a = Vector::randn(&mut rng, 1024);
let b = Vector::randn(&mut rng, 1024);
be.iter(|| {
fast_dot_noprefetch(&a, &b)
});
}
}

44
diskann/vec_dist.py Normal file
View File

@ -0,0 +1,44 @@
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
A = 0.6
LOG_A = np.log(A)
def scale(xs):
return np.sign(xs) * (np.log(np.abs(xs) + A) - LOG_A)
n_dims = 1152
n_used_dims = 32
data = np.frombuffer(open("embeddings.bin", "rb").read(), dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # TODO
# Create histogram bins
n_bins = 256
s = __import__("math").sqrt(n_dims)
hist_range = (-1.2, 1.2)
histogram_data = np.zeros((n_used_dims, n_bins))
# Calculate histograms for each dimension
for dim in range(n_used_dims):
dbd = data[:, dim]
dbd = (dbd - np.mean(dbd)) / np.std(dbd)
dbd = scale(dbd)
hist, _ = np.histogram(dbd, bins=n_bins, range=hist_range, density=True)
histogram_data[dim] = hist
# Create heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(histogram_data,
cmap='viridis',
xticklabels=np.linspace(hist_range[0], hist_range[1], n_bins),
yticklabels=range(n_used_dims),
cbar_kws={'label': 'Density'})
plt.xlabel('Value')
plt.ylabel('Dimension')
plt.title('Distribution Heatmap of First 16 Dimensions')
# Adjust layout to prevent label cutoff
plt.tight_layout()
plt.show()