1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-12-15 04:48:05 +00:00

fixed roargraph things

This commit is contained in:
osmarks
2025-01-11 11:17:51 +00:00
10 changed files with 575 additions and 123 deletions

2
.gitignore vendored
View File

@@ -15,3 +15,5 @@ diskann/target
*.bin *.bin
*.msgpack *.msgpack
*/flamegraph.svg */flamegraph.svg
*.hdf5
*.v

356
Cargo.lock generated
View File

@@ -95,6 +95,37 @@ dependencies = [
"syn 2.0.79", "syn 2.0.79",
] ]
[[package]]
name = "argh"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7af5ba06967ff7214ce4c7419c7d185be7ecd6cc4965a8f6e1d8ce0398aad219"
dependencies = [
"argh_derive",
"argh_shared",
]
[[package]]
name = "argh_derive"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56df0aeedf6b7a2fc67d06db35b09684c3e8da0c95f8f27685cb17e08413d87a"
dependencies = [
"argh_shared",
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "argh_shared"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5693f39141bda5760ecc4111ab08da40565d1771038c4a0250f03457ec707531"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "arrayvec" name = "arrayvec"
version = "0.7.6" version = "0.7.6"
@@ -317,6 +348,30 @@ version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61" checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61"
[[package]]
name = "bitcode"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee1bce7608560cd4bf0296a4262d0dbf13e6bcec5ff2105724c8ab88cc7fc784"
dependencies = [
"arrayvec",
"bitcode_derive",
"bytemuck",
"glam",
"serde",
]
[[package]]
name = "bitcode_derive"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a539389a13af092cd345a2b47ae7dec12deb306d660b2223d25cd3419b253ebe"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@@ -347,6 +402,18 @@ version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
version = "0.10.4" version = "0.10.4"
@@ -379,6 +446,20 @@ name = "bytemuck"
version = "1.19.0" version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d"
dependencies = [
"bytemuck_derive",
]
[[package]]
name = "bytemuck_derive"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]] [[package]]
name = "byteorder" name = "byteorder"
@@ -626,6 +707,31 @@ version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]]
name = "crossterm"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67"
dependencies = [
"bitflags 1.3.2",
"crossterm_winapi",
"libc",
"mio 0.8.11",
"parking_lot",
"signal-hook",
"signal-hook-mio",
"winapi 0.3.9",
]
[[package]]
name = "crossterm_winapi"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
dependencies = [
"winapi 0.3.9",
]
[[package]] [[package]]
name = "crunchy" name = "crunchy"
version = "0.2.2" version = "0.2.2"
@@ -697,6 +803,26 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "diskann"
version = "0.1.0"
dependencies = [
"anyhow",
"bitvec",
"bytemuck",
"fastrand",
"foldhash",
"half",
"matrixmultiply",
"rayon",
"rmp-serde",
"serde",
"simsimd",
"tqdm",
"tracing",
"tracing-subscriber",
]
[[package]] [[package]]
name = "document-features" name = "document-features"
version = "0.2.10" version = "0.2.10"
@@ -894,6 +1020,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
[[package]] [[package]]
name = "foreign-types" name = "foreign-types"
version = "0.3.2" version = "0.3.2"
@@ -918,6 +1050,12 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.31" version = "0.3.31"
@@ -1039,6 +1177,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glam"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc46dd3ec48fdd8e693a98d2b8bafae273a2d54c1de02a2a7e3d57d501f39677"
[[package]] [[package]]
name = "glob" name = "glob"
version = "0.3.1" version = "0.3.1"
@@ -1070,10 +1214,17 @@ version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [ dependencies = [
"bytemuck",
"cfg-if", "cfg-if",
"crunchy", "crunchy",
] ]
[[package]]
name = "hamming"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65043da274378d68241eb9a8f8f8aa54e349136f7b8e12f63e3ef44043cc30e1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.12.3" version = "0.12.3"
@@ -1496,7 +1647,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d" checksum = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d"
dependencies = [ dependencies = [
"winapi", "winapi 0.2.8",
"winapi-build", "winapi-build",
] ]
@@ -1637,6 +1788,28 @@ dependencies = [
"rawpointer", "rawpointer",
] ]
[[package]]
name = "maud"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df518b75016b4289cdddffa1b01f2122f4a49802c93191f3133f6dc2472ebcaa"
dependencies = [
"itoa",
"maud_macros",
]
[[package]]
name = "maud_macros"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa453238ec218da0af6b11fc5978d3b5c3a45ed97b722391a2a11f3306274e18"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]] [[package]]
name = "maybe-rayon" name = "maybe-rayon"
version = "0.1.1" version = "0.1.1"
@@ -1667,23 +1840,31 @@ name = "meme-search-engine"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"argh",
"async-recursion", "async-recursion",
"axum", "axum",
"base64 0.22.1", "base64 0.22.1",
"bitcode",
"bytemuck",
"chrono", "chrono",
"compact_str", "compact_str",
"console-subscriber", "console-subscriber",
"diskann",
"faiss", "faiss",
"fast_image_resize", "fast_image_resize",
"fastrand", "fastrand",
"ffmpeg-the-third", "ffmpeg-the-third",
"fnv", "fnv",
"foldhash",
"futures-util", "futures-util",
"half", "half",
"hamming",
"image", "image",
"itertools 0.13.0", "itertools 0.13.0",
"json5", "json5",
"lazy_static", "lazy_static",
"maud",
"memmap2",
"mimalloc", "mimalloc",
"ndarray", "ndarray",
"num_cpus", "num_cpus",
@@ -1691,9 +1872,11 @@ dependencies = [
"regex", "regex",
"reqwest", "reqwest",
"rmp-serde", "rmp-serde",
"seahash",
"serde", "serde",
"serde_bytes", "serde_bytes",
"serde_json", "serde_json",
"simsimd",
"sonic-rs", "sonic-rs",
"sqlx", "sqlx",
"tokio", "tokio",
@@ -1701,11 +1884,21 @@ dependencies = [
"tower 0.4.13", "tower 0.4.13",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-subscriber",
"url", "url",
"walkdir", "walkdir",
"zstd", "zstd",
] ]
[[package]]
name = "memmap2"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "mimalloc" name = "mimalloc"
version = "0.1.43" version = "0.1.43"
@@ -1756,6 +1949,18 @@ dependencies = [
"simd-adler32", "simd-adler32",
] ]
[[package]]
name = "mio"
version = "0.8.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
dependencies = [
"libc",
"log",
"wasi",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "mio" name = "mio"
version = "1.0.2" version = "1.0.2"
@@ -1843,6 +2048,16 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8"
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi 0.3.9",
]
[[package]] [[package]]
name = "num-bigint" name = "num-bigint"
version = "0.4.6" version = "0.4.6"
@@ -2000,6 +2215,12 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.3" version = "0.12.3"
@@ -2170,6 +2391,29 @@ dependencies = [
"zerocopy", "zerocopy",
] ]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.88" version = "1.0.88"
@@ -2275,6 +2519,12 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]] [[package]]
name = "rand" name = "rand"
version = "0.8.5" version = "0.8.5"
@@ -2625,7 +2875,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d931a44fdaa43b8637009e7632a02adc4f2b2e0733c08caa4cf00e8da4a117a7" checksum = "d931a44fdaa43b8637009e7632a02adc4f2b2e0733c08caa4cf00e8da4a117a7"
dependencies = [ dependencies = [
"kernel32-sys", "kernel32-sys",
"winapi", "winapi 0.2.8",
] ]
[[package]] [[package]]
@@ -2643,6 +2893,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "seahash"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b"
[[package]] [[package]]
name = "security-framework" name = "security-framework"
version = "2.11.1" version = "2.11.1"
@@ -2775,6 +3031,27 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]]
name = "signal-hook-mio"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
dependencies = [
"libc",
"mio 0.8.11",
"signal-hook",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.2" version = "1.4.2"
@@ -2815,6 +3092,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
[[package]]
name = "simsimd"
version = "6.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb18072bc601c31152841c2a114b78c11fe15513f0767eacd57d68359b7130e3"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.9" version = "0.4.9"
@@ -3175,6 +3461,12 @@ dependencies = [
"version-compare", "version-compare",
] ]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]] [[package]]
name = "target-lexicon" name = "target-lexicon"
version = "0.12.16" version = "0.12.16"
@@ -3259,7 +3551,7 @@ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
"libc", "libc",
"mio", "mio 1.0.2",
"parking_lot", "parking_lot",
"pin-project-lite", "pin-project-lite",
"signal-hook-registry", "signal-hook-registry",
@@ -3453,6 +3745,17 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tqdm"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa2d2932240205a99b65f15d9861992c95fbb8c9fb280b3a1f17a92db6dc611f"
dependencies = [
"anyhow",
"crossterm",
"once_cell",
]
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.40" version = "0.1.40"
@@ -3486,6 +3789,17 @@ dependencies = [
"valuable", "valuable",
] ]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]] [[package]]
name = "tracing-subscriber" name = "tracing-subscriber"
version = "0.3.18" version = "0.3.18"
@@ -3493,12 +3807,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [ dependencies = [
"matchers", "matchers",
"nu-ansi-term",
"once_cell", "once_cell",
"regex", "regex",
"sharded-slab", "sharded-slab",
"smallvec",
"thread_local", "thread_local",
"tracing", "tracing",
"tracing-core", "tracing-core",
"tracing-log",
] ]
[[package]] [[package]]
@@ -3630,7 +3947,7 @@ checksum = "bb08f9e670fab86099470b97cd2b252d6527f0b3cc1401acdb595ffc9dd288ff"
dependencies = [ dependencies = [
"kernel32-sys", "kernel32-sys",
"same-file", "same-file",
"winapi", "winapi 0.2.8",
] ]
[[package]] [[package]]
@@ -3753,12 +4070,34 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "167dc9d6949a9b857f3451275e911c3f44255842c1f7a76f33c55103a909087a" checksum = "167dc9d6949a9b857f3451275e911c3f44255842c1f7a76f33c55103a909087a"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]] [[package]]
name = "winapi-build" name = "winapi-build"
version = "0.1.1" version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d315eee3b34aca4797b2da6b13ed88266e6d612562a0c46390af8299fc699bc" checksum = "2d315eee3b34aca4797b2da6b13ed88266e6d612562a0c46390af8299fc699bc"
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]] [[package]]
name = "windows-core" name = "windows-core"
version = "0.52.0" version = "0.52.0"
@@ -3955,6 +4294,15 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"
version = "0.7.35" version = "0.7.35"

