mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-30 23:12:58 +00:00 
			
		
		
		
	release early draft of index code
This commit is contained in:
		
							
								
								
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -9,4 +9,9 @@ node_modules/* | ||||
| node_modules | ||||
| *sqlite3* | ||||
| thumbtemp | ||||
| mse-test-db-small | ||||
| mse-test-db-small | ||||
| clipfront2/static/bg* | ||||
| diskann/target | ||||
| *.bin | ||||
| *.msgpack | ||||
| */flamegraph.svg | ||||
|   | ||||
							
								
								
									
										748
									
								
								diskann/Cargo.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										748
									
								
								diskann/Cargo.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,748 @@ | ||||
| # This file is automatically @generated by Cargo. | ||||
| # It is not intended for manual editing. | ||||
| version = 3 | ||||
|  | ||||
| [[package]] | ||||
| name = "anyhow" | ||||
| version = "1.0.93" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" | ||||
|  | ||||
| [[package]] | ||||
| name = "autocfg" | ||||
| version = "1.4.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" | ||||
|  | ||||
| [[package]] | ||||
| name = "bitflags" | ||||
| version = "1.3.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" | ||||
|  | ||||
| [[package]] | ||||
| name = "bitflags" | ||||
| version = "2.6.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" | ||||
|  | ||||
| [[package]] | ||||
| name = "bitvec" | ||||
| version = "1.0.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" | ||||
| dependencies = [ | ||||
|  "funty", | ||||
|  "radium", | ||||
|  "tap", | ||||
|  "wyz", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "bytemuck" | ||||
| version = "1.20.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" | ||||
| dependencies = [ | ||||
|  "bytemuck_derive", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "bytemuck_derive" | ||||
| version = "1.7.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "byteorder" | ||||
| version = "1.5.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" | ||||
|  | ||||
| [[package]] | ||||
| name = "cc" | ||||
| version = "1.2.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" | ||||
| dependencies = [ | ||||
|  "shlex", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "cfg-if" | ||||
| version = "1.0.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" | ||||
|  | ||||
| [[package]] | ||||
| name = "crossbeam-deque" | ||||
| version = "0.8.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" | ||||
| dependencies = [ | ||||
|  "crossbeam-epoch", | ||||
|  "crossbeam-utils", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "crossbeam-epoch" | ||||
| version = "0.9.18" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" | ||||
| dependencies = [ | ||||
|  "crossbeam-utils", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "crossbeam-utils" | ||||
| version = "0.8.20" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" | ||||
|  | ||||
| [[package]] | ||||
| name = "crossterm" | ||||
| version = "0.25.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67" | ||||
| dependencies = [ | ||||
|  "bitflags 1.3.2", | ||||
|  "crossterm_winapi", | ||||
|  "libc", | ||||
|  "mio", | ||||
|  "parking_lot", | ||||
|  "signal-hook", | ||||
|  "signal-hook-mio", | ||||
|  "winapi", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "crossterm_winapi" | ||||
| version = "0.9.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" | ||||
| dependencies = [ | ||||
|  "winapi", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "crunchy" | ||||
| version = "0.2.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" | ||||
|  | ||||
| [[package]] | ||||
| name = "diskann" | ||||
| version = "0.1.0" | ||||
| dependencies = [ | ||||
|  "anyhow", | ||||
|  "bitvec", | ||||
|  "bytemuck", | ||||
|  "fastrand", | ||||
|  "foldhash", | ||||
|  "half", | ||||
|  "matrixmultiply", | ||||
|  "rayon", | ||||
|  "rmp-serde", | ||||
|  "serde", | ||||
|  "simsimd", | ||||
|  "tqdm", | ||||
|  "tracing", | ||||
|  "tracing-subscriber", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "either" | ||||
| version = "1.13.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" | ||||
|  | ||||
| [[package]] | ||||
| name = "fastrand" | ||||
| version = "2.2.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" | ||||
|  | ||||
| [[package]] | ||||
| name = "foldhash" | ||||
| version = "0.1.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" | ||||
|  | ||||
| [[package]] | ||||
| name = "funty" | ||||
| version = "2.0.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" | ||||
|  | ||||
| [[package]] | ||||
| name = "half" | ||||
| version = "2.4.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" | ||||
| dependencies = [ | ||||
|  "bytemuck", | ||||
|  "cfg-if", | ||||
|  "crunchy", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "lazy_static" | ||||
| version = "1.5.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" | ||||
|  | ||||
| [[package]] | ||||
| name = "libc" | ||||
| version = "0.2.164" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" | ||||
|  | ||||
| [[package]] | ||||
| name = "lock_api" | ||||
| version = "0.4.12" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" | ||||
| dependencies = [ | ||||
|  "autocfg", | ||||
|  "scopeguard", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "log" | ||||
| version = "0.4.22" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" | ||||
|  | ||||
| [[package]] | ||||
| name = "matrixmultiply" | ||||
| version = "0.3.9" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" | ||||
| dependencies = [ | ||||
|  "autocfg", | ||||
|  "rawpointer", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "mio" | ||||
| version = "0.8.11" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" | ||||
| dependencies = [ | ||||
|  "libc", | ||||
|  "log", | ||||
|  "wasi", | ||||
|  "windows-sys", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "nu-ansi-term" | ||||
| version = "0.46.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" | ||||
| dependencies = [ | ||||
|  "overload", | ||||
|  "winapi", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "num-traits" | ||||
| version = "0.2.19" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" | ||||
| dependencies = [ | ||||
|  "autocfg", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "once_cell" | ||||
| version = "1.20.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" | ||||
|  | ||||
| [[package]] | ||||
| name = "overload" | ||||
| version = "0.1.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" | ||||
|  | ||||
| [[package]] | ||||
| name = "parking_lot" | ||||
| version = "0.12.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" | ||||
| dependencies = [ | ||||
|  "lock_api", | ||||
|  "parking_lot_core", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "parking_lot_core" | ||||
| version = "0.9.10" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" | ||||
| dependencies = [ | ||||
|  "cfg-if", | ||||
|  "libc", | ||||
|  "redox_syscall", | ||||
|  "smallvec", | ||||
|  "windows-targets 0.52.6", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "paste" | ||||
| version = "1.0.15" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" | ||||
|  | ||||
| [[package]] | ||||
| name = "pin-project-lite" | ||||
| version = "0.2.15" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" | ||||
|  | ||||
| [[package]] | ||||
| name = "proc-macro2" | ||||
| version = "1.0.89" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" | ||||
| dependencies = [ | ||||
|  "unicode-ident", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "quote" | ||||
| version = "1.0.37" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "radium" | ||||
| version = "0.7.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" | ||||
|  | ||||
| [[package]] | ||||
| name = "rawpointer" | ||||
| version = "0.2.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" | ||||
|  | ||||
| [[package]] | ||||
| name = "rayon" | ||||
| version = "1.10.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" | ||||
| dependencies = [ | ||||
|  "either", | ||||
|  "rayon-core", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "rayon-core" | ||||
| version = "1.12.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" | ||||
| dependencies = [ | ||||
|  "crossbeam-deque", | ||||
|  "crossbeam-utils", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "redox_syscall" | ||||
| version = "0.5.7" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" | ||||
| dependencies = [ | ||||
|  "bitflags 2.6.0", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "rmp" | ||||
| version = "0.8.14" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" | ||||
| dependencies = [ | ||||
|  "byteorder", | ||||
|  "num-traits", | ||||
|  "paste", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "rmp-serde" | ||||
| version = "1.3.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" | ||||
| dependencies = [ | ||||
|  "byteorder", | ||||
|  "rmp", | ||||
|  "serde", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "scopeguard" | ||||
| version = "1.2.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" | ||||
|  | ||||
| [[package]] | ||||
| name = "serde" | ||||
| version = "1.0.215" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" | ||||
| dependencies = [ | ||||
|  "serde_derive", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "serde_derive" | ||||
| version = "1.0.215" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "sharded-slab" | ||||
| version = "0.1.7" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" | ||||
| dependencies = [ | ||||
|  "lazy_static", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "shlex" | ||||
| version = "1.3.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" | ||||
|  | ||||
| [[package]] | ||||
| name = "signal-hook" | ||||
| version = "0.3.17" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" | ||||
| dependencies = [ | ||||
|  "libc", | ||||
|  "signal-hook-registry", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "signal-hook-mio" | ||||
| version = "0.2.4" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" | ||||
| dependencies = [ | ||||
|  "libc", | ||||
|  "mio", | ||||
|  "signal-hook", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "signal-hook-registry" | ||||
| version = "1.4.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" | ||||
| dependencies = [ | ||||
|  "libc", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "simsimd" | ||||
| version = "6.0.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "be2ad0164e13e58a994d3dd1ff57d44cee87c445708e3acea7ad4f03a47092ce" | ||||
| dependencies = [ | ||||
|  "cc", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "smallvec" | ||||
| version = "1.13.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" | ||||
|  | ||||
| [[package]] | ||||
| name = "syn" | ||||
| version = "2.0.87" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "unicode-ident", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tap" | ||||
| version = "1.0.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" | ||||
|  | ||||
| [[package]] | ||||
| name = "thread_local" | ||||
| version = "1.1.8" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" | ||||
| dependencies = [ | ||||
|  "cfg-if", | ||||
|  "once_cell", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tqdm" | ||||
| version = "0.7.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "aa2d2932240205a99b65f15d9861992c95fbb8c9fb280b3a1f17a92db6dc611f" | ||||
| dependencies = [ | ||||
|  "anyhow", | ||||
|  "crossterm", | ||||
|  "once_cell", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tracing" | ||||
| version = "0.1.40" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" | ||||
| dependencies = [ | ||||
|  "pin-project-lite", | ||||
|  "tracing-attributes", | ||||
|  "tracing-core", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tracing-attributes" | ||||
| version = "0.1.27" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tracing-core" | ||||
| version = "0.1.32" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" | ||||
| dependencies = [ | ||||
|  "once_cell", | ||||
|  "valuable", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tracing-log" | ||||
| version = "0.2.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" | ||||
| dependencies = [ | ||||
|  "log", | ||||
|  "once_cell", | ||||
|  "tracing-core", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tracing-subscriber" | ||||
| version = "0.3.18" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" | ||||
| dependencies = [ | ||||
|  "nu-ansi-term", | ||||
|  "sharded-slab", | ||||
|  "smallvec", | ||||
|  "thread_local", | ||||
|  "tracing-core", | ||||
|  "tracing-log", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "unicode-ident" | ||||
| version = "1.0.13" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" | ||||
|  | ||||
| [[package]] | ||||
| name = "valuable" | ||||
| version = "0.1.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" | ||||
|  | ||||
| [[package]] | ||||
| name = "wasi" | ||||
| version = "0.11.0+wasi-snapshot-preview1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" | ||||
|  | ||||
| [[package]] | ||||
| name = "winapi" | ||||
| version = "0.3.9" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" | ||||
| dependencies = [ | ||||
|  "winapi-i686-pc-windows-gnu", | ||||
|  "winapi-x86_64-pc-windows-gnu", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "winapi-i686-pc-windows-gnu" | ||||
| version = "0.4.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" | ||||
|  | ||||
| [[package]] | ||||
| name = "winapi-x86_64-pc-windows-gnu" | ||||
| version = "0.4.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows-sys" | ||||
| version = "0.48.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" | ||||
| dependencies = [ | ||||
|  "windows-targets 0.48.5", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "windows-targets" | ||||
| version = "0.48.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" | ||||
| dependencies = [ | ||||
|  "windows_aarch64_gnullvm 0.48.5", | ||||
|  "windows_aarch64_msvc 0.48.5", | ||||
|  "windows_i686_gnu 0.48.5", | ||||
|  "windows_i686_msvc 0.48.5", | ||||
|  "windows_x86_64_gnu 0.48.5", | ||||
|  "windows_x86_64_gnullvm 0.48.5", | ||||
|  "windows_x86_64_msvc 0.48.5", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "windows-targets" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" | ||||
| dependencies = [ | ||||
|  "windows_aarch64_gnullvm 0.52.6", | ||||
|  "windows_aarch64_msvc 0.52.6", | ||||
|  "windows_i686_gnu 0.52.6", | ||||
|  "windows_i686_gnullvm", | ||||
|  "windows_i686_msvc 0.52.6", | ||||
|  "windows_x86_64_gnu 0.52.6", | ||||
|  "windows_x86_64_gnullvm 0.52.6", | ||||
|  "windows_x86_64_msvc 0.52.6", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_aarch64_gnullvm" | ||||
| version = "0.48.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_aarch64_gnullvm" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_aarch64_msvc" | ||||
| version = "0.48.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_aarch64_msvc" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_i686_gnu" | ||||
| version = "0.48.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_i686_gnu" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_i686_gnullvm" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_i686_msvc" | ||||
| version = "0.48.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_i686_msvc" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_x86_64_gnu" | ||||
| version = "0.48.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_x86_64_gnu" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_x86_64_gnullvm" | ||||
| version = "0.48.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_x86_64_gnullvm" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_x86_64_msvc" | ||||
| version = "0.48.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" | ||||
|  | ||||
| [[package]] | ||||
| name = "windows_x86_64_msvc" | ||||
| version = "0.52.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" | ||||
|  | ||||
| [[package]] | ||||
| name = "wyz" | ||||
| version = "0.5.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" | ||||
| dependencies = [ | ||||
|  "tap", | ||||
| ] | ||||
							
								
								
									
										28
									
								
								diskann/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								diskann/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | ||||
| [package] | ||||
| name = "diskann" | ||||
| version = "0.1.0" | ||||
| edition = "2021" | ||||
|  | ||||
| [dependencies] | ||||
| half = { version = "2", features = ["bytemuck"] } | ||||
| fastrand = "2" | ||||
| tracing = "0.1" | ||||
| tracing-subscriber = "0.3" | ||||
| simsimd = "6" | ||||
| foldhash = "0.1" | ||||
| bitvec = "1" | ||||
| tqdm = "0.7" | ||||
| anyhow = "1" | ||||
| bytemuck = { version = "1", features = ["extern_crate_alloc"] } | ||||
| serde = { version = "1", features = ["derive"] } | ||||
| rmp-serde = "1" | ||||
| rayon = "1" | ||||
| matrixmultiply = "0.3" | ||||
|  | ||||
| [lib] | ||||
| name = "diskann" | ||||
| path = "src/lib.rs" | ||||
|  | ||||
| [[bin]] | ||||
| name = "diskann" | ||||
| path = "src/main.rs" | ||||
							
								
								
									
										113
									
								
								diskann/aopq_train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								diskann/aopq_train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| import numpy as np | ||||
| import msgpack | ||||
| import math | ||||
| import torch | ||||
| from torch import autograd | ||||
| import faiss | ||||
| import tqdm | ||||
|  | ||||
| n_dims = 1152 | ||||
| output_code_size = 64 | ||||
| output_code_bits = 8 | ||||
| output_codebook_size = 2**output_code_bits | ||||
| n_dims_per_code = n_dims // output_code_size | ||||
| dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) | ||||
| queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) | ||||
| device = "cpu" | ||||
|  | ||||
| index = faiss.index_factory(n_dims, "HNSW32,SQfp16", faiss.METRIC_INNER_PRODUCT) | ||||
| index.train(queryset) | ||||
| index.add(queryset) | ||||
| print("index ready") | ||||
|  | ||||
| T = 64 | ||||
|  | ||||
| nearby_query_indices = torch.zeros((dataset.shape[0], T), dtype=torch.int32) | ||||
|  | ||||
| SEARCH_BATCH_SIZE = 1024 | ||||
|  | ||||
| for i in range(0, len(dataset), SEARCH_BATCH_SIZE): | ||||
|     res = index.search(dataset[i:i+SEARCH_BATCH_SIZE], T) | ||||
|     nearby_query_indices[i:i+SEARCH_BATCH_SIZE] = torch.tensor(res[1]) | ||||
|  | ||||
| print("query indices ready") | ||||
|  | ||||
| def pq_assign(centroids, batch): | ||||
|     quantized = torch.zeros_like(batch) | ||||
|  | ||||
|     # Assign to nearest centroid in each subspace | ||||
|     for dmin in range(0, n_dims, n_dims_per_code): | ||||
|         dmax = dmin + n_dims_per_code | ||||
|         similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T) | ||||
|         assignments = similarities.argmax(dim=1) | ||||
|         quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax] | ||||
|  | ||||
|     return quantized | ||||
|  | ||||
| # OOD-DiskANN (https://arxiv.org/abs/2211.12850) uses a more complicated scheme because it uses L2 norm | ||||
| # We only care about inner product so our quantization error (wrt a query) is just abs(dot(query, centroid - vector)) | ||||
| # Directly optimize for this (wrt top queries; it might actually be better to use a random sample instead?) | ||||
| def partition(vectors, centroids, projection, opt, queries, nearby_query_indices, k, max_iter=100, batch_size=4096): | ||||
|     n_vectors = len(vectors) | ||||
|     perm = torch.randperm(n_vectors, device=device) | ||||
|  | ||||
|     t = tqdm.trange(max_iter) | ||||
|     for iter in t: | ||||
|         total_loss = 0 | ||||
|         opt.zero_grad(set_to_none=True) | ||||
|  | ||||
|         for i in range(0, n_vectors, batch_size): | ||||
|             loss = torch.tensor(0.0, device=device) | ||||
|             batch = vectors[i:i+batch_size] @ projection | ||||
|             quantized = pq_assign(centroids, batch) | ||||
|             residuals = batch - quantized | ||||
|  | ||||
|             # for each index in our set of nearby queries | ||||
|             for j in range(0, nearby_query_indices.shape[1]): | ||||
|                 queries_for_batch_j = queries[nearby_query_indices[i:i+batch_size, j]] | ||||
|                 # minimize quantiation error in direction of query, i.e. mean abs(dot(query, centroid - vector)) | ||||
|                 # PyTorch won't do batched dot products cleanly, to spite me. Do componentwise multiplication and reduce. | ||||
|                 sg_errs = (queries_for_batch_j * residuals).sum(dim=-1) | ||||
|                 loss += torch.mean(torch.abs(sg_errs)) | ||||
|  | ||||
|             total_loss += loss.detach().item() | ||||
|             loss.backward() | ||||
|  | ||||
|         opt.step() | ||||
|  | ||||
|         t.set_description(f"loss: {total_loss:.4f}") | ||||
|  | ||||
| def random_ortho(dim): | ||||
|     h = torch.randn(dim, dim, device=device) | ||||
|     q, r = torch.linalg.qr(h) | ||||
|     return q | ||||
|  | ||||
| # non-parametric OPQ algorithm (roughly) | ||||
| # https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/opq_tr.pdf | ||||
| projection = random_ortho(n_dims) | ||||
| vectors = torch.tensor(dataset, device=device) | ||||
| queries = torch.tensor(queryset, device=device) | ||||
| perm = torch.randperm(len(vectors), device=device) | ||||
| centroids = vectors[perm[:output_codebook_size]] | ||||
| centroids.requires_grad = True | ||||
| opt = torch.optim.Adam([centroids], lr=0.001) | ||||
| for i in range(30): | ||||
|     # update centroids to minimize query-aware quantization loss | ||||
|     partition(vectors, centroids, projection, opt, queries, nearby_query_indices, output_codebook_size, max_iter=8) | ||||
|     # compute new projection as R = VU^T from XY^T = USV^T (SVD) | ||||
|     # where X is dataset vectors, Y is quantized dataset vectors | ||||
|     with torch.no_grad(): | ||||
|         y = pq_assign(centroids, vectors) | ||||
|         # paper uses D*N and not N*D in its descriptions for whatever reason (so we transpose when they don't) | ||||
|         u, s, vt = torch.linalg.svd(vectors.T @ y) | ||||
|         projection = vt.T @ u.T | ||||
|  | ||||
| print("done") | ||||
|  | ||||
| with open("opq.msgpack", "wb") as f: | ||||
|     msgpack.pack({ | ||||
|         "centroids": centroids.detach().cpu().numpy().flatten().tolist(), | ||||
|         "transform": projection.cpu().numpy().flatten().tolist(), | ||||
|         "n_dims_per_code": n_dims_per_code, | ||||
|         "n_dims": n_dims | ||||
|     }, f) | ||||
							
								
								
									
										491
									
								
								diskann/flamegraph.svg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										491
									
								
								diskann/flamegraph.svg
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| After Width: | Height: | Size: 139 KiB | 
							
								
								
									
										44
									
								
								diskann/opq_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								diskann/opq_test.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| import numpy as np | ||||
| import msgpack | ||||
| import math | ||||
| import torch | ||||
| import faiss | ||||
| import tqdm | ||||
|  | ||||
| n_dims = 1152 | ||||
| output_code_size = 64 | ||||
| output_code_bits = 8 | ||||
| output_codebook_size = 2**output_code_bits | ||||
| n_dims_per_code = n_dims // output_code_size | ||||
| dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) | ||||
| queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) | ||||
| device = "cpu" | ||||
|  | ||||
| def pq_assign(centroids, batch): | ||||
|     quantized = torch.zeros_like(batch) | ||||
|  | ||||
|     # Assign to nearest centroid in each subspace | ||||
|     for dmin in range(0, n_dims, n_dims_per_code): | ||||
|         dmax = dmin + n_dims_per_code | ||||
|         similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T) | ||||
|         assignments = similarities.argmax(dim=1) | ||||
|         quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax] | ||||
|  | ||||
|     return quantized | ||||
|  | ||||
| with open("opq.msgpack", "rb") as f: | ||||
|     data = msgpack.unpack(f) | ||||
|     centroids = torch.tensor(data["centroids"], device=device).reshape(2**output_code_bits, n_dims) | ||||
|     projection = torch.tensor(data["transform"], device=device).reshape(n_dims, n_dims) | ||||
|  | ||||
| vectors = torch.tensor(dataset, device=device) | ||||
| queries = torch.tensor(queryset, device=device) | ||||
|  | ||||
| sample_size = 64 | ||||
| qsample = pq_assign(centroids, vectors[:sample_size] @ projection) | ||||
| print(qsample) | ||||
| print(vectors[:sample_size]) | ||||
| exact_results = vectors[:sample_size] @ queries[0] | ||||
| approx_results = qsample @ (projection @ queries[0]) | ||||
| print(np.argsort(approx_results)) | ||||
| print(np.argsort(exact_results)) | ||||
							
								
								
									
										68
									
								
								diskann/rabitq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								diskann/rabitq.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,68 @@ | ||||
| # https://arxiv.org/pdf/2405.12497 | ||||
|  | ||||
| import numpy as np | ||||
| import msgpack | ||||
| import math | ||||
| import tqdm | ||||
|  | ||||
| n_dims = 1152 | ||||
| output_dims = 64*8 | ||||
| scale = 1 / math.sqrt(n_dims) | ||||
| dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) | ||||
| queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32) | ||||
| mean = np.mean(dataset, axis=0) | ||||
|  | ||||
| centered_dataset = dataset - mean | ||||
| norms = np.linalg.norm(centered_dataset, axis=1) | ||||
| centered_dataset = centered_dataset / norms[:, np.newaxis] | ||||
| print(centered_dataset) | ||||
|  | ||||
| sample = centered_dataset[:64] | ||||
|  | ||||
| def random_ortho(dim): | ||||
|     h = np.random.randn(dim, dim) | ||||
|     q, r = np.linalg.qr(h) | ||||
|     return q | ||||
|  | ||||
| p = random_ortho(n_dims) # algorithm only uses the inverse of P, so just sample that directly | ||||
| p = p[:output_dims, :] | ||||
|  | ||||
| def quantize(datavecs): | ||||
|     xs = (p @ datavecs.T).T | ||||
|     quantized = xs > 0 | ||||
|     dequantized = scale * (2 * quantized - 1) | ||||
|     dots = np.sum(dequantized * xs, axis=1) # <o_bar, o> | ||||
|     return quantized, dots | ||||
|  | ||||
| qsample, dots = quantize(sample) | ||||
| print(qsample.sum(axis=1).mean()) | ||||
| #print(dots) | ||||
| #print(dots.mean()) | ||||
|  | ||||
| def approx_dot(quantized_samples, dots, query): | ||||
|     mean_to_query = np.dot(mean, query) | ||||
|     print(mean_to_query) | ||||
|     dequantized = scale * (2 * quantized_samples - 1) | ||||
|     query_transformed = p @ query | ||||
|     o_bar_dot_q = np.sum(dequantized * query_transformed, axis=1) | ||||
|     return norms[:sample.shape[0]] * o_bar_dot_q * dots + mean_to_query | ||||
|  | ||||
| print(norms) | ||||
| approx_results = approx_dot(qsample, dots, queryset[0]) | ||||
| exact_results = sample @ queryset[0] | ||||
|  | ||||
| for x in zip(approx_results, exact_results): | ||||
|     print(*x) | ||||
|  | ||||
| print(*[ f"{x:.2f}" for x in (approx_results - exact_results) / np.abs(exact_results).mean() ]) | ||||
|  | ||||
| print(np.argsort(approx_results)) | ||||
| print(np.argsort(exact_results)) | ||||
|  | ||||
| with open("rabitq.msgpack", "wb") as f: | ||||
|     msgpack.pack({ | ||||
|         "mean": mean.flatten().tolist(), | ||||
|         "transform": p.flatten().tolist(), | ||||
|         "output_dims": output_dims, | ||||
|         "n_dims": n_dims | ||||
|     }, f) | ||||
							
								
								
									
										167
									
								
								diskann/scalar_quantize.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										167
									
								
								diskann/scalar_quantize.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,167 @@ | ||||
| import numpy as np | ||||
| import msgpack | ||||
| import math | ||||
|  | ||||
| n_dims = 1152 | ||||
| n_buckets = n_dims | ||||
| #n_buckets = n_dims // 2 # we now have one quant scale per pair of components | ||||
| #pair_separation = 16 # for efficient dot product computation, we need to have the second element of a pair exactly chunk_size after the first | ||||
| n_dims_per_bucket = n_dims // n_buckets | ||||
| data = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # sorry | ||||
|  | ||||
| CUTOFF = 1e-3 / 2 | ||||
|  | ||||
| print("computing quantiles") | ||||
| smin = np.quantile(data, CUTOFF, axis=0) | ||||
| smax = np.quantile(data, 1 - CUTOFF, axis=0) | ||||
|  | ||||
| # naive O(n²) greedy algorithm | ||||
| # probably overbuilt for the 2-components-per-bucket case but I'm not getting rid of it | ||||
| def assign_buckets(): | ||||
|     import random | ||||
|     intervals = list(enumerate(zip(smin, smax))) | ||||
|     random.shuffle(intervals) | ||||
|     buckets = [ [ intervals.pop() ] for _ in range(n_buckets) ] | ||||
|     def bucket_cost(bucket): | ||||
|         bmin = min(cmin for id, (cmin, cmax) in bucket) | ||||
|         bmax = max(cmax for id, (cmin, cmax) in bucket) | ||||
|         #print("MIN", bmin, "MAX", bmax) | ||||
|         return sum(abs(cmin - bmin) + abs(cmax - bmax) for id, (cmin, cmax) in bucket) | ||||
|     while len(intervals): | ||||
|         for bucket in buckets: | ||||
|             def new_interval_cost(interval): | ||||
|                 return bucket_cost(bucket + [interval[1]]) | ||||
|             i, interval = min(enumerate(intervals), key=new_interval_cost) | ||||
|             bucket.append(intervals.pop(i)) | ||||
|     return buckets | ||||
|  | ||||
| ranges = smax - smin | ||||
| # TODO: it is possible to do better assignment to buckets | ||||
| #order = np.argsort(ranges) | ||||
| print("bucket assignment") | ||||
| order = np.arange(n_dims) # np.concatenate(np.stack([ [ id for id, (cmin, cmax) in bucket ] for bucket in assign_buckets() ])) | ||||
|  | ||||
| bucket_ranges = [] | ||||
| bucket_centres = [] | ||||
| bucket_absmax = [] | ||||
| bucket_gmins = [] | ||||
|  | ||||
| for bucket_min in range(0, n_dims, n_dims_per_bucket): | ||||
|     bucket_max = bucket_min + n_dims_per_bucket | ||||
|     indices = order[bucket_min:bucket_max] | ||||
|     gmin = float(np.min(smin[indices])) | ||||
|     gmax = float(np.max(smax[indices])) | ||||
|     bucket_range = gmax - gmin | ||||
|     bucket_centre = (gmax + gmin) / 2 | ||||
|     bucket_gmins.append(gmin) | ||||
|     bucket_ranges.append(bucket_range) | ||||
|     bucket_centres.append(bucket_centre) | ||||
|     bucket_absmax.append(max(abs(gmin), abs(gmax))) | ||||
|  | ||||
| print("determining scales") | ||||
| scales = [] # multiply by float and convert to quantize | ||||
| offsets = [] | ||||
| q_offsets = [] # int16 value to add at dot time | ||||
| q_scales = [] # rescales channel up at dot time; must be proportional(ish) to square of scale factor but NOT cause overflow in accumulation or PLMULLW | ||||
| scale_factor_bound = float("inf") | ||||
| for bucket in range(n_buckets): | ||||
|     step_size = bucket_ranges[bucket] / 255 | ||||
|     scales.append(1 / step_size) | ||||
|     q_offset = int(bucket_gmins[bucket] / step_size) | ||||
|     q_offsets.append(q_offset) | ||||
|     nsfb = (2**31 - 1) / (n_dims_per_bucket * abs((255**2) + 2 * q_offset * 255 + q_offset ** 2)) / 2 | ||||
|     # we are bounded both by overflow in accumulation and PLMULLW (u8 plus offset times scale factor) | ||||
|     scale_factor_bound = min(scale_factor_bound, nsfb, (2**15 - 1) // (q_offset + 255)) | ||||
|     offsets.append(bucket_gmins[bucket]) | ||||
|  | ||||
| for bucket in range(n_buckets): | ||||
|     sfb = scale_factor_bound / max(map(lambda x: x ** 2, bucket_ranges)) | ||||
|     sf = (bucket_ranges[bucket]) ** 2 * sfb | ||||
|     q_scales.append(int(sf)) | ||||
|  | ||||
| print(bucket_ranges, bucket_centres, bucket_absmax) | ||||
| print(scales, offsets, q_offsets, q_scales) | ||||
|  | ||||
| """ | ||||
| interleave = np.concatenate([ | ||||
|     np.arange(0, n_dims, n_dims_per_bucket) + a | ||||
|     for a in range(n_dims_per_bucket) | ||||
| ]) | ||||
| """ | ||||
|  | ||||
| """ | ||||
| interleave = np.arange(0, n_dims) | ||||
| for base in range(0, n_dims, 2 * pair_separation): | ||||
|     interleave[base:base + pair_separation] = np.arange(base, base + 2 * pair_separation, 2) | ||||
|     interleave[base + pair_separation:base + 2 * pair_separation] = np.arange(base + 1, base + 2 * pair_separation + 1, 2) | ||||
| """ | ||||
|  | ||||
| #print(bucket_ranges, bucket_centres, order[interleave]) | ||||
| #print(ranges[order][interleave].tolist()) | ||||
| #print(ranges.tolist()) | ||||
|  | ||||
| with open("quantizer.msgpack", "wb") as f: | ||||
|     msgpack.pack({ | ||||
|         "permutation": order.tolist(), | ||||
|         "offsets": offsets, | ||||
|         "scales": scales, | ||||
|         "q_offsets": q_offsets, | ||||
|         "q_scales": q_scales | ||||
|     }, f) | ||||
|  | ||||
| def rquantize(vec): | ||||
|     out = np.zeros(len(vec), dtype=np.uint8) | ||||
|     for i, p in enumerate(order[interleave]): | ||||
|         bucket = p % n_buckets | ||||
|         raw = vec[i] | ||||
|         raw = (raw - offsets[bucket]) * scales[bucket] | ||||
|         raw = min(max(raw, 0.0), 255.0) | ||||
|         out[p] = round(raw) | ||||
|     return out | ||||
|  | ||||
| def rdquantize(bytes): | ||||
|     vec = np.zeros(n_dims, dtype=np.float32) | ||||
|     for i, p in enumerate(order[interleave]): | ||||
|         bucket = p % n_buckets | ||||
|         raw = float(bytes[p]) | ||||
|         vec[i] = raw / scales[bucket] + offsets[bucket] | ||||
|     return vec | ||||
|  | ||||
| def rdot(x, y): | ||||
|     xq_offsets = np.array(q_offsets, dtype=np.int16) | ||||
|     xq_scales = np.array(q_scales, dtype=np.int16) | ||||
|     assert x.shape == y.shape | ||||
|     assert x.dtype == np.uint8 == y.dtype | ||||
|     acc = 0 | ||||
|     for i in range(0, len(x), n_buckets): | ||||
|         x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets | ||||
|         y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets | ||||
|         x1 *= xq_scales | ||||
|         acc += np.dot(x1.astype(np.int32), y1.astype(np.int32)) | ||||
|     return acc | ||||
|  | ||||
| def cmp(i, j): | ||||
|     return np.dot(data[i], data[j]) / rdot(rquantize(data[i]), rquantize(data[j])) | ||||
|  | ||||
| def rdot_cmp(a, b): | ||||
|     x = rquantize(a) | ||||
|     y = rquantize(b) | ||||
|     a = a[order[interleave]] | ||||
|     b = b[order[interleave]] | ||||
|     xq_offsets = np.array(q_offsets, dtype=np.int16) | ||||
|     xq_scales = np.array(q_scales, dtype=np.int16) | ||||
|     assert x.shape == y.shape | ||||
|     assert x.dtype == np.uint8 == y.dtype | ||||
|     acc = 0 | ||||
|     for i in range(0, len(x), n_buckets): | ||||
|         x1 = x[i:i+n_buckets].astype(np.int16) + xq_offsets | ||||
|         y1 = y[i:i+n_buckets].astype(np.int16) + xq_offsets | ||||
|         x1 *= xq_scales | ||||
|         component = np.dot(x1.astype(np.int32), y1.astype(np.int32)) | ||||
|         a1 = a[i:i+n_buckets] | ||||
|         b1 = b[i:i+n_buckets] | ||||
|         component_exact = np.dot(a1, b1) | ||||
|         print(x1, a1, sep="\n") | ||||
|         print(component, component_exact, component / component_exact) | ||||
|         acc += component | ||||
|     return acc | ||||
							
								
								
									
										389
									
								
								diskann/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										389
									
								
								diskann/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,389 @@ | ||||
| #![feature(pointer_is_aligned_to)] | ||||
| #![feature(test)] | ||||
|  | ||||
| extern crate test; | ||||
|  | ||||
| use foldhash::{HashSet, HashMap, HashMapExt, HashSetExt}; | ||||
| use fastrand::Rng; | ||||
| use rayon::prelude::*; | ||||
| use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, Mutex}; | ||||
|  | ||||
| pub mod vector; | ||||
| use vector::{dot, fast_dot, fast_dot_noprefetch, to_svector, VectorRef, SVector, VectorList}; | ||||
|  | ||||
| // ParlayANN improves parallelism by not using locks like this and instead using smarter batch operations | ||||
| // but I don't have enough cores that it matters | ||||
| #[derive(Debug)] | ||||
| pub struct IndexGraph { | ||||
|     pub graph: Vec<RwLock<Vec<u32>>> | ||||
| } | ||||
|  | ||||
| impl IndexGraph { | ||||
|     pub fn random_r_regular(rng: &mut Rng, n: usize, r: usize, capacity: usize) -> Self { | ||||
|         let mut graph = Vec::with_capacity(n); | ||||
|         for _ in 0..n { | ||||
|             let mut adjacency = Vec::with_capacity(capacity); | ||||
|             for _ in 0..r { | ||||
|                 adjacency.push(rng.u32(0..(n as u32))); | ||||
|             } | ||||
|             graph.push(RwLock::new(adjacency)); | ||||
|         } | ||||
|         IndexGraph { | ||||
|             graph | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn empty(n: usize, capacity: usize) -> IndexGraph { | ||||
|         let mut graph = Vec::with_capacity(n); | ||||
|         for _ in 0..n { | ||||
|             graph.push(RwLock::new(Vec::with_capacity(capacity))); | ||||
|         } | ||||
|         IndexGraph { | ||||
|             graph | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn out_neighbours(&self, pt: u32) -> RwLockReadGuard<Vec<u32>> { | ||||
|         self.graph[pt as usize].read().unwrap() | ||||
|     } | ||||
|  | ||||
|     fn out_neighbours_mut(&self, pt: u32) -> RwLockWriteGuard<Vec<u32>> { | ||||
|         self.graph[pt as usize].write().unwrap() | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Clone, Copy, Debug)] | ||||
| pub struct IndexBuildConfig { | ||||
|     pub r: usize, | ||||
|     pub r_cap: usize, | ||||
|     pub l: usize, | ||||
|     pub maxc: usize, | ||||
|     pub alpha: i64 | ||||
| } | ||||
|  | ||||
|  | ||||
| fn centroid(vecs: &VectorList) -> SVector { | ||||
|     let mut centroid = SVector::zero(vecs.d_emb); | ||||
|  | ||||
|     for (i, vec) in vecs.iter().enumerate() { | ||||
|         let weight = 1.0 / (i + 1) as f32; | ||||
|         centroid += (to_svector(vec) - ¢roid) * weight; | ||||
|     } | ||||
|  | ||||
|     centroid | ||||
| } | ||||
|  | ||||
| pub fn medioid(vecs: &VectorList) -> u32 { | ||||
|     let centroid = centroid(vecs).half(); | ||||
|     vecs.iter().map(|vec| dot(vec, &*centroid)).enumerate().max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()).unwrap().0 as u32 | ||||
| } | ||||
|  | ||||
| // neighbours list sorted by score descending | ||||
| // TODO: this may actually be an awful datastructure | ||||
| #[derive(Clone, Debug)] | ||||
| pub struct NeighbourBuffer { | ||||
|     pub ids: Vec<u32>, | ||||
|     scores: Vec<i64>, | ||||
|     visited: Vec<bool>, | ||||
|     next_unvisited: Option<u32>, | ||||
|     size: usize | ||||
| } | ||||
|  | ||||
| impl NeighbourBuffer { | ||||
|     pub fn new(size: usize) -> Self { | ||||
|         NeighbourBuffer { | ||||
|             ids: Vec::with_capacity(size + 1), | ||||
|             scores: Vec::with_capacity(size + 1), | ||||
|             visited: Vec::with_capacity(size + 1), //bitvec::vec::BitVec::with_capacity(size), | ||||
|             next_unvisited: None, | ||||
|             size | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn next_unvisited(&mut self) -> Option<u32> { | ||||
|         //println!("next_unvisited: {:?}", self); | ||||
|         let mut cur = self.next_unvisited? as usize; | ||||
|         let old_cur = cur; | ||||
|         self.visited[cur] = true; | ||||
|         while cur < self.len() && self.visited[cur] { | ||||
|             cur += 1; | ||||
|         } | ||||
|         if cur == self.len() { | ||||
|             self.next_unvisited = None; | ||||
|         } else { | ||||
|             self.next_unvisited = Some(cur as u32); | ||||
|         } | ||||
|         Some(self.ids[old_cur]) | ||||
|     } | ||||
|  | ||||
|     pub fn len(&self) -> usize { | ||||
|         self.ids.len() | ||||
|     } | ||||
|  | ||||
|     pub fn cap(&self) -> usize { | ||||
|         self.size | ||||
|     } | ||||
|  | ||||
|     pub fn insert(&mut self, id: u32, score: i64) { | ||||
|         if self.len() == self.cap() && self.scores[self.len() - 1] > score { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         let loc = match self.scores.binary_search_by(|x| score.partial_cmp(&x).unwrap()) { | ||||
|             Ok(loc) => loc, | ||||
|             Err(loc) => loc | ||||
|         }; | ||||
|  | ||||
|         if self.ids.get(loc) == Some(&id) { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         // slightly inefficient but we avoid unsafe code | ||||
|         self.ids.insert(loc, id); | ||||
|         self.scores.insert(loc, score); | ||||
|         self.visited.insert(loc, false); | ||||
|         self.ids.truncate(self.size); | ||||
|         self.scores.truncate(self.size); | ||||
|         self.visited.truncate(self.size); | ||||
|  | ||||
|         self.next_unvisited = Some(loc as u32); | ||||
|     } | ||||
|  | ||||
|     pub fn clear(&mut self) { | ||||
|         self.ids.clear(); | ||||
|         self.scores.clear(); | ||||
|         self.visited.clear(); | ||||
|         self.next_unvisited = None; | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub struct Scratch { | ||||
|     visited: HashSet<u32>, | ||||
|     pub neighbour_buffer: NeighbourBuffer, | ||||
|     neighbour_pre_buffer: Vec<u32>, | ||||
|     visited_list: Vec<(u32, i64)>, | ||||
|     robust_prune_scratch_buffer: Vec<(usize, u32)> | ||||
| } | ||||
|  | ||||
| impl Scratch { | ||||
|     pub fn new(IndexBuildConfig { l, r, maxc, .. }: IndexBuildConfig) -> Self { | ||||
|         Scratch { | ||||
|             visited: HashSet::with_capacity(l * 8), | ||||
|             neighbour_buffer: NeighbourBuffer::new(l), | ||||
|             neighbour_pre_buffer: Vec::with_capacity(r), | ||||
|             visited_list: Vec::with_capacity(l * 8), | ||||
|             robust_prune_scratch_buffer: Vec::with_capacity(r) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub struct GreedySearchCounters { | ||||
|     pub distances: usize | ||||
| } | ||||
|  | ||||
| // Algorithm 1 from the DiskANN paper | ||||
| // We support the dot product metric only, so we want to keep things with the HIGHEST dot product | ||||
| pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs: &VectorList, graph: &IndexGraph, config: IndexBuildConfig) -> GreedySearchCounters { | ||||
|     scratch.visited.clear(); | ||||
|     scratch.neighbour_buffer.clear(); | ||||
|     scratch.visited_list.clear(); | ||||
|  | ||||
|     scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vecs[start as usize])); | ||||
|     scratch.visited.insert(start); | ||||
|  | ||||
|     let mut counters = GreedySearchCounters { distances: 0 }; | ||||
|  | ||||
|     while let Some(pt) = scratch.neighbour_buffer.next_unvisited() { | ||||
|         //println!("pt {} {:?}", pt, graph.out_neighbours(pt)); | ||||
|         scratch.neighbour_pre_buffer.clear(); | ||||
|         for &neighbour in graph.out_neighbours(pt).iter() { | ||||
|             if scratch.visited.insert(neighbour) { | ||||
|                 scratch.neighbour_pre_buffer.push(neighbour); | ||||
|             } | ||||
|         } | ||||
|         for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { | ||||
|             let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO | ||||
|             let distance = fast_dot(query, &vecs[neighbour as usize], &vecs[next_neighbour as usize]); | ||||
|             counters.distances += 1; | ||||
|             scratch.neighbour_buffer.insert(neighbour, distance); | ||||
|             scratch.visited_list.push((neighbour, distance)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     counters | ||||
| } | ||||
|  | ||||
| type CandidateList = Vec<(u32, i64)>; | ||||
|  | ||||
| fn merge_existing_neighbours(candidates: &mut CandidateList, point: u32, neigh: &[u32], vecs: &VectorList, config: IndexBuildConfig) { | ||||
|     let p_vec = &vecs[point as usize]; | ||||
|     for (i, &n) in neigh.iter().enumerate() { | ||||
|         let dot = fast_dot(p_vec, &vecs[n as usize], &vecs[neigh[(i + 1) % neigh.len() as usize] as usize]); | ||||
|         candidates.push((n, dot)); | ||||
|     } | ||||
| } | ||||
|  | ||||
| // "Robust prune" algorithm, kind of | ||||
| // The algorithm in the paper does not actually match the code as implemented in microsoft/DiskANN | ||||
| // and that's slightly different from the one in ParlayANN for no reason | ||||
| // This is closer to ParlayANN | ||||
| fn robust_prune(scratch: &mut Scratch, p: u32, neigh: &mut Vec<u32>, vecs: &VectorList, config: IndexBuildConfig) { | ||||
|     neigh.clear(); | ||||
|  | ||||
|     let candidates = &mut scratch.visited_list; | ||||
|  | ||||
|     // distance low to high = score high to low | ||||
|     candidates.sort_unstable_by_key(|&(_id, score)| -score); | ||||
|     candidates.truncate(config.maxc); | ||||
|  | ||||
|     let mut candidate_index = 0; | ||||
|     while neigh.len() < config.r && candidate_index < candidates.len() { | ||||
|         let p_star = candidates[candidate_index].0; | ||||
|         candidate_index += 1; | ||||
|         if p_star == p || p_star == u32::MAX { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         neigh.push(p_star); | ||||
|  | ||||
|         scratch.robust_prune_scratch_buffer.clear(); | ||||
|  | ||||
|         // mark remaining candidates as not-to-be-used if "not much better than" current candidate | ||||
|         for i in (candidate_index+1)..candidates.len() { | ||||
|             let p_prime = candidates[i].0; | ||||
|             if p_prime != u32::MAX { | ||||
|                 scratch.robust_prune_scratch_buffer.push((i, p_prime)); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         for (i, &(ci, p_prime)) in scratch.robust_prune_scratch_buffer.iter().enumerate() { | ||||
|             let next_vec = &vecs[scratch.robust_prune_scratch_buffer[(i + 1) % scratch.robust_prune_scratch_buffer.len()].0 as usize]; | ||||
|             let p_star_prime_score = fast_dot(&vecs[p_prime as usize], &vecs[p_star as usize], next_vec); | ||||
|             let p_prime_p_score = candidates[ci].1; | ||||
|             let alpha_times_p_star_prime_score = (config.alpha * p_star_prime_score) >> 16; | ||||
|  | ||||
|             if alpha_times_p_star_prime_score >= p_prime_p_score { | ||||
|                 candidates[ci].0 = u32::MAX; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &VectorList, config: IndexBuildConfig) { | ||||
|     assert!(vecs.len() < u32::MAX as usize); | ||||
|  | ||||
|     let mut sigmas: Vec<u32> = (0..(vecs.len() as u32)).collect(); | ||||
|     rng.shuffle(&mut sigmas); | ||||
|  | ||||
|     let rng = Mutex::new(rng.fork()); | ||||
|  | ||||
|     //let scratch = &mut Scratch::new(config); | ||||
|     //let mut rng = rng.lock().unwrap(); | ||||
|     sigmas.into_par_iter().for_each_init(|| (Scratch::new(config), rng.lock().unwrap().fork()), |(scratch, rng), sigma_i| { | ||||
|     //sigmas.into_iter().for_each(|sigma_i| { | ||||
|         greedy_search(scratch, medioid, &vecs[sigma_i as usize], vecs, &graph, config); | ||||
|  | ||||
|         { | ||||
|             let n = graph.out_neighbours(sigma_i); | ||||
|             merge_existing_neighbours(&mut scratch.visited_list, sigma_i, &*n, vecs, config); | ||||
|         } | ||||
|  | ||||
|         { | ||||
|             let mut n = graph.out_neighbours_mut(sigma_i); | ||||
|             robust_prune(scratch, sigma_i, &mut *n, vecs, config); | ||||
|         } | ||||
|  | ||||
|         let neighbours = graph.out_neighbours(sigma_i).to_owned(); | ||||
|         for neighbour in neighbours { | ||||
|             let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour); | ||||
|             // To cut down pruning time slightly, allow accumulating more neighbours than usual limit | ||||
|             if neighbour_neighbours.len() == config.r_cap { | ||||
|                 let mut n = neighbour_neighbours.to_vec(); | ||||
|                 scratch.visited_list.clear(); | ||||
|                 merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config); | ||||
|                 merge_existing_neighbours(&mut scratch.visited_list, neighbour, &vec![sigma_i], vecs, config); | ||||
|                 robust_prune(scratch, neighbour, &mut n, vecs, config); | ||||
|             } else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap { | ||||
|                 neighbour_neighbours.push(sigma_i); | ||||
|             } | ||||
|         } | ||||
|     }); | ||||
| } | ||||
|  | ||||
| // RoarGraph's AcquireNeighbours algorithm is actually almost identical to Vamana/DiskANN's RobustPrune, but with fixed α = 1.0. | ||||
| // We replace Vamana's random initialization of the graph with Neighbourhood-Aware Projection from RoarGraph - there's no way to use a large enough | ||||
| // query set that I would be confident in using *only* RoarGraph's algorithm | ||||
| pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec<Vec<u32>>, query_knns_bwd: &Vec<Vec<u32>>, config: IndexBuildConfig, vecs: &VectorList) { | ||||
|     let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect(); | ||||
|     rng.shuffle(&mut sigmas); | ||||
|  | ||||
|     // Iterate through graph vertices in a random order | ||||
|     let rng = Mutex::new(rng.fork()); | ||||
|     sigmas.into_par_iter().for_each_init(|| (rng.lock().unwrap().fork(), Scratch::new(config)), |(rng, scratch), sigma_i| { | ||||
|         scratch.visited.clear(); | ||||
|         scratch.visited_list.clear(); | ||||
|         scratch.neighbour_pre_buffer.clear(); | ||||
|         for &query_neighbour in query_knns[sigma_i as usize].iter() { | ||||
|             for &projected_neighbour in query_knns_bwd[query_neighbour as usize].iter() { | ||||
|                 if scratch.visited.insert(projected_neighbour) { | ||||
|                     scratch.neighbour_pre_buffer.push(projected_neighbour); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         rng.shuffle(&mut scratch.neighbour_pre_buffer); | ||||
|         scratch.neighbour_pre_buffer.truncate(config.maxc * 2); | ||||
|         for (i, &projected_neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { | ||||
|             let score = fast_dot(&vecs[sigma_i as usize], &vecs[projected_neighbour as usize], &vecs[scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()] as usize]); | ||||
|             scratch.visited_list.push((projected_neighbour, score)); | ||||
|         } | ||||
|         let mut neighbours = graph.out_neighbours_mut(sigma_i); | ||||
|         robust_prune(scratch, sigma_i, &mut *neighbours, vecs, config); | ||||
|     }) | ||||
| } | ||||
|  | ||||
| pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<Vec<u32>>, query_knns_bwd: Vec<Vec<u32>>, config: IndexBuildConfig) { | ||||
|     let mut sigmas: Vec<u32> = (0..(graph.graph.len() as u32)).collect(); | ||||
|     rng.shuffle(&mut sigmas); | ||||
|  | ||||
|     // Iterate through graph vertices in a random order | ||||
|     let rng = Mutex::new(rng.fork()); | ||||
|     sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| { | ||||
|         let mut neighbours = graph.out_neighbours_mut(sigma_i); | ||||
|         let mut i = 0; | ||||
|         while neighbours.len() < config.r_cap && i < 100 { | ||||
|             let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap(); | ||||
|             let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap(); | ||||
|             if !neighbours.contains(&projected_neighbour) { | ||||
|                 neighbours.push(projected_neighbour); | ||||
|             } | ||||
|             i += 1; | ||||
|         } | ||||
|     }) | ||||
| } | ||||
|  | ||||
| pub fn random_fill_graph(rng: &mut Rng, graph: &mut IndexGraph, r: usize) { | ||||
|     let rng = Mutex::new(rng.fork()); | ||||
|     (0..graph.graph.len() as u32).into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, i| { | ||||
|         let mut neighbours = graph.out_neighbours_mut(i); | ||||
|         while neighbours.len() < r { | ||||
|             let next = rng.u32(0..(graph.graph.len() as u32)); | ||||
|             if !neighbours.contains(&next) { | ||||
|                 neighbours.push(next); | ||||
|             } | ||||
|         } | ||||
|     }); | ||||
| } | ||||
|  | ||||
| pub struct Timer(&'static str, std::time::Instant); | ||||
|  | ||||
| impl Timer { | ||||
|     pub fn new(name: &'static str) -> Self { | ||||
|         Timer(name, std::time::Instant::now()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Drop for Timer { | ||||
|     fn drop(&mut self) { | ||||
|         println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32()); | ||||
|     } | ||||
| } | ||||
							
								
								
									
										121
									
								
								diskann/src/main.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								diskann/src/main.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | ||||
| #![feature(test)] | ||||
| #![feature(pointer_is_aligned_to)] | ||||
|  | ||||
| extern crate test; | ||||
|  | ||||
| use std::{io::Read, time::Instant}; | ||||
| use anyhow::Result; | ||||
| use half::f16; | ||||
|  | ||||
| use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, dot, VectorList, self}, Timer}; | ||||
| use simsimd::SpatialSimilarity; | ||||
|  | ||||
| const D_EMB: usize = 1152; | ||||
|  | ||||
| fn load_file(path: &str, trunc: Option<usize>) -> Result<VectorList> { | ||||
|     let mut input = std::fs::File::open(path)?; | ||||
|     let mut buf = Vec::new(); | ||||
|     input.read_to_end(&mut buf)?; | ||||
|     // TODO: this is not particularly efficient | ||||
|     let f16s = bytemuck::cast_slice::<_, f16>(&buf)[0..trunc.unwrap_or(buf.len()/2)].iter().copied().collect(); | ||||
|     Ok(VectorList::from_f16s(f16s, D_EMB)) | ||||
| } | ||||
|  | ||||
| const PQ_TEST_SIZE: usize = 1000; | ||||
|  | ||||
| fn main() -> Result<()> { | ||||
|     tracing_subscriber::fmt::init(); | ||||
|  | ||||
|     { | ||||
|         let file = std::fs::File::open("opq.msgpack")?; | ||||
|         let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?; | ||||
|         let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); | ||||
|         let codes = codec.quantize_batch(&input); | ||||
|         println!("{:?}", codes); | ||||
|         let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); | ||||
|         let query = codec.preprocess_query(&raw_query); | ||||
|         let mut real_scores = vec![]; | ||||
|         for i in 0..PQ_TEST_SIZE { | ||||
|             real_scores.push(SpatialSimilarity::dot(&raw_query, &input[i * D_EMB .. (i + 1) * D_EMB]).unwrap() as f32); | ||||
|         } | ||||
|         let pq_scores = codec.asymmetric_dot_product(&query, &codes); | ||||
|         for (x, y) in real_scores.iter().zip(pq_scores.iter()) { | ||||
|             println!("{} {} {} {}", x, y, x - y, (x - y) / x); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     let mut rng = fastrand::Rng::with_seed(1); | ||||
|  | ||||
|     let n = 100000; | ||||
|     let vecs = { | ||||
|         let _timer = Timer::new("loaded vectors"); | ||||
|  | ||||
|         &load_file("embeddings.bin", Some(D_EMB * n))? | ||||
|     }; | ||||
|  | ||||
|     let (graph, medioid) = { | ||||
|         let _timer = Timer::new("index built"); | ||||
|  | ||||
|         let mut config = IndexBuildConfig { | ||||
|             r: 64, | ||||
|             r_cap: 80, | ||||
|             l: 128, | ||||
|             maxc: 750, | ||||
|             alpha: 65536, | ||||
|         }; | ||||
|  | ||||
|         let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); | ||||
|  | ||||
|         let medioid = medioid(&vecs); | ||||
|  | ||||
|         build_graph(&mut rng, &mut graph, medioid, &vecs, config); | ||||
|         config.alpha = 58000; | ||||
|         build_graph(&mut rng, &mut graph, medioid, &vecs, config); | ||||
|  | ||||
|         (graph, medioid) | ||||
|     }; | ||||
|  | ||||
|     let mut edge_ctr = 0; | ||||
|  | ||||
|     for adjlist in graph.graph.iter() { | ||||
|         edge_ctr += adjlist.read().unwrap().len(); | ||||
|     } | ||||
|  | ||||
|     println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32); | ||||
|  | ||||
|     let time = Instant::now(); | ||||
|     let mut recall = 0; | ||||
|     let mut cmps_ctr = 0; | ||||
|     let mut cmps = vec![]; | ||||
|  | ||||
|     let mut config = IndexBuildConfig { | ||||
|         r: 64, | ||||
|         r_cap: 64, | ||||
|         l: 50, | ||||
|         alpha: 65536, | ||||
|         maxc: 0, | ||||
|     }; | ||||
|  | ||||
|     let mut scratch = Scratch::new(config); | ||||
|  | ||||
|     for (i, vec) in tqdm::tqdm(vecs.iter().enumerate()) { | ||||
|         let ctr = greedy_search(&mut scratch, medioid, &vec, &vecs, &graph, config); | ||||
|         cmps_ctr += ctr.distances; | ||||
|         cmps.push(ctr.distances); | ||||
|         if scratch.neighbour_buffer.ids[0] == (i as u32) { | ||||
|             recall += 1; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     cmps.sort(); | ||||
|  | ||||
|     let end = time.elapsed(); | ||||
|  | ||||
|     println!("recall@1: {} ({}/{})", recall as f32 / n as f32, recall, n); | ||||
|     println!("cmps: {} ({}/{})", cmps_ctr as f32 / n as f32, cmps_ctr, n); | ||||
|     println!("median comparisons: {}", cmps[cmps.len() / 2]); | ||||
|     //println!("brute force recall@1: {} ({}/{})", brute_force_recall as f32 / brute_force_queries as f32, brute_force_recall, brute_force_queries); | ||||
|     println!("{} QPS", n as f32 / end.as_secs_f32()); | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
							
								
								
									
										449
									
								
								diskann/src/vector.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										449
									
								
								diskann/src/vector.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,449 @@ | ||||
| use core::f32; | ||||
|  | ||||
| use half::f16; | ||||
| use simsimd::SpatialSimilarity; | ||||
| use fastrand::Rng; | ||||
| use serde::{Serialize, Deserialize}; | ||||
| use tracing_subscriber::field::RecordFields; | ||||
|  | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct Vector(Vec<f16>); | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct SVector(Vec<f32>); | ||||
|  | ||||
| pub type VectorRef<'a> = &'a [f16]; | ||||
| pub type QVectorRef<'a> = &'a [u8]; | ||||
| pub type SVectorRef<'a> = &'a [f32]; | ||||
|  | ||||
| impl SVector { | ||||
|     pub fn zero(d: usize) -> Self { | ||||
|         SVector(vec![0.0; d]) | ||||
|     } | ||||
|     pub fn half(&self) -> Vector { | ||||
|         Vector(self.0.iter().map(|a| f16::from_f32(*a)).collect()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn box_muller(rng: &mut Rng) -> f32 { | ||||
|     loop { | ||||
|         let u = rng.f32(); | ||||
|         let v = rng.f32(); | ||||
|         let x = (v * std::f32::consts::TAU).cos() * (-2.0 * u.ln()).sqrt(); | ||||
|         if x.is_finite() { | ||||
|             return x; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Vector { | ||||
|     pub fn zero(d: usize) -> Self { | ||||
|         Vector(vec![f16::from_f32(0.0); d]) | ||||
|     } | ||||
|  | ||||
|     pub fn randn(rng: &mut Rng, d: usize) -> Self { | ||||
|         Vector(Vec::from_iter((0..d).map(|_| f16::from_f32(box_muller(rng))))) | ||||
|     } | ||||
| } | ||||
|  | ||||
| // Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers | ||||
| const SCALE: f32 = 281474976710656.0; | ||||
| const SCALE_F64: f64 = SCALE as f64; | ||||
|  | ||||
| pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> f32 { | ||||
|     // safety is not real | ||||
|     (simsimd::f16::dot(unsafe { std::mem::transmute(x) }, unsafe { std::mem::transmute(y) }).unwrap()) as f32 | ||||
| } | ||||
|  | ||||
| pub fn to_svector(vec: VectorRef) -> SVector { | ||||
|     SVector(vec.iter().map(|a| a.to_f32()).collect()) | ||||
| } | ||||
|  | ||||
| impl<'a> std::ops::AddAssign<VectorRef<'a>> for SVector { | ||||
|     fn add_assign(&mut self, other: VectorRef<'a>) { | ||||
|         self.0.iter_mut().zip(other.iter()).for_each(|(a, b)| *a += b.to_f32()); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::ops::Div<f32> for SVector { | ||||
|     type Output = Self; | ||||
|  | ||||
|     fn div(self, b: f32) -> Self::Output { | ||||
|         SVector(self.0.iter().map(|a| a / b).collect()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::ops::Deref for Vector { | ||||
|     type Target = [f16]; | ||||
|  | ||||
|     fn deref(&self) -> &Self::Target { | ||||
|         &self.0 | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::ops::Deref for SVector { | ||||
|     type Target = [f32]; | ||||
|  | ||||
|     fn deref(&self) -> &Self::Target { | ||||
|         &self.0 | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::ops::Add<&SVector> for SVector { | ||||
|     type Output = Self; | ||||
|  | ||||
|     fn add(self, other: &Self) -> Self::Output { | ||||
|         SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a + b).collect()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::ops::Sub<&SVector> for SVector { | ||||
|     type Output = Self; | ||||
|  | ||||
|     fn sub(self, other: &Self) -> Self::Output { | ||||
|         SVector(self.0.iter().zip(other.0.iter()).map(|(a, b)| a - b).collect()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::ops::AddAssign for SVector { | ||||
|     fn add_assign(&mut self, other: Self) { | ||||
|         self.0.iter_mut().zip(other.0.iter()).for_each(|(a, b)| *a += b); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::ops::Mul<f32> for SVector { | ||||
|     type Output = Self; | ||||
|  | ||||
|     fn mul(self, other: f32) -> Self { | ||||
|         SVector(self.0.iter().map(|a| *a * other).collect()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct VectorList { | ||||
|     pub d_emb: usize, | ||||
|     pub length: usize, | ||||
|     pub data: Vec<f16> | ||||
| } | ||||
|  | ||||
| impl std::ops::Index<usize> for VectorList { | ||||
|     type Output = [f16]; | ||||
|  | ||||
|     fn index(&self, index: usize) -> &Self::Output { | ||||
|         &self.data[index * self.d_emb..(index + 1) * self.d_emb] | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub struct VectorListIterator<'a> { | ||||
|     list: &'a VectorList, | ||||
|     index: usize | ||||
| } | ||||
|  | ||||
| impl<'a> Iterator for VectorListIterator<'a> { | ||||
|     type Item = VectorRef<'a>; | ||||
|  | ||||
|     fn next(&mut self) -> Option<Self::Item> { | ||||
|         if self.index < self.list.len() { | ||||
|             let ret = &self.list[self.index]; | ||||
|             self.index += 1; | ||||
|             Some(ret) | ||||
|         } else { | ||||
|             None | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| impl VectorList { | ||||
|     pub fn len(&self) -> usize { | ||||
|         self.length | ||||
|     } | ||||
|  | ||||
|     pub fn iter(&self) -> VectorListIterator { | ||||
|         VectorListIterator { | ||||
|             list: self, | ||||
|             index: 0 | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn empty(d: usize) -> Self { | ||||
|         VectorList { | ||||
|             d_emb: d, | ||||
|             length: 0, | ||||
|             data: Vec::new() | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn from_f16s(f16s: Vec<f16>, d: usize) -> Self { | ||||
|         assert!(f16s.len() % d == 0); | ||||
|         VectorList { | ||||
|             d_emb: d, | ||||
|             length: f16s.len() / d, | ||||
|             data: f16s | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn push(&mut self, vec: VectorRef) { | ||||
|         self.length += 1; | ||||
|         self.data.extend_from_slice(vec); | ||||
|     } | ||||
| } | ||||
|  | ||||
| // SimSIMD has its own, but ours prefetches concurrently, is unrolled more, ignores inconveniently-sized vectors and does a cheaper reduction | ||||
| // Also, we return an int because floats are annoying (not Ord) | ||||
| // On Tiger Lake (i5-1135G7) we have about a 3x performance advantage ignoring the prefetching | ||||
| // (it would be better to use AVX512 for said CPU but this also has to run on Zen 3) | ||||
| pub fn fast_dot(x: VectorRef, y: VectorRef, prefetch: VectorRef) -> i64 { | ||||
|     use std::arch::x86_64::*; | ||||
|  | ||||
|     debug_assert!(x.len() == y.len()); | ||||
|     debug_assert!(prefetch.len() == x.len()); | ||||
|     debug_assert!(x.len() % 64 == 0); | ||||
|  | ||||
|     // safety is not real | ||||
|     // it's probably fine I guess | ||||
|     unsafe { | ||||
|         let mut x_ptr = x.as_ptr(); | ||||
|         let mut y_ptr = y.as_ptr(); | ||||
|         let end = x_ptr.add(x.len()); | ||||
|         let mut prefetch_ptr = prefetch.as_ptr(); | ||||
|  | ||||
|         let mut acc1 = _mm256_setzero_ps(); | ||||
|         let mut acc2 = _mm256_setzero_ps(); | ||||
|         let mut acc3 = _mm256_setzero_ps(); | ||||
|         let mut acc4 = _mm256_setzero_ps(); | ||||
|  | ||||
|         while x_ptr < end { | ||||
|             // fetch chunks and prefetch next vector | ||||
|             let x1 = _mm256_loadu_si256(x_ptr as *const __m256i); | ||||
|             let y1 = _mm256_loadu_si256(y_ptr as *const __m256i); | ||||
|             let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i); | ||||
|             let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i); | ||||
|             // technically, we only have to do this once per cache line but I don't care enough to test every way to optimize this | ||||
|             _mm_prefetch(prefetch_ptr as *const i8, _MM_HINT_T0); | ||||
|             x_ptr = x_ptr.add(32); // move 16 f16s at a time | ||||
|             y_ptr = y_ptr.add(32); | ||||
|             prefetch_ptr = prefetch_ptr.add(32); | ||||
|  | ||||
|             // unpack f32 to f16 | ||||
|             let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0)); | ||||
|             let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1)); | ||||
|             let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0)); | ||||
|             let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1)); | ||||
|             let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0)); | ||||
|             let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1)); | ||||
|             let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0)); | ||||
|             let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1)); | ||||
|  | ||||
|             acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1); | ||||
|             acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2); | ||||
|             acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3); | ||||
|             acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4); | ||||
|         } | ||||
|  | ||||
|         // reduce | ||||
|         let acc1 = _mm256_add_ps(acc1, acc2); | ||||
|         let acc2 = _mm256_add_ps(acc3, acc4); | ||||
|  | ||||
|         let hsum = _mm256_hadd_ps(acc1, acc2); | ||||
|         let hsum_lo = _mm256_extractf128_ps(hsum, 0); | ||||
|         let hsum_hi = _mm256_extractf128_ps(hsum, 1); | ||||
|         let hsum = _mm_add_ps(hsum_lo, hsum_hi); | ||||
|  | ||||
|         let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32); | ||||
|         (floatval * SCALE) as i64 | ||||
|     } | ||||
| } | ||||
|  | ||||
| // same as above, without prefetch pointer | ||||
| pub fn fast_dot_noprefetch(x: VectorRef, y: VectorRef) -> i64 { | ||||
|     use std::arch::x86_64::*; | ||||
|  | ||||
|     debug_assert!(x.len() == y.len()); | ||||
|     debug_assert!(x.len() % 64 == 0); | ||||
|  | ||||
|     unsafe { | ||||
|         let mut x_ptr = x.as_ptr(); | ||||
|         let mut y_ptr = y.as_ptr(); | ||||
|         let end = x_ptr.add(x.len()); | ||||
|  | ||||
|         let mut acc1 = _mm256_setzero_ps(); | ||||
|         let mut acc2 = _mm256_setzero_ps(); | ||||
|         let mut acc3 = _mm256_setzero_ps(); | ||||
|         let mut acc4 = _mm256_setzero_ps(); | ||||
|  | ||||
|         while x_ptr < end { | ||||
|             let x1 = _mm256_loadu_si256(x_ptr as *const __m256i); | ||||
|             let y1 = _mm256_loadu_si256(y_ptr as *const __m256i); | ||||
|             let x2 = _mm256_loadu_si256(x_ptr.add(16) as *const __m256i); | ||||
|             let y2 = _mm256_loadu_si256(y_ptr.add(16) as *const __m256i); | ||||
|             x_ptr = x_ptr.add(32); | ||||
|             y_ptr = y_ptr.add(32); | ||||
|  | ||||
|             let x1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 0)); | ||||
|             let x1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x1, 1)); | ||||
|             let y1lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 0)); | ||||
|             let y1hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y1, 1)); | ||||
|             let x2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 0)); | ||||
|             let x2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(x2, 1)); | ||||
|             let y2lo = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 0)); | ||||
|             let y2hi = _mm256_cvtph_ps(_mm256_extractf128_si256(y2, 1)); | ||||
|  | ||||
|             acc1 = _mm256_fmadd_ps(x1lo, y1lo, acc1); | ||||
|             acc2 = _mm256_fmadd_ps(x1hi, y1hi, acc2); | ||||
|             acc3 = _mm256_fmadd_ps(x2lo, y2lo, acc3); | ||||
|             acc4 = _mm256_fmadd_ps(x2hi, y2hi, acc4); | ||||
|         } | ||||
|  | ||||
|         // reduce | ||||
|         let acc1 = _mm256_add_ps(acc1, acc2); | ||||
|         let acc2 = _mm256_add_ps(acc3, acc4); | ||||
|  | ||||
|         let hsum = _mm256_hadd_ps(acc1, acc2); | ||||
|         let hsum_lo = _mm256_extractf128_ps(hsum, 0); | ||||
|         let hsum_hi = _mm256_extractf128_ps(hsum, 1); | ||||
|         let hsum = _mm_add_ps(hsum_lo, hsum_hi); | ||||
|  | ||||
|         let floatval = f32::from_bits(_mm_extract_ps::<0>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<1>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<2>(hsum) as u32) + f32::from_bits(_mm_extract_ps::<3>(hsum) as u32); | ||||
|         (floatval * SCALE) as i64 | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Serialize, Deserialize)] | ||||
| pub struct ProductQuantizer { | ||||
|     centroids: Vec<f32>, | ||||
|     transform: Vec<f32>, // D*D orthonormal matrix | ||||
|     pub n_dims_per_code: usize, | ||||
|     pub n_dims: usize | ||||
| } | ||||
|  | ||||
| // chunk * centroid_index | ||||
| pub struct DistanceLUT(Vec<f32>); | ||||
|  | ||||
| impl ProductQuantizer { | ||||
|     pub fn apply_transform(&self, x: &[f32]) -> Vec<f32> { | ||||
|         let dim = self.n_dims; | ||||
|         let n_vectors = x.len() / dim; | ||||
|         let mut transformed = vec![0.0; n_vectors * dim]; | ||||
|         // transform_matrix (D * D) @ batch.T (D * B) | ||||
|         unsafe { | ||||
|             matrixmultiply::sgemm(dim, dim, n_vectors, 1.0, self.transform.as_ptr(), dim as isize, 1, x.as_ptr(), 1, dim as isize, 0.0, transformed.as_mut_ptr(), 1, dim as isize); | ||||
|         } | ||||
|         transformed | ||||
|     } | ||||
|  | ||||
|     pub fn quantize_batch(&self, x: &[f32]) -> Vec<u8> { | ||||
|         // x is B * D | ||||
|         let dim = self.n_dims; | ||||
|         assert_eq!(dim * dim, self.transform.len()); | ||||
|         let n_vectors = x.len() / dim; | ||||
|         let n_centroids = self.centroids.len() / dim; | ||||
|         assert!(n_centroids <= 256); | ||||
|         let transformed = self.apply_transform(&x); // B * D, as we write sgemm result in a weird order | ||||
|         let mut codes = vec![0; n_vectors * dim / self.n_dims_per_code]; | ||||
|         let vec_len_codes = dim / self.n_dims_per_code; | ||||
|  | ||||
|         // B * C buffer of similarity of each vector to each centroid, within subspace | ||||
|         let mut scratch = vec![0.0; n_vectors * n_centroids]; | ||||
|  | ||||
|         for i in 0..(dim / self.n_dims_per_code) { | ||||
|             let offset = i * self.n_dims_per_code; | ||||
|             // transformed_batch[:, range] (B * D_r) @ centroids[:, range].T (D_r * C) | ||||
|             unsafe { | ||||
|                 matrixmultiply::sgemm(n_vectors, self.n_dims_per_code, n_centroids, 1.0, transformed.as_ptr().add(offset), dim as isize, 1, self.centroids.as_ptr().add(offset), 1, dim as isize, 0.0, scratch.as_mut_ptr(), n_centroids as isize, 1); | ||||
|             } | ||||
|             // assign this component to best centroid | ||||
|             for i_vec in 0..n_vectors { | ||||
|                 let mut best = f32::NEG_INFINITY; | ||||
|                 for i_centroid in 0..n_centroids { | ||||
|                     let score = scratch[i_vec * n_centroids + i_centroid]; | ||||
|                     if score > best { | ||||
|                         best = score; | ||||
|                         codes[i_vec * vec_len_codes + i] = i_centroid as u8; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         codes | ||||
|     } | ||||
|  | ||||
|     // not particularly performance-sensitive right now; do unbatched | ||||
|     pub fn preprocess_query(&self, query: &[f32]) -> DistanceLUT { | ||||
|         let transformed = self.apply_transform(query); | ||||
|         let n_chunks = self.n_dims / self.n_dims_per_code; | ||||
|         let n_centroids = self.centroids.len() / self.n_dims; | ||||
|         let mut lut = Vec::with_capacity(n_chunks * n_centroids); | ||||
|  | ||||
|         for i in 0..n_chunks { | ||||
|             let vec_component = &transformed[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code]; | ||||
|             for j in 0..n_centroids { | ||||
|                 let centroid = &self.centroids[j * self.n_dims..(j + 1) * self.n_dims]; | ||||
|                 let centroid_component = ¢roid[i * self.n_dims_per_code..(i + 1) * self.n_dims_per_code]; | ||||
|                 let score = SpatialSimilarity::dot(vec_component, centroid_component).unwrap(); | ||||
|                 lut.push(score as f32); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         DistanceLUT(lut) | ||||
|     } | ||||
|  | ||||
|     // compute dot products of query against product-quantized vectors | ||||
|     pub fn asymmetric_dot_product(&self, query: &DistanceLUT, pq_vectors: &[u8]) -> Vec<i64> { | ||||
|         let n_chunks = self.n_dims / self.n_dims_per_code; | ||||
|         let n_vectors = pq_vectors.len() / n_chunks; | ||||
|         let mut scores = vec![0.0; n_vectors]; | ||||
|         let n_centroids = self.centroids.len() / self.n_dims; | ||||
|  | ||||
|         for i in 0..n_chunks { | ||||
|             for j in 0..n_vectors { | ||||
|                 let code = pq_vectors[j * n_chunks + i]; | ||||
|                 let chunk_score = query.0[i * n_centroids + code as usize]; | ||||
|                 scores[j] += chunk_score; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // I have no idea why but we somehow have significant degradation in search quality | ||||
|         // if this accumulates in integers. As such, do floats and convert at the end. | ||||
|         // I'm sure there are fascinating reasons for this, but God is dead, God remains dead, etc. | ||||
|         scores.into_iter().map(|x| (x * SCALE) as i64).collect() | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn scale_dot_result(x: f64) -> i64 { | ||||
|     (x * SCALE_F64) as i64 | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod bench { | ||||
|     use super::*; | ||||
|     use test::Bencher; | ||||
|  | ||||
|     #[bench] | ||||
|     fn bench_dot(be: &mut Bencher) { | ||||
|         let mut rng = fastrand::Rng::with_seed(1); | ||||
|         let a = Vector::randn(&mut rng, 1024); | ||||
|         let b = Vector::randn(&mut rng, 1024); | ||||
|         be.iter(|| { | ||||
|             dot(&a, &b) | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     #[bench] | ||||
|     fn bench_fastdot(be: &mut Bencher) { | ||||
|         let mut rng = fastrand::Rng::with_seed(1); | ||||
|         let a = Vector::randn(&mut rng, 1024); | ||||
|         let b = Vector::randn(&mut rng, 1024); | ||||
|         be.iter(|| { | ||||
|             fast_dot(&a, &b, &a) | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     #[bench] | ||||
|     fn bench_fastdot_noprefetch(be: &mut Bencher) { | ||||
|         let mut rng = fastrand::Rng::with_seed(1); | ||||
|         let a = Vector::randn(&mut rng, 1024); | ||||
|         let b = Vector::randn(&mut rng, 1024); | ||||
|         be.iter(|| { | ||||
|             fast_dot_noprefetch(&a, &b) | ||||
|         }); | ||||
|     } | ||||
| } | ||||
							
								
								
									
										44
									
								
								diskann/vec_dist.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								diskann/vec_dist.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| import numpy as np | ||||
| import matplotlib.pyplot as plt | ||||
| import seaborn as sns | ||||
|  | ||||
| A = 0.6 | ||||
| LOG_A = np.log(A) | ||||
|  | ||||
| def scale(xs): | ||||
|     return np.sign(xs) * (np.log(np.abs(xs) + A) - LOG_A) | ||||
|  | ||||
| n_dims = 1152 | ||||
| n_used_dims = 32 | ||||
| data = np.frombuffer(open("embeddings.bin", "rb").read(), dtype=np.float16).reshape(-1, n_dims).astype(np.float32) # TODO | ||||
|  | ||||
| # Create histogram bins | ||||
| n_bins = 256 | ||||
| s = __import__("math").sqrt(n_dims) | ||||
| hist_range = (-1.2, 1.2) | ||||
| histogram_data = np.zeros((n_used_dims, n_bins)) | ||||
|  | ||||
| # Calculate histograms for each dimension | ||||
| for dim in range(n_used_dims): | ||||
|     dbd = data[:, dim] | ||||
|     dbd = (dbd - np.mean(dbd)) / np.std(dbd) | ||||
|     dbd = scale(dbd) | ||||
|     hist, _ = np.histogram(dbd, bins=n_bins, range=hist_range, density=True) | ||||
|     histogram_data[dim] = hist | ||||
|  | ||||
| # Create heatmap | ||||
| plt.figure(figsize=(12, 8)) | ||||
| sns.heatmap(histogram_data, | ||||
|             cmap='viridis', | ||||
|             xticklabels=np.linspace(hist_range[0], hist_range[1], n_bins), | ||||
|             yticklabels=range(n_used_dims), | ||||
|             cbar_kws={'label': 'Density'}) | ||||
|  | ||||
| plt.xlabel('Value') | ||||
| plt.ylabel('Dimension') | ||||
| plt.title('Distribution Heatmap of First 16 Dimensions') | ||||
|  | ||||
| # Adjust layout to prevent label cutoff | ||||
| plt.tight_layout() | ||||
|  | ||||
| plt.show() | ||||
		Reference in New Issue
	
	Block a user
	 osmarks
					osmarks