1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00

concurrent index queries and fix database typo yet again

This commit is contained in:
osmarks 2024-05-22 18:25:50 +01:00
parent 349fe802f7
commit ce590298a7
3 changed files with 116 additions and 13 deletions

106
Cargo.lock generated
View File

@ -127,6 +127,20 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "av-data"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75b98a3525d00f920df9a2d44cc99b9cc5b7dc70d7fbb612cd755270dbe6552"
dependencies = [
"byte-slice-cast",
"bytes",
"num-derive",
"num-rational",
"num-traits",
"thiserror",
]
[[package]]
name = "av1-grain"
version = "0.2.3"
@ -259,6 +273,15 @@ dependencies = [
"serde",
]
[[package]]
name = "bitreader"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdd859c9d97f7c468252795b35aeccc412bdbb1e90ee6969c4fa6328272eaeff"
dependencies = [
"cfg-if",
]
[[package]]
name = "bitstream-io"
version = "2.3.0"
@ -286,6 +309,12 @@ version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "byte-slice-cast"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c"
[[package]]
name = "bytemuck"
version = "1.16.0"
@ -462,6 +491,38 @@ dependencies = [
"typenum",
]
[[package]]
name = "dav1d"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d4b54a40baf633a71c6f0fb49494a7e4ee7bc26f3e727212b6cb915aa1ea1e1"
dependencies = [
"av-data",
"bitflags 2.5.0",
"dav1d-sys",
"static_assertions",
]
[[package]]
name = "dav1d-sys"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ecb1c5e8f4dc438eedc1b534a54672fb0e0a56035dae6b50162787bd2c50e95"
dependencies = [
"libc",
"system-deps",
]
[[package]]
name = "dcv-color-primitives"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07ad62edfed069700a5b33af6babd29c498d7e33eb01d96ffa8841ee1841634c"
dependencies = [
"paste",
"wasm-bindgen",
]
[[package]]
name = "der"
version = "0.7.9"
@ -586,6 +647,15 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b9c008fc56422bf34357f17226d9c5a5c2ef6245b4774759c5f67112e46915e"
[[package]]
name = "fallible_collections"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a88c69768c0a15262df21899142bc6df9b9b823546d4b4b9a7bc2d6c448ec6fd"
dependencies = [
"hashbrown 0.13.2",
]
[[package]]
name = "fastrand"
version = "2.1.0"
@ -808,6 +878,15 @@ dependencies = [
"crunchy",
]
[[package]]
name = "hashbrown"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
@ -824,7 +903,7 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
dependencies = [
"hashbrown",
"hashbrown 0.14.5",
]
[[package]]
@ -1032,9 +1111,12 @@ dependencies = [
"bytemuck",
"byteorder",
"color_quant",
"dav1d",
"dcv-color-primitives",
"exr",
"gif",
"image-webp",
"mp4parse",
"num-traits",
"png",
"qoi",
@ -1069,7 +1151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [
"equivalent",
"hashbrown",
"hashbrown 0.14.5",
]
[[package]]
@ -1357,6 +1439,20 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "mp4parse"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63a35203d3c6ce92d5251c77520acb2e57108c88728695aa883f70023624c570"
dependencies = [
"bitreader",
"byteorder",
"fallible_collections",
"log",
"num-traits",
"static_assertions",
]
[[package]]
name = "native-tls"
version = "0.2.11"
@ -2519,6 +2615,12 @@ dependencies = [
"urlencoding",
]
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "stringprep"
version = "0.1.4"

View File

@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
tokio = { version = "1", features = ["full"] }
axum = "0.7"
image = { version = "0.25", features = ["avif"] }
image = { version = "0.25", features = ["avif", "avif-native"] }
reqwest = { version = "0.12", features = ["multipart"] }
serde = { version = "1", features = ["derive"] }
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }

View File

@ -20,10 +20,11 @@ use tokio::sync::{broadcast, mpsc};
use tokio::task::JoinHandle;
use walkdir::WalkDir;
use base64::prelude::*;
use faiss::Index;
use faiss::{ConcurrentIndex, Index};
use futures_util::stream::{StreamExt, TryStreamExt};
use tokio_stream::wrappers::ReceiverStream;
use tower_http::cors::CorsLayer;
use faiss::index::scalar_quantizer;
mod ocr;
@ -51,7 +52,7 @@ struct Config {
#[derive(Debug)]
struct IIndex {
vectors: faiss::index::IndexImpl,
vectors: scalar_quantizer::ScalarQuantizerIndexImpl,
filenames: Vec<String>,
format_codes: Vec<u64>,
format_names: Vec<String>,
@ -86,7 +87,7 @@ END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_upd AFTER UPDATE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
"#;
@ -590,7 +591,7 @@ async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -
let pool = initialize_database(&config).await?;
let mut index = IIndex {
vectors: faiss::index_factory(backend.embedding_size as u32, "SQfp16", faiss::MetricType::InnerProduct)?,
vectors: scalar_quantizer::ScalarQuantizerIndexImpl::new(backend.embedding_size as u32, scalar_quantizer::QuantizerType::QT_fp16, faiss::MetricType::InnerProduct)?,
filenames: Vec::new(),
format_codes: Vec::new(),
format_names: Vec::new(),
@ -680,7 +681,7 @@ struct QueryRequest {
k: Option<usize>,
}
async fn query_index(index: &mut IIndex, query: EmbeddingVector, k: usize) -> Result<QueryResult> {
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize) -> Result<QueryResult> {
let result = index.vectors.search(&query, k as usize)?;
let items = result.distances
@ -708,7 +709,7 @@ async fn handle_request(
config: &Config,
backend_config: Arc<InferenceServerConfig>,
client: Arc<Client>,
index: &mut IIndex,
index: &IIndex,
req: Json<QueryRequest>,
) -> Result<Response<Body>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; backend_config.embedding_size]);
@ -808,7 +809,7 @@ async fn main() -> Result<()> {
let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1);
let index = Arc::new(tokio::sync::Mutex::new(build_index(config.clone(), backend.clone()).await?));
let index = Arc::new(tokio::sync::RwLock::new(build_index(config.clone(), backend.clone()).await?));
let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1);
let done_tx = Arc::new(ingest_done_tx.clone());
@ -824,7 +825,7 @@ async fn main() -> Result<()> {
Ok(_) => {
match build_index(config.clone(), backend.clone()).await {
Ok(new_index) => {
*index.lock().await = new_index;
*index.write().await = new_index;
}
Err(e) => {
log::error!("Index build failed: {:?}", e);
@ -851,9 +852,9 @@ async fn main() -> Result<()> {
.route("/", post(|req| async move {
let config = config.clone();
let backend_config = backend.clone();
let mut index = index.lock().await; // TODO: use ConcurrentIndex here
let index = index.read().await; // TODO: use ConcurrentIndex here
let client = client.clone();
handle_request(&config, backend_config, client.clone(), &mut index, req).await.map_err(|e| format!("{:?}", e))
handle_request(&config, backend_config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
}))
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
"OK"