View File

@@ -3,7 +3,9 @@ name = "meme-search-engine"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [profile.release-with-debug]
inherits = "release"
debug = true
[dependencies] [dependencies]
tokio = { version = "1", features = ["full", "tracing"] } tokio = { version = "1", features = ["full", "tracing"] }
@@ -44,6 +46,17 @@ compact_str = { version = "0.8.0-beta", features = ["serde"] }
itertools = "0.13" itertools = "0.13"
async-recursion = "1" async-recursion = "1"
fast_image_resize = { version = "5", features = ["image"] } fast_image_resize = { version = "5", features = ["image"] }
argh = "0.1"
maud = "0.26"
hamming = "0.1"
seahash = "4"
tracing-subscriber = "0.3"
diskann = { path = "./diskann" }
bytemuck = "1"
bitcode = "0.6"
simsimd = "6"
foldhash = "0.1"
memmap2 = "0.9"
[[bin]] [[bin]]
name = "reddit-dump" name = "reddit-dump"
@@ -56,3 +69,11 @@ path = "src/video_reader.rs"
[[bin]] [[bin]]
name = "dump-processor" name = "dump-processor"
path = "src/dump_processor.rs" path = "src/dump_processor.rs"
[[bin]]
name = "generate-index-shard"
path = "src/generate_index_shard.rs"
[[bin]]
name = "query-disk-index"
path = "src/query_disk_index.rs"

