mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-07 19:26:39 +00:00
release early draft of index code
This commit is contained in:
parent
512b776e10
commit
e0cf65204b
7
.gitignore
vendored
7
.gitignore
vendored
@ -9,4 +9,9 @@ node_modules/*
|
||||
node_modules
|
||||
*sqlite3*
|
||||
thumbtemp
|
||||
mse-test-db-small
|
||||
mse-test-db-small
|
||||
clipfront2/static/bg*
|
||||
diskann/target
|
||||
*.bin
|
||||
*.msgpack
|
||||
*/flamegraph.svg
|
||||
|
748
diskann/Cargo.lock
generated
Normal file
748
diskann/Cargo.lock
generated
Normal file
@ -0,0 +1,748 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.93"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
|
||||
dependencies = [
|
||||
"funty",
|
||||
"radium",
|
||||
"tap",
|
||||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a"
|
||||
dependencies = [
|
||||
"bytemuck_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck_derive"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
|
||||
dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
|
||||
|
||||
[[package]]
|
||||
name = "crossterm"
|
||||
version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"crossterm_winapi",
|
||||
"libc",
|
||||
"mio",
|
||||
"parking_lot",
|
||||
"signal-hook",
|
||||
"signal-hook-mio",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossterm_winapi"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
|
||||
[[package]]
|
||||
name = "diskann"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bitvec",
|
||||
"bytemuck",
|
||||
"fastrand",
|
||||
"foldhash",
|
||||
"half",
|
||||
"matrixmultiply",
|
||||
"rayon",
|
||||
"rmp-serde",
|
||||
"serde",
|
||||
"simsimd",
|
||||
"tqdm",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
|
||||
|
||||
[[package]]
|
||||
name = "foldhash"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
|
||||
|
||||
[[package]]
|
||||
name = "funty"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.164"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"wasi",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
|
||||
dependencies = [
|
||||
"overload",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.20.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"smallvec",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "radium"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rmp"
|
||||
version = "0.8.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"num-traits",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rmp-serde"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"rmp",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.215"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.215"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook"
|
||||
version = "0.3.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"signal-hook-registry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-mio"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"mio",
|
||||
"signal-hook",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simsimd"
|
||||
version = "6.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be2ad0164e13e58a994d3dd1ff57d44cee87c445708e3acea7ad4f03a47092ce"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.87"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tap"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa2d2932240205a99b65f15d9861992c95fbb8c9fb280b3a1f17a92db6dc611f"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"crossterm",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing"
|
||||
version = "0.1.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||
dependencies = [
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"valuable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-log"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-subscriber"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
|
||||
dependencies = [
|
||||
"nu-ansi-term",
|
||||
"sharded-slab",
|
||||
"smallvec",
|
||||
"thread_local",
|
||||
"tracing-core",
|
||||
"tracing-log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.48.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
|
||||
dependencies = [
|
||||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.48.5",
|
||||
"windows_aarch64_msvc 0.48.5",
|
||||
"windows_i686_gnu 0.48.5",
|
||||
"windows_i686_msvc 0.48.5",
|
||||
"windows_x86_64_gnu 0.48.5",
|
||||
"windows_x86_64_gnullvm 0.48.5",
|
||||
"windows_x86_64_msvc 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.52.6",
|
||||
"windows_aarch64_msvc 0.52.6",
|
||||
"windows_i686_gnu 0.52.6",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc 0.52.6",
|
||||
"windows_x86_64_gnu 0.52.6",
|
||||
"windows_x86_64_gnullvm 0.52.6",
|
||||
"windows_x86_64_msvc 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.48.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "wyz"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
|
||||
dependencies = [
|
||||
"tap",
|
||||
]
|
28
diskann/Cargo.toml
Normal file
28
diskann/Cargo.toml
Normal file
@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "diskann"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
half = { version = "2", features = ["bytemuck"] }
|
||||
fastrand = "2"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
simsimd = "6"
|
||||
foldhash = "0.1"
|
||||
bitvec = "1"
|
||||
tqdm = "0.7"
|
||||
anyhow = "1"
|
||||
bytemuck = { version = "1", features = ["extern_crate_alloc"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
rmp-serde = "1"
|
||||
rayon = "1"
|
||||
matrixmultiply = "0.3"
|
||||
|
||||
[lib]
|
||||
name = "diskann"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "diskann"
|
||||
path = "src/main.rs"
|
113
diskann/aopq_train.py
Normal file
113
diskann/aopq_train.py
Normal file
@ -0,0 +1,113 @@
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import math
|
||||
import torch
|
||||
from torch import autograd
|
||||
import faiss
|
||||
import tqdm
|
||||
|
||||
n_dims = 1152
|
||||
output_code_size = 64
|
||||
output_code_bits = 8
|
||||
output_codebook_size = 2**output_code_bits
|
||||
n_dims_per_code = n_dims // output_code_size
|
||||
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
device = "cpu"
|
||||
|
||||
index = faiss.index_factory(n_dims, "HNSW32,SQfp16", faiss.METRIC_INNER_PRODUCT)
|
||||
index.train(queryset)
|
||||
index.add(queryset)
|
||||
print("index ready")
|
||||
|
||||
T = 64
|
||||
|
||||
nearby_query_indices = torch.zeros((dataset.shape[0], T), dtype=torch.int32)
|
||||
|
||||
SEARCH_BATCH_SIZE = 1024
|
||||
|
||||
for i in range(0, len(dataset), SEARCH_BATCH_SIZE):
|
||||
res = index.search(dataset[i:i+SEARCH_BATCH_SIZE], T)
|
||||
nearby_query_indices[i:i+SEARCH_BATCH_SIZE] = torch.tensor(res[1])
|
||||
|
||||
print("query indices ready")
|
||||
|
||||
def pq_assign(centroids, batch):
|
||||
quantized = torch.zeros_like(batch)
|
||||
|
||||
# Assign to nearest centroid in each subspace
|
||||
for dmin in range(0, n_dims, n_dims_per_code):
|
||||
dmax = dmin + n_dims_per_code
|
||||
similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T)
|
||||
assignments = similarities.argmax(dim=1)
|
||||
quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax]
|
||||
|
||||
return quantized
|
||||
|
||||
# OOD-DiskANN (https://arxiv.org/abs/2211.12850) uses a more complicated scheme because it uses L2 norm
|
||||
# We only care about inner product so our quantization error (wrt a query) is just abs(dot(query, centroid - vector))
|
||||
# Directly optimize for this (wrt top queries; it might actually be better to use a random sample instead?)
|
||||
def partition(vectors, centroids, projection, opt, queries, nearby_query_indices, k, max_iter=100, batch_size=4096):
|
||||
n_vectors = len(vectors)
|
||||
perm = torch.randperm(n_vectors, device=device)
|
||||
|
||||
t = tqdm.trange(max_iter)
|
||||
for iter in t:
|
||||
total_loss = 0
|
||||
opt.zero_grad(set_to_none=True)
|
||||
|
||||
for i in range(0, n_vectors, batch_size):
|
||||
loss = torch.tensor(0.0, device=device)
|
||||
batch = vectors[i:i+batch_size] @ projection
|
||||
quantized = pq_assign(centroids, batch)
|
||||
residuals = batch - quantized
|
||||
|
||||
# for each index in our set of nearby queries
|
||||
for j in range(0, nearby_query_indices.shape[1]):
|
||||
queries_for_batch_j = queries[nearby_query_indices[i:i+batch_size, j]]
|
||||
# minimize quantiation error in direction of query, i.e. mean abs(dot(query, centroid - vector))
|
||||
# PyTorch won't do batched dot products cleanly, to spite me. Do componentwise multiplication and reduce.
|
||||
sg_errs = (queries_for_batch_j * residuals).sum(dim=-1)
|
||||
loss += torch.mean(torch.abs(sg_errs))
|
||||
|
||||
total_loss += loss.detach().item()
|
||||
loss.backward()
|
||||
|
||||
opt.step()
|
||||
|
||||
t.set_description(f"loss: {total_loss:.4f}")
|
||||
|
||||
def random_ortho(dim):
|
||||
h = torch.randn(dim, dim, device=device)
|
||||
q, r = torch.linalg.qr(h)
|
||||
return q
|
||||
|
||||
# non-parametric OPQ algorithm (roughly)
|
||||
# https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/opq_tr.pdf
|
||||
projection = random_ortho(n_dims)
|
||||
vectors = torch.tensor(dataset, device=device)
|
||||
queries = torch.tensor(queryset, device=device)
|
||||
perm = torch.randperm(len(vectors), device=device)
|
||||
centroids = vectors[perm[:output_codebook_size]]
|
||||
centroids.requires_grad = True
|
||||
opt = torch.optim.Adam([centroids], lr=0.001)
|
||||
for i in range(30):
|
||||
# update centroids to minimize query-aware quantization loss
|
||||
partition(vectors, centroids, projection, opt, queries, nearby_query_indices, output_codebook_size, max_iter=8)
|
||||
# compute new projection as R = VU^T from XY^T = USV^T (SVD)
|
||||
# where X is dataset vectors, Y is quantized dataset vectors
|
||||
with torch.no_grad():
|
||||
y = pq_assign(centroids, vectors)
|
||||
# paper uses D*N and not N*D in its descriptions for whatever reason (so we transpose when they don't)
|
||||
u, s, vt = torch.linalg.svd(vectors.T @ y)
|
||||
projection = vt.T @ u.T
|
||||
|
||||
print("done")
|
||||
|
||||
with open("opq.msgpack", "wb") as f:
|
||||
msgpack.pack({
|
||||
"centroids": centroids.detach().cpu().numpy().flatten().tolist(),
|
||||
"transform": projection.cpu().numpy().flatten().tolist(),
|
||||
"n_dims_per_code": n_dims_per_code,
|
||||
"n_dims": n_dims
|
||||
}, f)
|
491
diskann/flamegraph.svg
Normal file
491
diskann/flamegraph.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 139 KiB |
44
diskann/opq_test.py
Normal file
44
diskann/opq_test.py
Normal file
@ -0,0 +1,44 @@
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import math
|
||||
import torch
|
||||
import faiss
|
||||
import tqdm
|
||||
|
||||
n_dims = 1152
|
||||
output_code_size = 64
|
||||
output_code_bits = 8
|
||||
output_codebook_size = 2**output_code_bits
|
||||
n_dims_per_code = n_dims // output_code_size
|
||||
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
device = "cpu"
|
||||
|
||||
def pq_assign(centroids, batch):
|
||||
quantized = torch.zeros_like(batch)
|
||||
|
||||
# Assign to nearest centroid in each subspace
|
||||
for dmin in range(0, n_dims, n_dims_per_code):
|
||||
dmax = dmin + n_dims_per_code
|
||||
similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T)
|
||||
assignments = similarities.argmax(dim=1)
|
||||
quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax]
|
||||
|
||||
return quantized
|
||||
|
||||
with open("opq.msgpack", "rb") as f:
|
||||
data = msgpack.unpack(f)
|
||||
centroids = torch.tensor(data["centroids"], device=device).reshape(2**output_code_bits, n_dims)
|
||||
projection = torch.tensor(data["transform"], device=device).reshape(n_dims, n_dims)
|
||||
|
||||
vectors = torch.tensor(dataset, device=device)
|
||||
queries = torch.tensor(queryset, device=device)
|
||||
|
||||
sample_size = 64
|
||||
qsample = pq_assign(centroids, vectors[:sample_size] @ projection)
|
||||
print(qsample)
|
||||
print(vectors[:sample_size])
|
||||
exact_results = vectors[:sample_size] @ queries[0]
|
||||
approx_results = qsample @ (projection @ queries[0])
|
||||
print(np.argsort(approx_results))
|
||||
print(np.argsort(exact_results))
|
68
diskann/rabitq.py
Normal file
68
diskann/rabitq.py
Normal file
@ -0,0 +1,68 @@
|
||||
# https://arxiv.org/pdf/2405.12497
|
||||
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import math
|
||||
import tqdm
|
||||
|
||||
n_dims = 1152
|
||||
output_dims = 64*8
|
||||
scale = 1 / math.sqrt(n_dims)
|
||||
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
mean = np.mean(dataset, axis=0)
|
||||
|
||||
centered_dataset = dataset - mean
|
||||
norms = np.linalg.norm(centered_dataset, axis=1)
|
||||
centered_dataset = centered_dataset / norms[:, np.newaxis]
|
||||
print(centered_dataset)
|
||||
|
||||
sample = centered_dataset[:64]
|
||||
|
||||
def random_ortho(dim):
|
||||
h = np.random.randn(dim, dim)
|
||||
q, r = np.linalg.qr(h)
|
||||
return q
|
||||
|
||||
p = random_ortho(n_dims) # algorithm only uses the inverse of P, so just sample that directly
|
||||
p = p[:output_dims, :]
|
||||
|
||||
def quantize(datavecs):
|
||||
xs = (p @ datavecs.T).T
|
||||
quantized = xs > 0
|
||||
dequantized = scale * (2 * quantized - 1)
|
||||
dots = np.sum(dequantized * xs, axis=1) # <o_bar, o>
|
||||
return quantized, dots
|
||||
|
||||
qsample, dots = quantize(sample)
|
||||
print(qsample.sum(axis=1).mean())
|
||||
#print(dots)
|
||||
#print(dots.mean())
|
||||
|
||||
def approx_dot(quantized_samples, dots, query):
|
||||
mean_to_query = np.dot(mean, query)
|
||||
print(mean_to_query)
|
||||
dequantized = scale * (2 * quantized_samples - 1)
|
||||
query_transformed = p @ query
|
||||
o_bar_dot_q = np.sum(dequantized * query_transformed, axis=1)
|
||||
return norms[:sample.shape[0]] * o_bar_dot_q * dots + mean_to_query
|
||||
|
||||
print(norms)
|
||||
approx_results = approx_dot(qsample, dots, queryset[0])
|
||||
exact_results = sample @ queryset[0]
|
||||
|
||||
for x in zip(approx_results, exact_results):
|
||||
print(*x)
|
||||
|
||||
print(*[ f"{x:.2f}" for x in (approx_results - exact_results) / np.abs(exact_results).mean() ])
|
||||
|
||||
print(np.argsort(approx_results))
|
||||
print(np.argsort(exact_results))
|
||||
|
||||
with open("rabitq.msgpack", "wb") as f:
|
||||
msgpack.pack({
|
||||
"mean": mean.flatten().tolist(),
|
||||
"transform": p.flatten().tolist(),
|
||||
"output_dims": output_dims,
|
||||
"n_dims": n_dims
|
||||
}, f)
|
167
diskann/scalar_quantize.py
Normal file
167
diskann/scalar_quantize.py
Normal file
@ -0,0 +1,167 @@
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import math
|
||||
|
||||
n_dims = 1152
|
||||
n_buckets = n_dims
|
||||
#n_buckets = n_dims // 2 # we now have one quant scale per pair of components
|
||||
#pair_separation = 16 # for efficient dot product computation, we need to have the second element of a pair exactly chunk_size after the first
|
||||
n_dims_per_bucket = n_dims // n_buckets
|
||||
data = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # sorry
|
||||
|
||||
CUTOFF = 1e-3 / 2
|
||||
|
||||
print("computing quantiles")
|
||||
smin = np.quantile(data, CUTOFF, axis=0)
|
||||
smax = np.quantile(data, 1 - CUTOFF, axis=0)
|
||||
|
||||
# naive O(n²) greedy algorithm
|
||||
# probably overbuilt for the 2-components-per-bucket case but I'm not getting rid of it
|
||||
def assign_buckets():
|
||||
import random
|
||||
intervals = list(enumerate(zip(smin, smax)))
|
||||
random.shuffle(intervals)
|
||||
buckets = [ [ intervals.pop() ] for _ in range(n_buckets) ]
|
||||
def bucket_cost(bucket):
|
||||
bmin = min(cmin for id, (cmin, cmax) in bucket)
|
||||
bmax = max(cmax for id, (cmin, cmax) in bucket)
|
||||
#print("MIN", bmin, "MAX", bmax)
|
||||
return sum(abs(cmin - bmin) + abs(cmax - bmax) for id, (cmin, cmax) in bucket)
|
||||
while len(intervals):
|
||||
for bucket in buckets:
|
||||
def new_interval_cost(interval):
|
||||
return bucket_cost(bucket + [interval[1]])
|
||||
i, interval = min(enumerate(intervals), key=new_interval_cost)
|
||||
bucket.append(intervals.pop(i))
|
||||
return buckets
|
||||
|
||||
ranges = smax - smin
|
||||
# TODO: it is possible to do better assignment to buckets
|
||||
#order = np.argsort(ranges)
|
||||
print("bucket assignment")
|
||||
order = np.arange(n_dims) # np.concatenate(np.stack([ [ id for id, (cmin, cmax) in bucket ] for bucket in assign_buckets() ]))
|
||||
|
||||
bucket_ranges = []
|
||||
bucket_centres = []
|
||||
bucket_absmax = []
|
||||
bucket_gmins = []
|
||||
|
||||
for bucket_min in range(0, n_dims, n_dims_per_bucket):
|
||||
bucket_max = bucket_min + n_dims_per_bucket
|
||||
indices = order[bucket_min:bucket_max]
|
||||
gmin = float(np.min(smin[indices]))
|
||||
gmax = float(np.max(smax[indices]))
|
||||
bucket_range = gmax - gmin
|
||||
bucket_centre = (gmax + gmin) / 2
|
||||
bucket_gmins.append(gmin)
|
||||
bucket_ranges.append(bucket_range)
|
||||
bucket_centres.append(bucket_centre)
|
||||
bucket_absmax.append(max(abs(gmin), abs(gmax)))
|
||||
|
||||
print("determining scales")
|
||||
scales = [] # multiply by float and convert to quantize
|
||||
offsets = []
|
||||
q_offsets = [] # int16 value to add at dot time
|
||||
q_scales = [] # rescales channel up at dot time; must be proportional(ish) to square of scale factor but NOT cause overflow in accumulation or PLMULLW
|
||||
scale_factor_bound = float("inf")
|
||||
for bucket in range(n_buckets):
|
||||
step_size = bucket_ranges[bucket] / 255
|
||||
scales.append(1 / step_size)
|
||||
q_offset = int(bucket_gmins[bucket] / step_size)
|
||||
q_offsets.append(q_offset)
|
||||
nsfb = (2**31 - 1) / (n_dims_per_bucket * abs((255**2) + 2 * q_offset * 255 + q_offset ** 2)) / 2
|
||||
# we are bounded both by overflow in accumulation and PLMULLW (u8 plus offset times scale factor)
|
||||
scale_factor_bound = min(scale_factor_bound, nsfb, (2**15 - 1) // (q_offset + 255))
|
||||
offsets.append(bucket_gmins[bucket])
|
||||
|
||||
for bucket in range(n_buckets):
|
||||
sfb = scale_factor_bound / max(map(lambda x: x ** 2, bucket_ranges))
|
||||
sf = (bucket_ranges[bucket]) ** 2 * sfb
|
||||
q_scales.append(int(sf))
|
||||
|
||||
print(bucket_ranges, bucket_centres, bucket_absmax)
|
||||
print(scales, offsets, q_offsets, q_scales)
|
||||
|
||||
"""
|
||||
interleave = np.concatenate([
|
||||
np.arange(0, n_dims, n_dims_per_bucket) + a
|
||||
for a in range(n_dims_per_bucket)
|
||||
])
|
||||
"""
|
||||
|
||||
"""
|
||||
interleave = np.arange(0, n_dims)
|
||||
for base in range(0, n_dims, 2 * pair_separation):
|
||||
interleave[base:base + pair_separation] = np.arange(base, base + 2 * pair_separation, 2)
|
||||
interleave[base + pair_separation:base + 2 * pair_separation] = np.arange(base + 1, base + 2 * pair_separation + 1, 2)
|
||||
"""
|
||||
|
||||
#print(bucket_ranges, bucket_centres, order[interleave])
|
||||
#print(ranges[order][interleave].tolist())
|
||||
#print(ranges.tolist())
|
||||
|
||||
with open("quantizer.msgpack", "wb") as f:
|
||||
msgpack.pack({
|
||||
"permutation": order.tolist(),
|
||||
"offsets": offsets,
|
||||
"scales": scales,
|
||||
"q_offsets": q_offsets,
|
||||
"q_scales": q_scales
|
||||
}, f)
|
||||
|
||||
def rquantize(vec):
|
||||
out = np.zeros(len(vec), dtype=np.uint8)
|
||||
for i, p in enumerate(order[interleave]):
|
||||
bucket = p % n_buckets
|
||||
raw = vec[i]
|
||||
raw = (raw - offsets[bucket]) * scales[bucket]
|
||||
raw = min(max(raw, 0.0), 255.0)
|
||||
out[p] = round(raw)
|
||||
return out
|
||||
|
||||
def rdquantize(bytes):
|
||||
vec = np.zeros(n_dims, dtype=np.float32)
|
||||
for i, p in enumerate(order[interleave]):
|
||||
bucket = p % n_buckets
|
||||
raw = float(bytes[p])
|
||||
vec[i] = raw / scales[bucket] + offsets[bucket]
|
||||
return vec
|
||||
|
||||
def rdot(x, y):
|
||||
xq_offsets = np.array(q_offsets, dtype=np.int16)
|
||||
xq_scales = np.array(q_scales, dtype=np.int16)
|
||||
assert x.shape == y.shape
|
||||
assert x.dtype == np.uint8 == y.dtype
|
||||
acc = 0
|
||||
for i in range(0, len(x), n_buckets):
|
||||
x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets
|
||||
y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets
|
||||
x1 *= xq_scales
|
||||
acc += np.dot(x1.astype(np.int32), y1.astype(np.int32))
|
||||
return acc
|
||||
|
||||
def cmp(i, j):
|
||||
return np.dot(data[i], data[j]) / rdot(rquantize(data[i]), rquantize(data[j]))
|
||||
|
||||
def rdot_cmp(a, b):
|
||||
x = rquantize(a)
|
||||
y = rquantize(b)
|
||||
a = a[order[interleave]]
|
||||
b = b[order[interleave]]
|
||||
xq_offsets = np.array(q_offsets, dtype=np.int16)
|
||||
xq_scales = np.array(q_scales, dtype=np.int16)
|
||||
assert x.shape == y.shape
|
||||
assert x.dtype == np.uint8 == y.dtype
|
||||
acc = 0
|
||||
for i in range(0, len(x), n_buckets):
|
||||
x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets
|
||||
y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets
|
||||
x1 *= xq_scales
|
||||
component = np.dot(x1.astype(np.int32), y1.astype(np.int32))
|
||||
a1 = a[i:i+n_buckets]
|
||||
b1 = b[i:i+n_buckets]
|
||||
component_exact = np.dot(a1, b1)
|
||||
print(x1, a1, sep="\n")
|
||||
print(component, component_exact, component / component_exact)
|
||||
acc += component
|
||||
return acc
|
389
diskann/src/lib.rs
Normal file
389
diskann/src/lib.rs
Normal file
@ -0,0 +1,389 @@
|
||||
#![feature(pointer_is_aligned_to)]
|
||||
#![feature(test)]
|
||||
|
||||
extern crate test;
|
||||
|
||||
use foldhash::{HashSet, HashMap, HashMapExt, HashSetExt};
|
||||
use fastrand::Rng;
|
||||
use rayon::prelude::*;
|
||||
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, Mutex};
|
||||
|
||||
pub mod vector;
|
||||
use vector::{dot, fast_dot, fast_dot_noprefetch, to_svector, VectorRef, SVector, VectorList};
|
||||
|
||||
// ParlayANN improves parallelism by not using locks like this and instead using smarter batch operations
|
||||
// but I don't have enough cores that it matters
|
||||
#[derive(Debug)]
|
||||
pub struct IndexGraph {
|
||||
pub graph: Vec<RwLock<Vec<u32>>>
|
||||
}
|
||||
|
||||
impl IndexGraph {
|
||||
pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self {
|
||||
let mut graph = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
let mut adjacency = Vec::with_capacity(capacity);
|
||||
for _ in 0..r {
|
||||
adjacency.push(rng.u32(0..(n as u32)));
|
||||
}
|
||||
graph.push(RwLock::new(adjacency));
|
||||
}
|
||||
IndexGraph {
|
||||
graph
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty(n: usize, capacity: usize) -> IndexGraph {
|
||||
let mut graph = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
graph.push(RwLock::new(Vec::with_capacity(capacity)));
|
||||
}
|
||||
IndexGraph {
|
||||
graph
|
||||
}
|
||||
}
|
||||
|
||||
fn out_neighbours(&self, pt: u32) -> RwLockReadGuard<Vec<u32>> {
|
||||
self.graph[pt as usize].read().unwrap()
|
||||
}
|
||||
|
||||
fn out_neighbours_mut(&self, pt: u32) -> RwLockWriteGuard<Vec<u32>> {
|
||||
self.graph[pt as usize].write().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct IndexBuildConfig {
|
||||
pub r: usize,
|
||||
pub r_cap: usize,
|
||||
pub l: usize,
|
||||
pub maxc: usize,
|
||||
pub alpha: i64
|
||||
}
|
||||
|
||||
|
||||
fn centroid(vecs: &VectorList) -> SVector {
|
||||
let mut centroid = SVector::zero(vecs.d_emb);
|
||||
|
||||
for (i, vec) in vecs.iter().enumerate() {
|
||||
let weight = 1.0 / (i + 1) as f32;
|
||||
centroid += (to_svector(vec) - ¢roid) * weight;
|
||||
}
|
||||
|
||||
centroid
|
||||
}
|
||||
|
||||
pub fn medioid(vecs: &VectorList) -> u32 {
|
||||
let centroid = centroid(vecs).half();
|
||||
vecs.iter().map(|vec| dot(vec, &*centroid)).enumerate().max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()).unwrap().0 as u32
|
||||
}
|
||||
|
||||
// neighbours list sorted by score descending
|
||||
// TODO: this may actually be an awful datastructure
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NeighbourBuffer {
|
||||
pub ids: Vec<u32>,
|
||||
scores: Vec<i64>,
|
||||
visited: Vec<bool>,
|
||||
next_unvisited: Option<u32>,
|
||||
size: usize
|
||||
}
|
||||
|
||||
impl NeighbourBuffer {
|
||||
pub fn new(size: usize) -> Self {
|
||||
NeighbourBuffer {
|
||||
ids: Vec::with_capacity(size + 1),
|
||||
scores: Vec::with_capacity(size + 1),
|
||||
visited: Vec::with_capacity(size + 1), //bitvec::vec::BitVec::with_capacity(size),
|
||||
next_unvisited: None,
|
||||
size
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_unvisited(&mut self) -> Option<u32> {
|
||||
//println!("next_unvisited: {:?}", self);
|
||||
let mut cur = self.next_unvisited? as usize;
|
||||
let old_cur = cur;
|
||||
self.visited[cur] = true;
|
||||
while cur < self.len() && self.visited[cur] {
|
||||
cur += 1;
|
||||
}
|
||||
if cur == self.len() {
|
||||
self.next_unvisited = None;
|
||||
} else {
|
||||
self.next_unvisited = Some(cur as u32);
|
||||
}
|
||||
Some(self.ids[old_cur])
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.ids.len()
|
||||
}
|
||||
|
||||
pub fn cap(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, id: u32, score: i64) {
|
||||
if self.len() == self.cap() && self.scores[self.len() - 1] > score {
|
||||
return;
|
||||
}
|
||||
|
||||
let loc = match self.scores.binary_search_by(|x| score.partial_cmp(&x).unwrap()) {
|
||||
Ok(loc) => loc,
|
||||
Err(loc) => loc
|
||||
};
|
||||
|
||||
if self.ids.get(loc) == Some(&id) {
|
||||
return;
|
||||
}
|
||||
|
||||
// slightly inefficient but we avoid unsafe code
|
||||
self.ids.insert(loc, id);
|
||||
self.scores.insert(loc, score);
|
||||
self.visited.insert(loc, false);
|
||||
self.ids.truncate(self.size);
|
||||
self.scores.truncate(self.size);
|
||||
self.visited.truncate(self.size);
|
||||
|
||||
self.next_unvisited = Some(loc as u32);
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.ids.clear();
|
||||
self.scores.clear();
|
||||
self.visited.clear();
|
||||
self.next_unvisited = None;
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Scratch {
|
||||
visited: HashSet<u32>,
|
||||
pub neighbour_buffer: NeighbourBuffer,
|
||||
neighbour_pre_buffer: Vec<u32>,
|
||||
visited_list: Vec<(u32, i64)>,
|
||||
robust_prune_scratch_buffer: Vec<(usize, u32)>
|
||||
}
|
||||
|
||||
impl Scratch {
|
||||
pub fn new(IndexBuildConfig { l, r, maxc, .. }: IndexBuildConfig) -> Self {
|
||||
Scratch {
|
||||
visited: HashSet::with_capacity(l * 8),
|
||||
neighbour_buffer: NeighbourBuffer::new(l),
|
||||
neighbour_pre_buffer: Vec::with_capacity(r),
|
||||
visited_list: Vec::with_capacity(l * 8),
|
||||
robust_prune_scratch_buffer: Vec::with_capacity(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GreedySearchCounters {
|
||||
pub distances: usize
|
||||
}
|
||||
|
||||
// Algorithm 1 from the DiskANN paper
|
||||
// We support the dot product metric only, so we want to keep things with the HIGHEST dot product
|
||||
pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters {
|
||||
scratch.visited.clear();
|
||||
scratch.neighbour_buffer.clear();
|
||||
scratch.visited_list.clear();
|
||||
|
||||
scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vecs[start as usize]));
|
||||
scratch.visited.insert(start);
|
||||
|
||||
let mut counters = GreedySearchCounters { distances: 0 };
|
||||
|
||||
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
|
||||
//println!("pt {} {:?}", pt, graph.out_neighbours(pt));
|
||||
scratch.neighbour_pre_buffer.clear();
|
||||
for &neighbour in graph.out_neighbours(pt).iter() {
|
||||
if scratch.visited.insert(neighbour) {
|
||||
scratch.neighbour_pre_buffer.push(neighbour);
|
||||
}
|
||||
}
|
||||
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
|
||||
let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
|
||||
let distance = fast_dot(query, &vecs[neighbour as usize], &vecs[next_neighbour as usize]);
|
||||
counters.distances += 1;
|
||||
scratch.neighbour_buffer.insert(neighbour, distance);
|
||||
scratch.visited_list.push((neighbour, distance));
|
||||
}
|
||||
}
|
||||
|
||||
counters
|
||||
}
|
||||
|
||||
type CandidateList = Vec<(u32, i64)>;
|
||||
|
||||
fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList, config: IndexBuildConfig) {
|
||||
let p_vec = &vecs[point as usize];
|
||||
for (i, &n) in neigh.iter().enumerate() {
|
||||
let dot = fast_dot(p_vec, &vecs[n as usize], &vecs[neigh[(i + 1) % neigh.len() as usize] as usize]);
|
||||
candidates.push((n, dot));
|
||||
}
|
||||
}
|
||||
|
||||
// "Robust prune" algorithm, kind of
|
||||
// The algorithm in the paper does not actually match the code as implemented in microsoft/DiskANN
|
||||
// and that's slightly different from the one in ParlayANN for no reason
|
||||
// This is closer to ParlayANN
|
||||
fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &VectorList, config: IndexBuildConfig) {
|
||||
neigh.clear();
|
||||
|
||||
let candidates = &mut scratch.visited_list;
|
||||
|
||||
// distance low to high = score high to low
|
||||
candidates.sort_unstable_by_key(|&(_id, score)| -score);
|
||||
candidates.truncate(config.maxc);
|
||||
|
||||
let mut candidate_index = 0;
|
||||
while neigh.len() < config.r && candidate_index < candidates.len() {
|
||||
let p_star = candidates[candidate_index].0;
|
||||
candidate_index += 1;
|
||||
if p_star == p || p_star == u32::MAX {
|
||||
continue;
|
||||
}
|
||||
|
||||
neigh.push(p_star);
|
||||
|
||||
scratch.robust_prune_scratch_buffer.clear();
|
||||
|
||||
// mark remaining candidates as not-to-be-used if "not much better than" current candidate
|
||||
for i in (candidate_index+1)..candidates.len() {
|
||||
let p_prime = candidates[i].0;
|
||||
if p_prime != u32::MAX {
|
||||
scratch.robust_prune_scratch_buffer.push((i, p_prime));
|
||||
}
|
||||
}
|
||||
|
||||
for (i, &(ci, p_prime)) in scratch.robust_prune_scratch_buffer.iter().enumerate() {
|
||||
let next_vec = &vecs[scratch.robust_prune_scratch_buffer[(i + 1) % scratch.robust_prune_scratch_buffer.len()].0 as usize];
|
||||
let p_star_prime_score = fast_dot(&vecs[p_prime as usize], &vecs[p_star as usize], next_vec);
|
||||
let p_prime_p_score = candidates[ci].1;
|
||||
let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16;
|
||||
|
||||
if alpha_times_p_star_prime_score >= p_prime_p_score {
|
||||
candidates[ci].0 = u32::MAX;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &VectorList, config: IndexBuildConfig) {
|
||||
assert!(vecs.len() < u32::MAX as usize);
|
||||
|
||||
let mut sigmas: Vec<u32> = (0..(vecs.len() as u32)).collect();
|
||||
rng.shuffle(&mut sigmas);
|
||||
|
||||
let rng = Mutex::new(rng.fork());
|
||||
|
||||
//let scratch = &mut Scratch::new(config);
|
||||
//let mut rng = rng.lock().unwrap();
|
||||
sigmas.into_par_iter().for_each_init(|| (Scratch::new(config), rng.lock().unwrap().fork()), |(scratch, rng), sigma_i| {
|
||||
//sigmas.into_iter().for_each(|sigma_i| {
|
||||
greedy_search(scratch, medioid, &vecs[sigma_i as usize], vecs, &graph, config);
|
||||
|
||||
{
|
||||
let n = graph.out_neighbours(sigma_i);
|
||||
merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs, config);
|
||||
}
|
||||
|
||||
{
|
||||
let mut n = graph.out_neighbours_mut(sigma_i);
|
||||
robust_prune(scratch, sigma_i, &mut *n, vecs, config);
|
||||
}
|
||||
|
||||
let neighbours = graph.out_neighbours(sigma_i).to_owned();
|
||||
for neighbour in neighbours {
|
||||
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
|
||||
// To cut down pruning time slightly, allow accumulating more neighbours than usual limit
|
||||
if neighbour_neighbours.len() == config.r_cap {
|
||||
let mut n = neighbour_neighbours.to_vec();
|
||||
scratch.visited_list.clear();
|
||||
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
|
||||
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config);
|
||||
robust_prune(scratch, neighbour, &mut n, vecs, config);
|
||||
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap {
|
||||
neighbour_neighbours.push(sigma_i);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// RoarGraph's AcquireNeighbours algorithm is actually almost identical to Vamana/DiskANN's RobustPrune, but with fixed α = 1.0.
|
||||
// We replace Vamana's random initialization of the graph with Neighbourhood-Aware Projection from RoarGraph - there's no way to use a large enough
|
||||
// query set that I would be confident in using *only* RoarGraph's algorithm
|
||||
pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec<Vec<u32>>, query_knns_bwd: &Vec<Vec<u32>>, config: IndexBuildConfig, vecs: &VectorList) {
|
||||
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
|
||||
rng.shuffle(&mut sigmas);
|
||||
|
||||
// Iterate through graph vertices in a random order
|
||||
let rng = Mutex::new(rng.fork());
|
||||
sigmas.into_par_iter().for_each_init(|| (rng.lock().unwrap().fork(), Scratch::new(config)), |(rng, scratch), sigma_i| {
|
||||
scratch.visited.clear();
|
||||
scratch.visited_list.clear();
|
||||
scratch.neighbour_pre_buffer.clear();
|
||||
for &query_neighbour in query_knns[sigma_i as usize].iter() {
|
||||
for &projected_neighbour in query_knns_bwd[query_neighbour as usize].iter() {
|
||||
if scratch.visited.insert(projected_neighbour) {
|
||||
scratch.neighbour_pre_buffer.push(projected_neighbour);
|
||||
}
|
||||
}
|
||||
}
|
||||
rng.shuffle(&mut scratch.neighbour_pre_buffer);
|
||||
scratch.neighbour_pre_buffer.truncate(config.maxc * 2);
|
||||
for (i, &projected_neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
|
||||
let score = fast_dot(&vecs[sigma_i as usize], &vecs[projected_neighbour as usize], &vecs[scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()] as usize]);
|
||||
scratch.visited_list.push((projected_neighbour, score));
|
||||
}
|
||||
let mut neighbours = graph.out_neighbours_mut(sigma_i);
|
||||
robust_prune(scratch, sigma_i, &mut *neighbours, vecs, config);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig) {
|
||||
let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect();
|
||||
rng.shuffle(&mut sigmas);
|
||||
|
||||
// Iterate through graph vertices in a random order
|
||||
let rng = Mutex::new(rng.fork());
|
||||
sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
|
||||
let mut neighbours = graph.out_neighbours_mut(sigma_i);
|
||||
let mut i = 0;
|
||||
while neighbours.len() < config.r_cap && i < 100 {
|
||||
let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
|
||||
let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
|
||||
if !neighbours.contains(&projected_neighbour) {
|
||||
neighbours.push(projected_neighbour);
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn random_fill_graph(rng: &mut Rng, graph: &mut IndexGraph, r: usize) {
|
||||
let rng = Mutex::new(rng.fork());
|
||||
(0..graph.graph.len() as u32).into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, i| {
|
||||
let mut neighbours = graph.out_neighbours_mut(i);
|
||||
while neighbours.len() < r {
|
||||
let next = rng.u32(0..(graph.graph.len() as u32));
|
||||
if !neighbours.contains(&next) {
|
||||
neighbours.push(next);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub struct Timer(&'static str, std::time::Instant);
|
||||
|
||||
impl Timer {
|
||||
pub fn new(name: &'static str) -> Self {
|
||||
Timer(name, std::time::Instant::now())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Timer {
|
||||
fn drop(&mut self) {
|
||||
println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32());
|
||||
}
|
||||
}
|
121
diskann/src/main.rs
Normal file
121
diskann/src/main.rs
Normal file
@ -0,0 +1,121 @@
|
||||
#![feature(test)]
|
||||
#![feature(pointer_is_aligned_to)]
|
||||
|
||||
extern crate test;
|
||||
|
||||
use std::{io::Read, time::Instant};
|
||||
use anyhow::Result;
|
||||
use half::f16;
|
||||
|
||||
use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, dot, VectorList, self}, Timer};
|
||||
use simsimd::SpatialSimilarity;
|
||||
|
||||
const D_EMB: usize = 1152;
|
||||
|
||||
fn load_file(path: &str, trunc: Option<usize>) -> Result<VectorList> {
|
||||
let mut input = std::fs::File::open(path)?;
|
||||
let mut buf = Vec::new();
|
||||
input.read_to_end(&mut buf)?;
|
||||
// TODO: this is not particularly efficient
|
||||
let f16s = bytemuck::cast_slice::<_, f16>(&buf)[0..trunc.unwrap_or(buf.len()/2)].iter().copied().collect();
|
||||
Ok(VectorList::from_f16s(f16s, D_EMB))
|
||||
}
|
||||
|
||||
const PQ_TEST_SIZE: usize = 1000;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
{
|
||||
let file = std::fs::File::open("opq.msgpack")?;
|
||||
let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?;
|
||||
let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
|
||||
let codes = codec.quantize_batch(&input);
|
||||
println!("{:?}", codes);
|
||||
let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
|
||||
let query = codec.preprocess_query(&raw_query);
|
||||
let mut real_scores = vec![];
|
||||
for i in 0..PQ_TEST_SIZE {
|
||||
real_scores.push(SpatialSimilarity::dot(&raw_query, &input[i * D_EMB .. (i + 1) * D_EMB]).unwrap() as f32);
|
||||
}
|
||||
let pq_scores = codec.asymmetric_dot_product(&query, &codes);
|
||||
for (x, y) in real_scores.iter().zip(pq_scores.iter()) {
|
||||
println!("{} {} {} {}", x, y, x - y, (x - y) / x);
|
||||
}
|
||||
}
|
||||
|
||||
let mut rng = fastrand::Rng::with_seed(1);
|
||||
|
||||
let n = 100000;
|
||||
let vecs = {
|
||||
let _timer = Timer::new("loaded vectors");
|
||||
|
||||
&load_file("embeddings.bin", Some(D_EMB * n))?
|
||||
};
|
||||
|
||||
let (graph, medioid) = {
|
||||
let _timer = Timer::new("index built");
|
||||
|
||||
let mut config = IndexBuildConfig {
|
||||
r: 64,
|
||||
r_cap: 80,
|
||||
l: 128,
|
||||
maxc: 750,
|
||||
alpha: 65536,
|
||||
};
|
||||
|
||||
let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap);
|
||||
|
||||
let medioid = medioid(&vecs);
|
||||
|
||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||
config.alpha = 58000;
|
||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||
|
||||
(graph, medioid)
|
||||
};
|
||||
|
||||
let mut edge_ctr = 0;
|
||||
|
||||
for adjlist in graph.graph.iter() {
|
||||
edge_ctr += adjlist.read().unwrap().len();
|
||||
}
|
||||
|
||||
println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32);
|
||||
|
||||
let time = Instant::now();
|
||||
let mut recall = 0;
|
||||
let mut cmps_ctr = 0;
|
||||
let mut cmps = vec![];
|
||||
|
||||
let mut config = IndexBuildConfig {
|
||||
r: 64,
|
||||
r_cap: 64,
|
||||
l: 50,
|
||||
alpha: 65536,
|
||||
maxc: 0,
|
||||
};
|
||||
|
||||
let mut scratch = Scratch::new(config);
|
||||
|
||||
for (i, vec) in tqdm::tqdm(vecs.iter().enumerate()) {
|
||||
let ctr = greedy_search(&mut scratch, medioid, &vec, &vecs, &graph, config);
|
||||
cmps_ctr += ctr.distances;
|
||||
cmps.push(ctr.distances);
|
||||
if scratch.neighbour_buffer.ids[0] == (i as u32) {
|
||||
recall += 1;
|
||||
}
|
||||
}
|
||||
|
||||
cmps.sort();
|
||||
|
||||
let end = time.elapsed();
|
||||
|
||||
println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n);
|
||||
println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n);
|
||||
println!("median comparisons: {}", cmps[cmps.len() / 2]);
|
||||
//println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries);
|
||||
println!("{} QPS", n as f32 / end.as_secs_f32());
|
||||
|
||||
Ok(())
|
||||
}
|
449
diskann/src/vector.rs
Normal file
449
diskann/src/vector.rs
Normal file
@ -0,0 +1,449 @@
|
||||
use core::f32;
|
||||
|
||||
use half::f16;
|
||||
use simsimd::SpatialSimilarity;
|
||||
use fastrand::Rng;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use tracing_subscriber::field::RecordFields;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Vector(Vec<f16>);
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVector(Vec<f32>);
|
||||
|
||||
pub type VectorRef<'a> = &'a [f16];
|
||||
pub type QVectorRef<'a> = &'a [u8];
|
||||
pub type SVectorRef<'a> = &'a [f32];
|
||||
|
||||
impl SVector {
|
||||
pub fn zero(d: usize) -> Self {
|
||||
SVector(vec![0.0; d])
|
||||
}
|
||||
pub fn half(&self) -> Vector {
|
||||
Vector(self.0.iter().map(|a| f16::from_f32(*a)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
fn box_muller(rng: &mut Rng) -> f32 {
|
||||
loop {
|
||||
let u = rng.f32();
|
||||
let v = rng.f32();
|
||||
let x = (v * std::f32::consts::TAU).cos() * (-2.0 * u.ln()).sqrt();
|
||||
if x.is_finite() {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Vector {
|
||||
pub fn zero(d: usize) -> Self {
|
||||
Vector(vec![f16::from_f32(0.0); d])
|
||||
}
|
||||
|
||||
pub fn randn(rng: &mut Rng, d: usize) -> Self {
|
||||
Vector(Vec::from_iter((0..d).map(|_| f16::from_f32(box_muller(rng)))))
|
||||
}
|
||||
}
|
||||
|
||||
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
|
||||
const SCALE: f32 = 281474976710656.0;
|
||||
const SCALE_F64: f64 = SCALE as f64;
|
||||
|
||||
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> f32 {
|
||||
// safety is not real
|
||||
(simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap()) as f32
|
||||
}
|
||||
|
||||
pub fn to_svector(vec: VectorRef) -> SVector {
|
||||
SVector(vec.iter().map(|a| a.to_f32()).collect())
|
||||
}
|
||||
|
||||
impl<'a> std::ops::AddAssign<VectorRef<'a>> for SVector {
|
||||
fn add_assign(&mut self, other: VectorRef<'a>) {
|
||||
self.0.iter_mut().zip(other.iter()).for_each(|(a, b)| *a += b.to_f32());
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Div<f32> for SVector {
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, b: f32) -> Self::Output {
|
||||
SVector(self.0.iter().map(|a| a / b).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for Vector {
|
||||
type Target = [f16];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for SVector {
|
||||
type Target = [f32];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<&SVector> for SVector {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: &Self) -> Self::Output {
|
||||
SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a + b).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<&SVector> for SVector {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: &Self) -> Self::Output {
|
||||
SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a - b).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::AddAssign for SVector {
|
||||
fn add_assign(&mut self, other: Self) {
|
||||
self.0.iter_mut().zip(other.0.iter()).for_each(|(a, b)| *a += b);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<f32> for SVector {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: f32) -> Self {
|
||||
SVector(self.0.iter().map(|a| *a * other).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorList {
|
||||
pub d_emb: usize,
|
||||
pub length: usize,
|
||||
pub data: Vec<f16>
|
||||
}
|
||||
|
||||
impl std::ops::Index<usize> for VectorList {
|
||||
type Output = [f16];
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
&self.data[index * self.d_emb..(index + 1) * self.d_emb]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorListIterator<'a> {
|
||||
list: &'a VectorList,
|
||||
index: usize
|
||||
}
|
||||
|
||||
impl<'a> Iterator for VectorListIterator<'a> {
|
||||
type Item = VectorRef<'a>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index < self.list.len() {
|
||||
let ret = &self.list[self.index];
|
||||
self.index += 1;
|
||||
Some(ret)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl VectorList {
|
||||
pub fn len(&self) -> usize {
|
||||
self.length
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> VectorListIterator {
|
||||
VectorListIterator {
|
||||
list: self,
|
||||
index: 0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty(d: usize) -> Self {
|
||||
VectorList {
|
||||
d_emb: d,
|
||||
length: 0,
|
||||
data: Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_f16s(f16s: Vec<f16>, d: usize) -> Self {
|
||||
assert!(f16s.len() % d == 0);
|
||||
VectorList {
|
||||
d_emb: d,
|
||||
length: f16s.len() / d,
|
||||
data: f16s
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, vec: VectorRef) {
|
||||
self.length += 1;
|
||||
self.data.extend_from_slice(vec);
|
||||
}
|
||||
}
|
||||
|
||||
// SimSIMD has its own, but ours prefetches concurrently, is unrolled more, ignores inconveniently-sized vectors and does a cheaper reduction
|
||||
// Also, we return an int because floats are annoying (not Ord)
|
||||
// On Tiger Lake (i5-1135G7) we have about a 3x performance advantage ignoring the prefetching
|
||||
// (it would be better to use AVX512 for said CPU but this also has to run on Zen 3)
|
||||
pub fn fast_dot(x: VectorRef, y: VectorRef, prefetch: VectorRef) -> i64 {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
debug_assert!(x.len() == y.len());
|
||||
debug_assert!(prefetch.len() == x.len());
|
||||
debug_assert!(x.len() % 64 == 0);
|
||||
|
||||
// safety is not real
|
||||
// it's probably fine I guess
|
||||
unsafe {
|
||||
let mut x_ptr = x.as_ptr();
|
||||
let mut y_ptr = y.as_ptr();
|
||||
let end = x_ptr.add(x.len());
|
||||
let mut prefetch_ptr = prefetch.as_ptr();
|
||||
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
let mut acc4 = _mm256_setzero_ps();
|
||||
|
||||
while x_ptr < end {
|
||||
// fetch chunks and prefetch next vector
|
||||
let x1 = _mm256_loadu_si256(x_ptr as *const __m256i);
|
||||
let y1 = _mm256_loadu_si256(y_ptr as *const __m256i);
|
||||
let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i);
|
||||
let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i);
|
||||
// technically, we only have to do this once per cache line but I don't care enough to test every way to optimize this
|
||||
_mm_prefetch(prefetch_ptr as *const i8, _MM_HINT_T0);
|
||||
x_ptr = x_ptr.add(32); // move 16 f16s at a time
|
||||
y_ptr = y_ptr.add(32);
|
||||
prefetch_ptr = prefetch_ptr.add(32);
|
||||
|
||||
// unpack f32 to f16
|
||||
let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0));
|
||||
let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1));
|
||||
let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0));
|
||||
let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1));
|
||||
let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0));
|
||||
let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1));
|
||||
let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0));
|
||||
let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1));
|
||||
|
||||
acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1);
|
||||
acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2);
|
||||
acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3);
|
||||
acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4);
|
||||
}
|
||||
|
||||
// reduce
|
||||
let acc1 = _mm256_add_ps(acc1, acc2);
|
||||
let acc2 = _mm256_add_ps(acc3, acc4);
|
||||
|
||||
let hsum = _mm256_hadd_ps(acc1, acc2);
|
||||
let hsum_lo = _mm256_extractf128_ps(hsum, 0);
|
||||
let hsum_hi = _mm256_extractf128_ps(hsum, 1);
|
||||
let hsum = _mm_add_ps(hsum_lo, hsum_hi);
|
||||
|
||||
let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32);
|
||||
(floatval * SCALE) as i64
|
||||
}
|
||||
}
|
||||
|
||||
// same as above, without prefetch pointer
|
||||
pub fn fast_dot_noprefetch(x: VectorRef, y: VectorRef) -> i64 {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
debug_assert!(x.len() == y.len());
|
||||
debug_assert!(x.len() % 64 == 0);
|
||||
|
||||
unsafe {
|
||||
let mut x_ptr = x.as_ptr();
|
||||
let mut y_ptr = y.as_ptr();
|
||||
let end = x_ptr.add(x.len());
|
||||
|
||||
let mut acc1 = _mm256_setzero_ps();
|
||||
let mut acc2 = _mm256_setzero_ps();
|
||||
let mut acc3 = _mm256_setzero_ps();
|
||||
let mut acc4 = _mm256_setzero_ps();
|
||||
|
||||
while x_ptr < end {
|
||||
let x1 = _mm256_loadu_si256(x_ptr as *const __m256i);
|
||||
let y1 = _mm256_loadu_si256(y_ptr as *const __m256i);
|
||||
let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i);
|
||||
let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i);
|
||||
x_ptr = x_ptr.add(32);
|
||||
y_ptr = y_ptr.add(32);
|
||||
|
||||
let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0));
|
||||
let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1));
|
||||
let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0));
|
||||
let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1));
|
||||
let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0));
|
||||
let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1));
|
||||
let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0));
|
||||
let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1));
|
||||
|
||||
acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1);
|
||||
acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2);
|
||||
acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3);
|
||||
acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4);
|
||||
}
|
||||
|
||||
// reduce
|
||||
let acc1 = _mm256_add_ps(acc1, acc2);
|
||||
let acc2 = _mm256_add_ps(acc3, acc4);
|
||||
|
||||
let hsum = _mm256_hadd_ps(acc1, acc2);
|
||||
let hsum_lo = _mm256_extractf128_ps(hsum, 0);
|
||||
let hsum_hi = _mm256_extractf128_ps(hsum, 1);
|
||||
let hsum = _mm_add_ps(hsum_lo, hsum_hi);
|
||||
|
||||
let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32);
|
||||
(floatval * SCALE) as i64
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProductQuantizer {
|
||||
centroids: Vec<f32>,
|
||||
transform: Vec<f32>, // D*D orthonormal matrix
|
||||
pub n_dims_per_code: usize,
|
||||
pub n_dims: usize
|
||||
}
|
||||
|
||||
// chunk * centroid_index
|
||||
pub struct DistanceLUT(Vec<f32>);
|
||||
|
||||
impl ProductQuantizer {
|
||||
pub fn apply_transform(&self, x: &[f32]) -> Vec<f32> {
|
||||
let dim = self.n_dims;
|
||||
let n_vectors = x.len() / dim;
|
||||
let mut transformed = vec![0.0; n_vectors * dim];
|
||||
// transform_matrix (D * D) @ batch.T (D * B)
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(dim, dim, n_vectors, 1.0, self.transform.as_ptr(), dim as isize, 1, x.as_ptr(), 1, dim as isize, 0.0, transformed.as_mut_ptr(), 1, dim as isize);
|
||||
}
|
||||
transformed
|
||||
}
|
||||
|
||||
pub fn quantize_batch(&self, x: &[f32]) -> Vec<u8> {
|
||||
// x is B * D
|
||||
let dim = self.n_dims;
|
||||
assert_eq!(dim * dim, self.transform.len());
|
||||
let n_vectors = x.len() / dim;
|
||||
let n_centroids = self.centroids.len() / dim;
|
||||
assert!(n_centroids <= 256);
|
||||
let transformed = self.apply_transform(&x); // B * D, as we write sgemm result in a weird order
|
||||
let mut codes = vec![0; n_vectors * dim / self.n_dims_per_code];
|
||||
let vec_len_codes = dim / self.n_dims_per_code;
|
||||
|
||||
// B * C buffer of similarity of each vector to each centroid, within subspace
|
||||
let mut scratch = vec![0.0; n_vectors * n_centroids];
|
||||
|
||||
for i in 0..(dim / self.n_dims_per_code) {
|
||||
let offset = i * self.n_dims_per_code;
|
||||
// transformed_batch[:, range] (B * D_r) @ centroids[:, range].T (D_r * C)
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(n_vectors, self.n_dims_per_code, n_centroids, 1.0, transformed.as_ptr().add(offset), dim as isize, 1, self.centroids.as_ptr().add(offset), 1, dim as isize, 0.0, scratch.as_mut_ptr(), n_centroids as isize, 1);
|
||||
}
|
||||
// assign this component to best centroid
|
||||
for i_vec in 0..n_vectors {
|
||||
let mut best = f32::NEG_INFINITY;
|
||||
for i_centroid in 0..n_centroids {
|
||||
let score = scratch[i_vec * n_centroids + i_centroid];
|
||||
if score > best {
|
||||
best = score;
|
||||
codes[i_vec * vec_len_codes + i] = i_centroid as u8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
codes
|
||||
}
|
||||
|
||||
// not particularly performance-sensitive right now; do unbatched
|
||||
pub fn preprocess_query(&self, query: &[f32]) -> DistanceLUT {
|
||||
let transformed = self.apply_transform(query);
|
||||
let n_chunks = self.n_dims / self.n_dims_per_code;
|
||||
let n_centroids = self.centroids.len() / self.n_dims;
|
||||
let mut lut = Vec::with_capacity(n_chunks * n_centroids);
|
||||
|
||||
for i in 0..n_chunks {
|
||||
let vec_component = &transformed[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code];
|
||||
for j in 0..n_centroids {
|
||||
let centroid = &self.centroids[j * self.n_dims..(j + 1) * self.n_dims];
|
||||
let centroid_component = ¢roid[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code];
|
||||
let score = SpatialSimilarity::dot(vec_component, centroid_component).unwrap();
|
||||
lut.push(score as f32);
|
||||
}
|
||||
}
|
||||
|
||||
DistanceLUT(lut)
|
||||
}
|
||||
|
||||
// compute dot products of query against product-quantized vectors
|
||||
pub fn asymmetric_dot_product(&self, query: &DistanceLUT, pq_vectors: &[u8]) -> Vec<i64> {
|
||||
let n_chunks = self.n_dims / self.n_dims_per_code;
|
||||
let n_vectors = pq_vectors.len() / n_chunks;
|
||||
let mut scores = vec![0.0; n_vectors];
|
||||
let n_centroids = self.centroids.len() / self.n_dims;
|
||||
|
||||
for i in 0..n_chunks {
|
||||
for j in 0..n_vectors {
|
||||
let code = pq_vectors[j * n_chunks + i];
|
||||
let chunk_score = query.0[i * n_centroids + code as usize];
|
||||
scores[j] += chunk_score;
|
||||
}
|
||||
}
|
||||
|
||||
// I have no idea why but we somehow have significant degradation in search quality
|
||||
// if this accumulates in integers. As such, do floats and convert at the end.
|
||||
// I'm sure there are fascinating reasons for this, but God is dead, God remains dead, etc.
|
||||
scores.into_iter().map(|x| (x * SCALE) as i64).collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scale_dot_result(x: f64) -> i64 {
|
||||
(x * SCALE_F64) as i64
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod bench {
|
||||
use super::*;
|
||||
use test::Bencher;
|
||||
|
||||
#[bench]
|
||||
fn bench_dot(be: &mut Bencher) {
|
||||
let mut rng = fastrand::Rng::with_seed(1);
|
||||
let a = Vector::randn(&mut rng, 1024);
|
||||
let b = Vector::randn(&mut rng, 1024);
|
||||
be.iter(|| {
|
||||
dot(&a, &b)
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_fastdot(be: &mut Bencher) {
|
||||
let mut rng = fastrand::Rng::with_seed(1);
|
||||
let a = Vector::randn(&mut rng, 1024);
|
||||
let b = Vector::randn(&mut rng, 1024);
|
||||
be.iter(|| {
|
||||
fast_dot(&a, &b, &a)
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_fastdot_noprefetch(be: &mut Bencher) {
|
||||
let mut rng = fastrand::Rng::with_seed(1);
|
||||
let a = Vector::randn(&mut rng, 1024);
|
||||
let b = Vector::randn(&mut rng, 1024);
|
||||
be.iter(|| {
|
||||
fast_dot_noprefetch(&a, &b)
|
||||
});
|
||||
}
|
||||
}
|
44
diskann/vec_dist.py
Normal file
44
diskann/vec_dist.py
Normal file
@ -0,0 +1,44 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
A = 0.6
|
||||
LOG_A = np.log(A)
|
||||
|
||||
def scale(xs):
|
||||
return np.sign(xs) * (np.log(np.abs(xs) + A) - LOG_A)
|
||||
|
||||
n_dims = 1152
|
||||
n_used_dims = 32
|
||||
data = np.frombuffer(open("embeddings.bin", "rb").read(), dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # TODO
|
||||
|
||||
# Create histogram bins
|
||||
n_bins = 256
|
||||
s = __import__("math").sqrt(n_dims)
|
||||
hist_range = (-1.2, 1.2)
|
||||
histogram_data = np.zeros((n_used_dims, n_bins))
|
||||
|
||||
# Calculate histograms for each dimension
|
||||
for dim in range(n_used_dims):
|
||||
dbd = data[:, dim]
|
||||
dbd = (dbd - np.mean(dbd)) / np.std(dbd)
|
||||
dbd = scale(dbd)
|
||||
hist, _ = np.histogram(dbd, bins=n_bins, range=hist_range, density=True)
|
||||
histogram_data[dim] = hist
|
||||
|
||||
# Create heatmap
|
||||
plt.figure(figsize=(12, 8))
|
||||
sns.heatmap(histogram_data,
|
||||
cmap='viridis',
|
||||
xticklabels=np.linspace(hist_range[0], hist_range[1], n_bins),
|
||||
yticklabels=range(n_used_dims),
|
||||
cbar_kws={'label': 'Density'})
|
||||
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Dimension')
|
||||
plt.title('Distribution Heatmap of First 16 Dimensions')
|
||||
|
||||
# Adjust layout to prevent label cutoff
|
||||
plt.tight_layout()
|
||||
|
||||
plt.show()
|
Loading…
x
Reference in New Issue
Block a user