From ce590298a7e4f93dfbadfe0584dfbeabbf063d30 Mon Sep 17 00:00:00 2001 From: osmarks Date: Wed, 22 May 2024 18:25:50 +0100 Subject: [PATCH] concurrent index queries and fix database typo yet again --- Cargo.lock | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 2 +- src/main.rs | 21 ++++++----- 3 files changed, 116 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 543b422..0457eae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,6 +127,20 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "av-data" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d75b98a3525d00f920df9a2d44cc99b9cc5b7dc70d7fbb612cd755270dbe6552" +dependencies = [ + "byte-slice-cast", + "bytes", + "num-derive", + "num-rational", + "num-traits", + "thiserror", +] + [[package]] name = "av1-grain" version = "0.2.3" @@ -259,6 +273,15 @@ dependencies = [ "serde", ] +[[package]] +name = "bitreader" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdd859c9d97f7c468252795b35aeccc412bdbb1e90ee6969c4fa6328272eaeff" +dependencies = [ + "cfg-if", +] + [[package]] name = "bitstream-io" version = "2.3.0" @@ -286,6 +309,12 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "byte-slice-cast" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" + [[package]] name = "bytemuck" version = "1.16.0" @@ -462,6 +491,38 @@ dependencies = [ "typenum", ] +[[package]] +name = "dav1d" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d4b54a40baf633a71c6f0fb49494a7e4ee7bc26f3e727212b6cb915aa1ea1e1" +dependencies = [ + "av-data", + "bitflags 2.5.0", + "dav1d-sys", + "static_assertions", +] + +[[package]] +name = "dav1d-sys" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ecb1c5e8f4dc438eedc1b534a54672fb0e0a56035dae6b50162787bd2c50e95" +dependencies = [ + "libc", + "system-deps", +] + +[[package]] +name = "dcv-color-primitives" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07ad62edfed069700a5b33af6babd29c498d7e33eb01d96ffa8841ee1841634c" +dependencies = [ + "paste", + "wasm-bindgen", +] + [[package]] name = "der" version = "0.7.9" @@ -586,6 +647,15 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b9c008fc56422bf34357f17226d9c5a5c2ef6245b4774759c5f67112e46915e" +[[package]] +name = "fallible_collections" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a88c69768c0a15262df21899142bc6df9b9b823546d4b4b9a7bc2d6c448ec6fd" +dependencies = [ + "hashbrown 0.13.2", +] + [[package]] name = "fastrand" version = "2.1.0" @@ -808,6 +878,15 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -824,7 +903,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown", + "hashbrown 0.14.5", ] [[package]] @@ -1032,9 +1111,12 @@ dependencies = [ "bytemuck", "byteorder", "color_quant", + "dav1d", + "dcv-color-primitives", "exr", "gif", "image-webp", + "mp4parse", "num-traits", "png", "qoi", @@ -1069,7 +1151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.5", ] [[package]] @@ -1357,6 +1439,20 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mp4parse" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63a35203d3c6ce92d5251c77520acb2e57108c88728695aa883f70023624c570" +dependencies = [ + "bitreader", + "byteorder", + "fallible_collections", + "log", + "num-traits", + "static_assertions", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -2519,6 +2615,12 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "stringprep" version = "0.1.4" diff --git a/Cargo.toml b/Cargo.toml index f84055e..dcc064e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] tokio = { version = "1", features = ["full"] } axum = "0.7" -image = { version = "0.25", features = ["avif"] } +image = { version = "0.25", features = ["avif", "avif-native"] } reqwest = { version = "0.12", features = ["multipart"] } serde = { version = "1", features = ["derive"] } sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] } diff --git a/src/main.rs b/src/main.rs index fc0bcdd..89b03ce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,10 +20,11 @@ use tokio::sync::{broadcast, mpsc}; use tokio::task::JoinHandle; use walkdir::WalkDir; use base64::prelude::*; -use faiss::Index; +use faiss::{ConcurrentIndex, Index}; use futures_util::stream::{StreamExt, TryStreamExt}; use tokio_stream::wrappers::ReceiverStream; use tower_http::cors::CorsLayer; +use faiss::index::scalar_quantizer; mod ocr; @@ -51,7 +52,7 @@ struct Config { #[derive(Debug)] struct IIndex { - vectors: faiss::index::IndexImpl, + vectors: scalar_quantizer::ScalarQuantizerIndexImpl, filenames: Vec, format_codes: Vec, format_names: Vec, @@ -86,7 +87,7 @@ END; CREATE TRIGGER IF NOT EXISTS ocr_fts_upd AFTER UPDATE ON files BEGIN INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, '')); - INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, '')); + INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, '')); END; "#; @@ -590,7 +591,7 @@ async fn build_index(config: Arc, backend: Arc) - let pool = initialize_database(&config).await?; let mut index = IIndex { - vectors: faiss::index_factory(backend.embedding_size as u32, "SQfp16", faiss::MetricType::InnerProduct)?, + vectors: scalar_quantizer::ScalarQuantizerIndexImpl::new(backend.embedding_size as u32, scalar_quantizer::QuantizerType::QT_fp16, faiss::MetricType::InnerProduct)?, filenames: Vec::new(), format_codes: Vec::new(), format_names: Vec::new(), @@ -680,7 +681,7 @@ struct QueryRequest { k: Option, } -async fn query_index(index: &mut IIndex, query: EmbeddingVector, k: usize) -> Result { +async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize) -> Result { let result = index.vectors.search(&query, k as usize)?; let items = result.distances @@ -708,7 +709,7 @@ async fn handle_request( config: &Config, backend_config: Arc, client: Arc, - index: &mut IIndex, + index: &IIndex, req: Json, ) -> Result> { let mut total_embedding = ndarray::Array::from(vec![0.0; backend_config.embedding_size]); @@ -808,7 +809,7 @@ async fn main() -> Result<()> { let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1); - let index = Arc::new(tokio::sync::Mutex::new(build_index(config.clone(), backend.clone()).await?)); + let index = Arc::new(tokio::sync::RwLock::new(build_index(config.clone(), backend.clone()).await?)); let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1); let done_tx = Arc::new(ingest_done_tx.clone()); @@ -824,7 +825,7 @@ async fn main() -> Result<()> { Ok(_) => { match build_index(config.clone(), backend.clone()).await { Ok(new_index) => { - *index.lock().await = new_index; + *index.write().await = new_index; } Err(e) => { log::error!("Index build failed: {:?}", e); @@ -851,9 +852,9 @@ async fn main() -> Result<()> { .route("/", post(|req| async move { let config = config.clone(); let backend_config = backend.clone(); - let mut index = index.lock().await; // TODO: use ConcurrentIndex here + let index = index.read().await; // TODO: use ConcurrentIndex here let client = client.clone(); - handle_request(&config, backend_config, client.clone(), &mut index, req).await.map_err(|e| format!("{:?}", e)) + handle_request(&config, backend_config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e)) })) .route("/", get(|_req: axum::http::Request| async move { "OK"