7
diskann/hdf5_to_bin.py Normal file
View File

@@ -0,0 +1,7 @@
import h5py
import numpy as np
f = h5py.File("glove-200-angular.hdf5", "r")
data = np.array(f["train"])
norms = np.linalg.norm(data, axis=1)
data = data / norms[:, np.newaxis]
data.astype(np.float16).tofile("glove-200-angular.bin")

View File

@@ -55,7 +55,6 @@ impl IndexGraph {
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct IndexBuildConfig { pub struct IndexBuildConfig {
pub r: usize, pub r: usize,
pub r_cap: usize,
pub l: usize, pub l: usize,
pub maxc: usize, pub maxc: usize,
pub alpha: i64 pub alpha: i64
@@ -146,7 +145,14 @@ impl NeighbourBuffer {
self.scores.truncate(self.size); self.scores.truncate(self.size);
self.visited.truncate(self.size); self.visited.truncate(self.size);
self.next_unvisited = Some(loc as u32); match self.next_unvisited {
Some(ref mut next_unvisited) => {
*next_unvisited = (loc as u32).min(*next_unvisited);
},
None => {
self.next_unvisited = Some(loc as u32);
}
}
} }
pub fn clear(&mut self) { pub fn clear(&mut self) {
@@ -194,7 +200,6 @@ pub fn greedy_search(scratch: &mut Scratch, start: u32, query: VectorRef, vecs:
let mut counters = GreedySearchCounters { distances: 0 }; let mut counters = GreedySearchCounters { distances: 0 };
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() { while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
//println!("pt {} {:?}", pt, graph.out_neighbours(pt));
scratch.neighbour_pre_buffer.clear(); scratch.neighbour_pre_buffer.clear();
for &neighbour in graph.out_neighbours(pt).iter() { for &neighbour in graph.out_neighbours(pt).iter() {
if scratch.visited.insert(neighbour) { if scratch.visited.insert(neighbour) {
@@ -299,7 +304,7 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
// somewhat ugly deadlock avoidance hack - only hold writelock at end // somewhat ugly deadlock avoidance hack - only hold writelock at end
let neighbour_neighbours = graph.out_neighbours(neighbour); let neighbour_neighbours = graph.out_neighbours(neighbour);
// To cut down pruning time slightly, allow accumulating more neighbours than usual limit // To cut down pruning time slightly, allow accumulating more neighbours than usual limit
if neighbour_neighbours.len() == config.r_cap { if neighbour_neighbours.len() == config.r {
let mut n = neighbour_neighbours.to_vec(); let mut n = neighbour_neighbours.to_vec();
scratch.visited_list.clear(); scratch.visited_list.clear();
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config); merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
@@ -307,7 +312,7 @@ pub fn build_graph(rng: &mut Rng, graph: &mut IndexGraph, medioid: u32, vecs: &V
robust_prune(scratch, neighbour, &mut n, vecs, config); robust_prune(scratch, neighbour, &mut n, vecs, config);
std::mem::drop(neighbour_neighbours); std::mem::drop(neighbour_neighbours);
*graph.out_neighbours_mut(neighbour) = n; *graph.out_neighbours_mut(neighbour) = n;
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap { } else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r {
// we apparently cannot actually upgrade the read lock to a write lock // we apparently cannot actually upgrade the read lock to a write lock
std::mem::drop(neighbour_neighbours); std::mem::drop(neighbour_neighbours);
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour); let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
@@ -360,7 +365,7 @@ pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec
// as in Vamana algorithm // as in Vamana algorithm
// TODO factor out common code // TODO factor out common code
let neighbour_neighbours = graph.out_neighbours(neighbour); let neighbour_neighbours = graph.out_neighbours(neighbour);
if neighbour_neighbours.len() == config.r_cap { if neighbour_neighbours.len() == config.r {
let mut n = neighbour_neighbours.to_vec(); let mut n = neighbour_neighbours.to_vec();
scratch.visited_list.clear(); scratch.visited_list.clear();
merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config); merge_existing_neighbours(&mut scratch.visited_list, neighbour, &neighbour_neighbours, vecs, config);
@@ -368,7 +373,7 @@ pub fn project_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: &Vec
robust_prune(scratch, neighbour, &mut n, vecs, config); robust_prune(scratch, neighbour, &mut n, vecs, config);
std::mem::drop(neighbour_neighbours); std::mem::drop(neighbour_neighbours);
*graph.out_neighbours_mut(neighbour) = n; *graph.out_neighbours_mut(neighbour) = n;
} else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r_cap { } else if !neighbour_neighbours.contains(&sigma_i) && neighbour_neighbours.len() < config.r {
std::mem::drop(neighbour_neighbours); std::mem::drop(neighbour_neighbours);
let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour); let mut neighbour_neighbours = graph.out_neighbours_mut(neighbour);
neighbour_neighbours.push(sigma_i); neighbour_neighbours.push(sigma_i);
@@ -386,7 +391,7 @@ pub fn augment_bipartite(rng: &mut Rng, graph: &mut IndexGraph, query_knns: Vec<
sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| { sigmas.into_par_iter().for_each_init(|| rng.lock().unwrap().fork(), |rng, sigma_i| {
let mut neighbours = graph.out_neighbours_mut(sigma_i); let mut neighbours = graph.out_neighbours_mut(sigma_i);
let mut i = 0; let mut i = 0;
while neighbours.len() < config.r_cap && i < 100 { while neighbours.len() < config.r && i < 100 {
let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap(); let query_neighbour = *rng.choice(&query_knns[sigma_i as usize]).unwrap();
let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap(); let projected_neighbour = *rng.choice(&query_knns_bwd[query_neighbour as usize]).unwrap();
if !neighbours.contains(&projected_neighbour) { if !neighbours.contains(&projected_neighbour) {
@@ -423,3 +428,18 @@ impl Drop for Timer {
println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32()); println!("{}: {:.2}s", self.0, self.1.elapsed().as_secs_f32());
} }
} }
pub fn report_degrees(graph: &IndexGraph) {
let mut total_degree = 0;
let mut degrees = Vec::with_capacity(graph.graph.len());
for out_neighbours in graph.graph.iter() {
let deg = out_neighbours.read().unwrap().len();
total_degree += deg;
degrees.push(deg);
}
degrees.sort_unstable();
println!("average degree {}", (total_degree as f64) / (graph.graph.len() as f64));
println!("median degree {}", degrees[degrees.len() / 2]);
println!("min degree {}", degrees[0]);
println!("max degree {}", degrees[degrees.len() - 1]);
}

View File

@@ -7,7 +7,7 @@ use std::{io::Read, time::Instant};
use anyhow::Result; use anyhow::Result;
use half::f16; use half::f16;
use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer}; use diskann::{build_graph, IndexBuildConfig, medioid, IndexGraph, greedy_search, Scratch, vector::{fast_dot, SCALE, dot, VectorList, self}, Timer, report_degrees, random_fill_graph};
use simsimd::SpatialSimilarity; use simsimd::SpatialSimilarity;
const D_EMB: usize = 1152; const D_EMB: usize = 1152;
@@ -26,12 +26,13 @@ const PQ_TEST_SIZE: usize = 1000;
fn main() -> Result<()> { fn main() -> Result<()> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
/*/
{ {
let file = std::fs::File::open("opq.msgpack")?; let file = std::fs::File::open("opq.msgpack")?;
let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?; let codec: vector::ProductQuantizer = rmp_serde::from_read(file)?;
let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); let input = load_file("embeddings.bin", Some(D_EMB * PQ_TEST_SIZE))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
let codes = codec.quantize_batch(&input); let codes = codec.quantize_batch(&input);
println!("{:?}", codes); //println!("{:?}", codes);
let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>(); let raw_query = load_file("query.bin", Some(D_EMB))?.data.into_iter().map(|a| a.to_f32()).collect::<Vec<_>>();
let query = codec.preprocess_query(&raw_query); let query = codec.preprocess_query(&raw_query);
let mut real_scores = vec![]; let mut real_scores = vec![];
@@ -41,17 +42,17 @@ fn main() -> Result<()> {
let pq_scores = codec.asymmetric_dot_product(&query, &codes); let pq_scores = codec.asymmetric_dot_product(&query, &codes);
for (x, y) in real_scores.iter().zip(pq_scores.iter()) { for (x, y) in real_scores.iter().zip(pq_scores.iter()) {
let y = (*y as f32) / SCALE; let y = (*y as f32) / SCALE;
println!("{} {} {} {}", x, y, x - y, (x - y) / x); //println!("{} {} {} {}", x, y, x - y, (x - y) / x);
} }
} }*/
let mut rng = fastrand::Rng::with_seed(1); let mut rng = fastrand::Rng::with_seed(1);
let n = 100000; let n = 100_000;
let vecs = { let vecs = {
let _timer = Timer::new("loaded vectors"); let _timer = Timer::new("loaded vectors");
&load_file("embeddings.bin", Some(D_EMB * n))? &load_file("query.bin", Some(D_EMB * n))?
}; };
let (graph, medioid) = { let (graph, medioid) = {
@@ -59,10 +60,9 @@ fn main() -> Result<()> {
let mut config = IndexBuildConfig { let mut config = IndexBuildConfig {
r: 64, r: 64,
r_cap: 80, l: 192,
l: 128,
maxc: 750, maxc: 750,
alpha: 65536, alpha: 65200,
}; };
let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap); let mut graph = IndexGraph::random_r_regular(&mut rng, vecs.len(), config.r, config.r_cap);
@@ -70,8 +70,11 @@ fn main() -> Result<()> {
let medioid = medioid(&vecs); let medioid = medioid(&vecs);
build_graph(&mut rng, &mut graph, medioid, &vecs, config); build_graph(&mut rng, &mut graph, medioid, &vecs, config);
config.alpha = 58000; report_degrees(&graph);
build_graph(&mut rng, &mut graph, medioid, &vecs, config); //random_fill_graph(&mut rng, &mut graph, config.r);
//config.alpha = 65536;
//build_graph(&mut rng, &mut graph, medioid, &vecs, config);
report_degrees(&graph);
(graph, medioid) (graph, medioid)
}; };
@@ -82,8 +85,6 @@ fn main() -> Result<()> {
edge_ctr += adjlist.read().unwrap().len(); edge_ctr += adjlist.read().unwrap().len();
} }
println!("average degree: {}", edge_ctr as f32 / graph.graph.len() as f32);
let time = Instant::now(); let time = Instant::now();
let mut recall = 0; let mut recall = 0;
let mut cmps_ctr = 0; let mut cmps_ctr = 0;
@@ -91,7 +92,6 @@ fn main() -> Result<()> {
let mut config = IndexBuildConfig { let mut config = IndexBuildConfig {
r: 64, r: 64,
r_cap: 64,
l: 50, l: 50,
alpha: 65536, alpha: 65536,
maxc: 0, maxc: 0,

View File

@@ -173,18 +173,17 @@ pub mod index_config {
use diskann::IndexBuildConfig; use diskann::IndexBuildConfig;
pub const BASE_CONFIG: IndexBuildConfig = IndexBuildConfig { pub const BASE_CONFIG: IndexBuildConfig = IndexBuildConfig {
r: 64, r: 40,
r_cap: 80, l: 200,
l: 500, maxc: 900,
maxc: 750, alpha: 65200
alpha: 60000
}; };
pub const PROJECTION_CUT_POINT: usize = 1; pub const PROJECTION_CUT_POINT: usize = 3;
pub const FIRST_PASS_ALPHA: i64 = 65536; pub const FIRST_PASS_ALPHA: i64 = 65200;
pub const SECOND_PASS_ALPHA: i64 = 62000; //pub const SECOND_PASS_ALPHA: i64 = 62000;
pub const QUERY_SEARCH_K: usize = 200; // we want each query to have QUERY_REVERSE_K results, but some queries are likely more common than others in the top-k lists, so oversample a bit pub const QUERY_SEARCH_K: usize = 200; // we want each query to have QUERY_REVERSE_K results, but some queries are likely more common than others in the top-k lists, so oversample a bit
pub const QUERY_REVERSE_K: usize = 100; pub const QUERY_REVERSE_K: usize = 100;

View File

@@ -55,7 +55,7 @@ struct CLIArguments {
#[argh(switch, short='t', description="print titles")] #[argh(switch, short='t', description="print titles")]
titles: bool, titles: bool,
#[argh(option, description="truncate centroids list")] #[argh(option, description="truncate centroids list")]
clip_centroids: Option<usize>, clip_shards: Option<usize>,
#[argh(switch, description="print original linked URL")] #[argh(switch, description="print original linked URL")]
original_url: bool, original_url: bool,
#[argh(option, short='q', description="product quantization codec path")] #[argh(option, short='q', description="product quantization codec path")]
@@ -179,7 +179,7 @@ fn main() -> Result<()> {
let centroids_data = fs::read(centroids).context("read centroids file")?; let centroids_data = fs::read(centroids).context("read centroids file")?;
let mut centroids_data = common::decode_fp16_buffer(&centroids_data); let mut centroids_data = common::decode_fp16_buffer(&centroids_data);
if let Some(clip) = args.clip_centroids { if let Some(clip) = args.clip_shards {
centroids_data.truncate(clip * D_EMB as usize); centroids_data.truncate(clip * D_EMB as usize);
} }
@@ -208,6 +208,14 @@ fn main() -> Result<()> {
let path = file.path(); let path = file.path();
let filename = path.file_name().unwrap().to_str().unwrap(); let filename = path.file_name().unwrap().to_str().unwrap();
let (fst, snd) = filename.split_once(".").unwrap(); let (fst, snd) = filename.split_once(".").unwrap();
let id: u32 = str::parse(fst)?;
if let Some(clip) = args.clip_shards {
if id >= (clip as u32) {
continue;
}
}
if snd == "shard-header.msgpack" { if snd == "shard-header.msgpack" {
let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?; let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?;
if original_ids_to_shards.len() < (header.max as usize + 1) { if original_ids_to_shards.len() < (header.max as usize + 1) {
@@ -237,7 +245,6 @@ fn main() -> Result<()> {
shard_id_mappings.push((header.id, header.mapping)); shard_id_mappings.push((header.id, header.mapping));
} else if snd == "shard.bin" { } else if snd == "shard.bin" {
let file = fs::File::open(&path).context("open shard file")?; let file = fs::File::open(&path).context("open shard file")?;
let id: u32 = str::parse(fst)?;
files.push((id, file)); files.push((id, file));
} }
} }
@@ -245,11 +252,16 @@ fn main() -> Result<()> {
files.sort_by_key(|(id, _)| *id); files.sort_by_key(|(id, _)| *id);
shard_id_mappings.sort_by_key(|(id, _)| *id); shard_id_mappings.sort_by_key(|(id, _)| *id);
let read_out_vertices =move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
let mut out_vertices: Vec<u32> = vec![]; let mut out_vertices: Vec<u32> = vec![];
let mut shards: Vec<u32> = vec![]; let mut shards: Vec<u32> = vec![];
// look up each location in shard files // look up each location in shard files
for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() { for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() {
if (shard, offset, len) == EMPTY_LOOKUP {
continue;
}
shards.push(shard); shards.push(shard);
let shard = shard as usize; let shard = shard as usize;
// this random access is almost certainly rather slow // this random access is almost certainly rather slow
@@ -257,7 +269,7 @@ fn main() -> Result<()> {
files[shard].1.seek(SeekFrom::Start(offset))?; files[shard].1.seek(SeekFrom::Start(offset))?;
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
files[shard].1.read_exact(&mut buf)?; files[shard].1.read_exact(&mut buf)?;
let s: &mut [u32] = bytemuck::cast_slice_mut(&mut *buf); let mut s: Vec<u32> = bytemuck::allocation::pod_collect_to_vec(&*buf);
for within_shard_id in s.iter_mut() { for within_shard_id in s.iter_mut() {
*within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize]; *within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize];
} }

View File

@@ -1,10 +1,10 @@
use anyhow::{Result, Context}; use anyhow::{Result, Context};
use itertools::Itertools; use itertools::Itertools;
use std::io::{BufReader, Write, BufWriter, Seek};
use rmp_serde::decode::Error as DecodeError;
use std::fs;
use std::collections::BinaryHeap; use std::collections::BinaryHeap;
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer}; use std::io::{BufReader, BufWriter, Write};
use std::fs;
use rmp_serde::decode::{Error as DecodeError, from_read};
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList, scale_dot_result}, IndexBuildConfig, IndexGraph, Timer, report_degrees};
use half::f16; use half::f16;
mod common; mod common;
@@ -13,19 +13,6 @@ use common::{index_config::{self, QUERY_REVERSE_K}, ShardHeader, ShardInputHeade
const D_EMB: usize = 1152; const D_EMB: usize = 1152;
fn report_degrees(graph: &IndexGraph) {
let mut total_degree = 0;
let mut degrees = Vec::with_capacity(graph.graph.len());
for out_neighbours in graph.graph.iter() {
let deg = out_neighbours.read().unwrap().len();
total_degree += deg;
degrees.push(deg);
}
degrees.sort_unstable();
println!("average degree {}", (total_degree as f32) / (graph.graph.len() as f32));
println!("median degree {}", degrees[degrees.len() / 2]);
}
fn main() -> Result<()> { fn main() -> Result<()> {
let mut rng = fastrand::Rng::new(); let mut rng = fastrand::Rng::new();
@@ -98,7 +85,7 @@ fn main() -> Result<()> {
length: original_ids.len() length: original_ids.len()
}; };
let mut graph = IndexGraph::empty(original_ids.len(), config.r_cap); let mut graph = IndexGraph::empty(original_ids.len(), config.r);
{ {
let _timer = Timer::new("project bipartite"); let _timer = Timer::new("project bipartite");
@@ -126,14 +113,6 @@ fn main() -> Result<()> {
report_degrees(&graph); report_degrees(&graph);
{
let _timer = Timer::new("second pass");
config.alpha = common::index_config::SECOND_PASS_ALPHA;
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
}
report_degrees(&graph);
std::mem::drop(vecs); std::mem::drop(vecs);
let len = original_ids.len(); let len = original_ids.len();

View File

@@ -7,7 +7,6 @@ use std::fs;
use base64::Engine; use base64::Engine;
use argh::FromArgs; use argh::FromArgs;
use chrono::{TimeZone, Utc, DateTime}; use chrono::{TimeZone, Utc, DateTime};
use std::collections::VecDeque;
use itertools::Itertools; use itertools::Itertools;
use foldhash::{HashSet, HashSetExt}; use foldhash::{HashSet, HashSetExt};
use half::f16; use half::f16;
@@ -23,9 +22,17 @@ use common::{PackedIndexEntry, IndexHeader};
#[argh(description="Query disk index")] #[argh(description="Query disk index")]
struct CLIArguments { struct CLIArguments {
#[argh(positional)] #[argh(positional)]
query_vector: String, index_path: String,
#[argh(positional)] #[argh(option, short='q', description="query vector in base64")]
index_path: String query_vector_base64: Option<String>,
#[argh(option, short='f', description="file of FP16 query vectors")]
query_vector_file: Option<String>,
#[argh(switch, short='v', description="verbose")]
verbose: bool,
#[argh(option, short='n', description="stop at n queries")]
n: Option<usize>,
#[argh(switch, description="always use full-precision vectors (slow)")]
disable_pq: bool
} }
fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> { fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> {
@@ -56,7 +63,7 @@ struct IndexRef<'a> {
pq_code_size: usize pq_code_size: usize
} }
fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef) -> Result<(usize, usize)> { fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef, disable_pq: bool) -> Result<(usize, usize)> {
scratch.visited.clear(); scratch.visited.clear();
scratch.neighbour_buffer.clear(); scratch.neighbour_buffer.clear();
scratch.visited_list.clear(); scratch.visited_list.clear();
@@ -88,24 +95,48 @@ fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preproc
} }
let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes); let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes);
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
//let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO if disable_pq {
//let node = read_node(neighbour, index.data_file, index.header)?; //let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
//let vector = bytemuck::cast_slice(&node.vector); let node = read_node(neighbour, index.data_file, index.header)?;
//let distance = fast_dot_noprefetch(query, &vector); let vector = bytemuck::cast_slice(&node.vector);
pq_cmps += 1; let distance = fast_dot_noprefetch(query, &vector);
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]); scratch.neighbour_buffer.insert(neighbour, distance);
//scratch.neighbour_buffer.insert(neighbour, distance); } else {
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]);
pq_cmps += 1;
}
} }
} }
Ok((cmps, pq_cmps)) Ok((cmps, pq_cmps))
} }
fn summary_stats(ranks: &mut [usize]) {
let sum = ranks.iter().sum::<usize>();
let mean = sum as f64 / ranks.len() as f64 + 1.0;
ranks.sort_unstable();
let median = ranks[ranks.len() / 2] + 1;
let harmonic_mean = ranks.iter().map(|x| 1.0 / ((x+1) as f64)).sum::<f64>() / ranks.len() as f64;
println!("median {} mean {} max {} min {} harmonic mean {}", median, mean, ranks[ranks.len() - 1] + 1, ranks[0] + 1, 1.0 / harmonic_mean);
}
fn main() -> Result<()> { fn main() -> Result<()> {
let args: CLIArguments = argh::from_env(); let args: CLIArguments = argh::from_env();
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(args.query_vector.as_bytes()).context("invalid base64")?); let mut queries = vec![];
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
if let Some(query_vector_base64) = args.query_vector_base64 {
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(query_vector_base64.as_bytes()).context("invalid base64")?);
queries.push(query_vector);
}
if let Some(query_vector_file) = args.query_vector_file {
let query_vectors = fs::read(query_vector_file)?;
queries.extend(common::chunk_fp16_buffer(&query_vectors).chunks(1152).map(|x| x.to_vec()).collect::<Vec<_>>());
}
if let Some(n) = args.n {
queries.truncate(n);
}
let index_path = PathBuf::from(&args.index_path); let index_path = PathBuf::from(&args.index_path);
let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?; let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?;
@@ -117,57 +148,90 @@ fn main() -> Result<()> {
MmapOptions::new().populate().map(&pq_codes_file)? MmapOptions::new().populate().map(&pq_codes_file)?
}; };
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32);
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
// TODO slightly dubious let mut top_20_ranks_best_shard = vec![];
let selected_shard = header.shards.iter().position_max_by_key(|x| { let mut top_rank_best_shard = vec![];
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
}).unwrap();
println!("best shard is {}", selected_shard); for query_vector in queries.iter() {
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32);
for shard in 0..header.shards.len() { // TODO slightly dubious
let selected_start = header.shards[shard].1; let selected_shard = header.shards.iter().position_max_by_key(|x| {
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
}).unwrap();
let mut scratch = Scratch { if args.verbose {
visited: HashSet::new(), println!("selected shard is {}", selected_shard);
neighbour_buffer: NeighbourBuffer::new(5000),
neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new()
};
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng);
let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
data_file: &mut data_file,
header: &header,
pq_codes: &pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
})?;
println!("index scan {}: {:?} cmps", shard, cmps);
scratch.visited_list.sort_by_key(|x| -x.1);
for (id, distance, url, shards) in scratch.visited_list.iter().take(20) {
println!("index scan: {} {} {} {:?}", id, distance, url, shards);
} }
println!("");
let mut matches = vec![];
// brute force scan
for i in 0..header.count {
let node = read_node(i, &mut data_file, &header)?;
//println!("{} {}", i, node.url);
let vector = bytemuck::cast_slice(&node.vector);
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
}
matches.sort_unstable_by_key(|x| -x.1);
let mut matches = matches.into_iter().enumerate().map(|(i, (id, distance, url, shards))| (id, i)).collect::<Vec<_>>();
matches.sort_unstable();
/*for (id, distance, url, shards) in matches.iter().take(20) {
println!("brute force: {} {} {} {:?}", id, distance, url, shards);
}*/
let mut top_ranks = vec![usize::MAX; 20];
for shard in 0..header.shards.len() {
let selected_start = header.shards[shard].1;
let mut scratch = Scratch {
visited: HashSet::new(),
neighbour_buffer: NeighbourBuffer::new(300),
neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new()
};
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng);
let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
data_file: &mut data_file,
header: &header,
pq_codes: &pq_codes,
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
}, args.disable_pq)?;
if args.verbose {
println!("index scan {}: {:?} cmps", shard, cmps);
}
scratch.visited_list.sort_by_key(|x| -x.1);
for (i, (id, distance, url, shards)) in scratch.visited_list.iter().take(20).enumerate() {
if args.verbose {
println!("index scan: {} {} {} {:?}", id, distance, url, shards);
};
let found_id = match matches.binary_search(&(*id, 0)) {
Ok(pos) => pos,
Err(pos) => pos
};
if args.verbose {
println!("rank {}", matches[found_id].1);
};
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
}
if args.verbose { println!("") }
}
top_rank_best_shard.push(top_ranks[0]);
top_20_ranks_best_shard.extend(top_ranks);
} }
let mut matches = vec![]; println!("ranks of top 20:");
// brute force scan summary_stats(&mut top_20_ranks_best_shard);
for i in 0..header.count { println!("ranks of top 1:");
let node = read_node(i, &mut data_file, &header)?; summary_stats(&mut top_rank_best_shard);
//println!("{} {}", i, node.url);
let vector = bytemuck::cast_slice(&node.vector);
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
}
matches.sort_by_key(|x| -x.1);
for (id, distance, url, shards) in matches.iter().take(20) {
println!("brute force: {} {} {} {:?}", id, distance, url, shards);
}
Ok(()) Ok(())
} }