mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
concurrent index queries and fix database typo yet again
This commit is contained in:
parent
349fe802f7
commit
ce590298a7
106
Cargo.lock
generated
106
Cargo.lock
generated
@ -127,6 +127,20 @@ version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||
|
||||
[[package]]
|
||||
name = "av-data"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d75b98a3525d00f920df9a2d44cc99b9cc5b7dc70d7fbb612cd755270dbe6552"
|
||||
dependencies = [
|
||||
"byte-slice-cast",
|
||||
"bytes",
|
||||
"num-derive",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "av1-grain"
|
||||
version = "0.2.3"
|
||||
@ -259,6 +273,15 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitreader"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bdd859c9d97f7c468252795b35aeccc412bdbb1e90ee6969c4fa6328272eaeff"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitstream-io"
|
||||
version = "2.3.0"
|
||||
@ -286,6 +309,12 @@ version = "3.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||
|
||||
[[package]]
|
||||
name = "byte-slice-cast"
|
||||
version = "1.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.16.0"
|
||||
@ -462,6 +491,38 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dav1d"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d4b54a40baf633a71c6f0fb49494a7e4ee7bc26f3e727212b6cb915aa1ea1e1"
|
||||
dependencies = [
|
||||
"av-data",
|
||||
"bitflags 2.5.0",
|
||||
"dav1d-sys",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dav1d-sys"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ecb1c5e8f4dc438eedc1b534a54672fb0e0a56035dae6b50162787bd2c50e95"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"system-deps",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dcv-color-primitives"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07ad62edfed069700a5b33af6babd29c498d7e33eb01d96ffa8841ee1841634c"
|
||||
dependencies = [
|
||||
"paste",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.7.9"
|
||||
@ -586,6 +647,15 @@ version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b9c008fc56422bf34357f17226d9c5a5c2ef6245b4774759c5f67112e46915e"
|
||||
|
||||
[[package]]
|
||||
name = "fallible_collections"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a88c69768c0a15262df21899142bc6df9b9b823546d4b4b9a7bc2d6c448ec6fd"
|
||||
dependencies = [
|
||||
"hashbrown 0.13.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.1.0"
|
||||
@ -808,6 +878,15 @@ dependencies = [
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
@ -824,7 +903,7 @@ version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
|
||||
dependencies = [
|
||||
"hashbrown",
|
||||
"hashbrown 0.14.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1032,9 +1111,12 @@ dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
"color_quant",
|
||||
"dav1d",
|
||||
"dcv-color-primitives",
|
||||
"exr",
|
||||
"gif",
|
||||
"image-webp",
|
||||
"mp4parse",
|
||||
"num-traits",
|
||||
"png",
|
||||
"qoi",
|
||||
@ -1069,7 +1151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
"hashbrown 0.14.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1357,6 +1439,20 @@ dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mp4parse"
|
||||
version = "0.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63a35203d3c6ce92d5251c77520acb2e57108c88728695aa883f70023624c570"
|
||||
dependencies = [
|
||||
"bitreader",
|
||||
"byteorder",
|
||||
"fallible_collections",
|
||||
"log",
|
||||
"num-traits",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
@ -2519,6 +2615,12 @@ dependencies = [
|
||||
"urlencoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "stringprep"
|
||||
version = "0.1.4"
|
||||
|
@ -8,7 +8,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
axum = "0.7"
|
||||
image = { version = "0.25", features = ["avif"] }
|
||||
image = { version = "0.25", features = ["avif", "avif-native"] }
|
||||
reqwest = { version = "0.12", features = ["multipart"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
|
||||
|
21
src/main.rs
21
src/main.rs
@ -20,10 +20,11 @@ use tokio::sync::{broadcast, mpsc};
|
||||
use tokio::task::JoinHandle;
|
||||
use walkdir::WalkDir;
|
||||
use base64::prelude::*;
|
||||
use faiss::Index;
|
||||
use faiss::{ConcurrentIndex, Index};
|
||||
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use faiss::index::scalar_quantizer;
|
||||
|
||||
mod ocr;
|
||||
|
||||
@ -51,7 +52,7 @@ struct Config {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IIndex {
|
||||
vectors: faiss::index::IndexImpl,
|
||||
vectors: scalar_quantizer::ScalarQuantizerIndexImpl,
|
||||
filenames: Vec<String>,
|
||||
format_codes: Vec<u64>,
|
||||
format_names: Vec<String>,
|
||||
@ -86,7 +87,7 @@ END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_upd AFTER UPDATE ON files BEGIN
|
||||
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
|
||||
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
END;
|
||||
"#;
|
||||
|
||||
@ -590,7 +591,7 @@ async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -
|
||||
let pool = initialize_database(&config).await?;
|
||||
|
||||
let mut index = IIndex {
|
||||
vectors: faiss::index_factory(backend.embedding_size as u32, "SQfp16", faiss::MetricType::InnerProduct)?,
|
||||
vectors: scalar_quantizer::ScalarQuantizerIndexImpl::new(backend.embedding_size as u32, scalar_quantizer::QuantizerType::QT_fp16, faiss::MetricType::InnerProduct)?,
|
||||
filenames: Vec::new(),
|
||||
format_codes: Vec::new(),
|
||||
format_names: Vec::new(),
|
||||
@ -680,7 +681,7 @@ struct QueryRequest {
|
||||
k: Option<usize>,
|
||||
}
|
||||
|
||||
async fn query_index(index: &mut IIndex, query: EmbeddingVector, k: usize) -> Result<QueryResult> {
|
||||
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize) -> Result<QueryResult> {
|
||||
let result = index.vectors.search(&query, k as usize)?;
|
||||
|
||||
let items = result.distances
|
||||
@ -708,7 +709,7 @@ async fn handle_request(
|
||||
config: &Config,
|
||||
backend_config: Arc<InferenceServerConfig>,
|
||||
client: Arc<Client>,
|
||||
index: &mut IIndex,
|
||||
index: &IIndex,
|
||||
req: Json<QueryRequest>,
|
||||
) -> Result<Response<Body>> {
|
||||
let mut total_embedding = ndarray::Array::from(vec![0.0; backend_config.embedding_size]);
|
||||
@ -808,7 +809,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1);
|
||||
|
||||
let index = Arc::new(tokio::sync::Mutex::new(build_index(config.clone(), backend.clone()).await?));
|
||||
let index = Arc::new(tokio::sync::RwLock::new(build_index(config.clone(), backend.clone()).await?));
|
||||
|
||||
let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1);
|
||||
let done_tx = Arc::new(ingest_done_tx.clone());
|
||||
@ -824,7 +825,7 @@ async fn main() -> Result<()> {
|
||||
Ok(_) => {
|
||||
match build_index(config.clone(), backend.clone()).await {
|
||||
Ok(new_index) => {
|
||||
*index.lock().await = new_index;
|
||||
*index.write().await = new_index;
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Index build failed: {:?}", e);
|
||||
@ -851,9 +852,9 @@ async fn main() -> Result<()> {
|
||||
.route("/", post(|req| async move {
|
||||
let config = config.clone();
|
||||
let backend_config = backend.clone();
|
||||
let mut index = index.lock().await; // TODO: use ConcurrentIndex here
|
||||
let index = index.read().await; // TODO: use ConcurrentIndex here
|
||||
let client = client.clone();
|
||||
handle_request(&config, backend_config, client.clone(), &mut index, req).await.map_err(|e| format!("{:?}", e))
|
||||
handle_request(&config, backend_config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
|
||||
}))
|
||||
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
|
||||
"OK"
|
||||
|
Loading…
Reference in New Issue
Block a user