use std::collections::HashSet; use std::{collections::HashMap, io::Cursor}; use std::path::Path; use std::sync::Arc; use anyhow::{Result, Context}; use axum::body::Body; use axum::response::Response; use axum::{ extract::{Json, DefaultBodyLimit}, response::IntoResponse, routing::{get, post}, Router, http::StatusCode }; use common::{resize_for_embed_sync, FrontendInit}; use compact_str::CompactString; use image::RgbImage; use image::{imageops::FilterType, ImageReader, DynamicImage, ImageFormat}; use reqwest::Client; use serde::{Deserialize, Serialize}; use sqlx::SqliteConnection; use sqlx::{sqlite::SqliteConnectOptions, SqlitePool}; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::task::JoinHandle; use walkdir::WalkDir; use faiss::{ConcurrentIndex, Index}; use futures_util::stream::{StreamExt, TryStreamExt}; use tokio_stream::wrappers::ReceiverStream; use tower_http::cors::CorsLayer; use faiss::index::scalar_quantizer; use lazy_static::lazy_static; use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec}; use tracing::instrument; use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine}; use mimalloc::MiMalloc; mod ocr; mod common; mod video_reader; use crate::ocr::scan_image; use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server, decode_fp16_buffer, QueryRequest, QueryResult, EmbeddingVector}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; lazy_static! { static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap(); static ref QUERIES_COUNTER: IntCounter = register_int_counter!("mse_queries", "queries executed").unwrap(); static ref IMAGES_LOADED_COUNTER: IntCounter = register_int_counter!("mse_loads", "images loaded by ingest process").unwrap(); static ref IMAGES_LOADED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_load_errors", "image load fails by ingest process").unwrap(); static ref VIDEOS_LOADED_COUNTER: IntCounter = register_int_counter!("mse_video_loads", "video loaded by ingest process").unwrap(); static ref VIDEOS_LOADED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_video_load_errors", "video load fails by ingest process").unwrap(); static ref IMAGES_EMBEDDED_COUNTER: IntCounter = register_int_counter!("mse_embeds", "images embedded by ingest process").unwrap(); static ref IMAGES_OCRED_COUNTER: IntCounter = register_int_counter!("mse_ocrs", "images OCRed by ingest process").unwrap(); static ref IMAGES_OCRED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_ocr_errors", "image OCR fails by ingest process").unwrap(); static ref IMAGES_THUMBNAILED_COUNTER: IntCounter = register_int_counter!("mse_thumbnails", "images thumbnailed by ingest process").unwrap(); static ref THUMBNAILS_GENERATED_COUNTER: IntCounterVec = register_int_counter_vec!("mse_thumbnail_outputs", "thumbnails produced by ingest process", &["output_format"]).unwrap(); static ref LAST_INDEX_SIZE: IntGauge = register_int_gauge!("mse_index_size", "images in loaded index").unwrap(); } fn function_which_returns_50() -> usize { 50 } fn function_which_will_return_the_integer_one_successor_of_zero_but_as_a_float() -> f32 { 1.0 } #[derive(Debug, Deserialize, Clone)] struct Config { clip_server: String, db_path: String, port: u16, files: String, #[serde(default)] enable_ocr: bool, #[serde(default)] thumbs_path: String, #[serde(default)] enable_thumbs: bool, #[serde(default="function_which_returns_50")] ocr_concurrency: usize, #[serde(default)] no_run_server: bool, #[serde(default="function_which_will_return_the_integer_one_successor_of_zero_but_as_a_float")] video_frame_interval: f32 } #[derive(Debug, Clone)] struct WConfig { backend: InferenceServerConfig, service: Config, predefined_embeddings: HashMap, ndarray::Dim<[usize; 1]>>> } #[derive(Debug)] struct IIndex { vectors: scalar_quantizer::ScalarQuantizerIndexImpl, filenames: Vec, format_codes: Vec, format_names: Vec, metadata: Vec> } const SCHEMA: &[&str] = &[ r#" CREATE TABLE IF NOT EXISTS files ( filename TEXT NOT NULL PRIMARY KEY, embedding_time INTEGER, ocr_time INTEGER, thumbnail_time INTEGER, embedding BLOB, ocr TEXT, raw_ocr_segments BLOB, thumbnails BLOB ); CREATE TABLE IF NOT EXISTS predefined_embeddings ( name TEXT NOT NULL PRIMARY KEY, embedding BLOB NOT NULL ); DROP TRIGGER IF EXISTS ocr_fts_upd; DROP TRIGGER IF EXISTS ocr_fts_ins; DROP TRIGGER IF EXISTS ocr_fts_del; DROP TABLE IF EXISTS ocr_fts; "#, r#" ALTER TABLE files ADD COLUMN metadata BLOB; "#]; #[derive(Debug, Clone, Serialize, Deserialize)] struct FileMetadata { width: u32, height: u32, frames: Option } #[derive(Debug, sqlx::FromRow, Clone)] struct RawFileRecord { filename: Vec, embedding_time: Option, ocr_time: Option, thumbnail_time: Option, embedding: Option>, // this totally "will" be used later ocr: Option, raw_ocr_segments: Option>, thumbnails: Option>, metadata: Option> } #[derive(Debug, Clone)] struct FileRecord { filename: CompactString, needs_embed: bool, needs_ocr: bool, needs_thumbnail: bool, needs_metadata: bool } #[derive(Debug)] struct LoadedImage { image: Arc, filename: Filename, original_filesize: Option, fast_thumbnails_only: bool } #[derive(Debug, Clone, Serialize, Deserialize, Hash)] enum Filename { Actual(CompactString), VideoFrame(CompactString, u32) } // this is a somewhat horrible hack, but probably nobody has NUL bytes at the start of filenames? impl Filename { fn decode(buf: Vec) -> Result { Ok(match buf.strip_prefix(&[0]) { Some(remainder) => rmp_serde::from_read(&*remainder)?, None => Filename::Actual(CompactString::from_utf8(buf)?) }) } fn encode(&self) -> Result> { match self { Self::Actual(s) => Ok(s.to_string().into_bytes()), x => { let mut out = rmp_serde::to_vec(x).context("should not happen")?; out.insert(0, 0); Ok(out) } } } fn container_filename(&self) -> String { match self { Self::Actual(s) => s.to_string(), Self::VideoFrame(s, _) => s.to_string() } } } #[derive(Debug)] struct EmbeddingInput { image: Vec, filename: Filename, } fn timestamp() -> i64 { chrono::Utc::now().timestamp_micros() } #[derive(Debug, Clone)] struct ImageFormatConfig { target_width: u32, target_filesize: usize, quality: u8, format: ImageFormat, extension: String, is_fast: bool } fn generate_filename_hash(filename: &Filename) -> String { use std::hash::{Hash, Hasher}; let mut hasher = fnv::FnvHasher::default(); match filename { Filename::Actual(x) => x.hash(&mut hasher), _ => filename.hash(&mut hasher) }; BASE64_URL_SAFE_NO_PAD.encode(hasher.finish().to_le_bytes()) } fn generate_thumbnail_filename( filename: &Filename, format_name: &str, format_config: &ImageFormatConfig, ) -> String { format!( "{}{}.{}", generate_filename_hash(filename), format_name, format_config.extension ) } async fn initialize_database(config: &Config) -> Result { let connection_options = SqliteConnectOptions::new() .filename(&config.db_path) .create_if_missing(true); let pool = SqlitePool::connect_with(connection_options).await?; let mut tx = pool.begin().await?; let version = sqlx::query_scalar!("PRAGMA user_version").fetch_one(&mut *tx).await?.unwrap(); for (index, sql) in SCHEMA.iter().enumerate() { if (index as i32) < version { continue } tracing::info!("Migrating to DB version {}", index); sqlx::query(sql).execute(&mut *tx).await?; sqlx::query(&format!("PRAGMA user_version = {}", index + 1)).execute(&mut *tx).await?; } tx.commit().await?; Ok(pool) } fn image_formats(_config: &Config) -> HashMap { let mut formats = HashMap::new(); formats.insert( "jpegl".to_string(), ImageFormatConfig { target_width: 800, target_filesize: 0, quality: 70, format: ImageFormat::Jpeg, extension: "jpg".to_string(), is_fast: true }, ); formats.insert( "jpegh".to_string(), ImageFormatConfig { target_width: 1600, target_filesize: 0, quality: 80, format: ImageFormat::Jpeg, extension: "jpg".to_string(), is_fast: true }, ); formats.insert( "jpeg256kb".to_string(), ImageFormatConfig { target_width: 500, target_filesize: 256000, quality: 0, format: ImageFormat::Jpeg, extension: "jpg".to_string(), is_fast: false }, ); formats.insert( "avifh".to_string(), ImageFormatConfig { target_width: 1600, target_filesize: 0, quality: 80, format: ImageFormat::Avif, extension: "avif".to_string(), is_fast: false }, ); formats.insert( "avifl".to_string(), ImageFormatConfig { target_width: 800, target_filesize: 0, quality: 70, format: ImageFormat::Avif, extension: "avif".to_string(), is_fast: false }, ); formats } #[instrument(skip_all)] async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc: &Vec) -> Result<()> { sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename_enc) .execute(conn) .await?; Ok(()) } #[instrument(skip_all)] async fn write_metadata(conn: &mut SqliteConnection, filename_enc: &Vec, metadata: FileMetadata) -> Result<()> { ensure_filename_record_exists(conn, filename_enc).await?; let metadata_serialized = rmp_serde::to_vec_named(&metadata)?; sqlx::query!("UPDATE files SET metadata = ? WHERE filename = ?", metadata_serialized, filename_enc) .execute(conn) .await?; Ok(()) } #[instrument] async fn handle_embedding_batch(client: reqwest::Client, config: Arc, pool: SqlitePool, batch: Vec, video_embed_times: Arc>>) -> Result<()> { let result: Vec = query_clip_server( &client, &config.service.clip_server, "", EmbeddingRequest::Images { images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(), }, ).await.context("querying CLIP server")?; let mut tx = pool.begin().await?; let ts = timestamp(); for (i, vector) in result.into_iter().enumerate() { let vector = vector.into_vec(); tracing::debug!("embedded {:?}", batch[i].filename); let encoded_filename = batch[i].filename.encode()?; IMAGES_EMBEDDED_COUNTER.inc(); ensure_filename_record_exists(&mut *tx, &encoded_filename).await?; match &batch[i].filename { Filename::VideoFrame(container, _) => { video_embed_times.write().await.insert(container.clone(), timestamp()); }, _ => () } sqlx::query!( "UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?", ts, vector, encoded_filename ) .execute(&mut *tx) .await?; } tx.commit().await?; anyhow::Result::Ok(()) } #[instrument(skip(to_embed_tx, to_thumbnail_tx, to_ocr_tx, to_metadata_write_tx, video_meta))] async fn load_image(record: FileRecord, to_embed_tx: mpsc::Sender, to_thumbnail_tx: mpsc::Sender, to_ocr_tx: mpsc::Sender, to_metadata_write_tx: mpsc::Sender<(Filename, FileMetadata)>, config: Arc, video_meta: Arc>>) -> Result<()> { let path = Path::new(&config.service.files).join(&*record.filename); let image: Result> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?))); let image = match image { Ok(image) => image, Err(e) => { tracing::warn!("Could not read {} as image: {}", record.filename, e); let filename = record.filename.clone(); IMAGES_LOADED_ERROR_COUNTER.inc(); let meta = tokio::task::spawn_blocking(move || -> Result> { let mut i: u32 = 0; let mut last_metadata = None; let callback = |frame: RgbImage| { let frame: Arc = Arc::new(frame.into()); let embed_buf = resize_for_embed_sync(&config.backend, frame.clone())?; let filename = Filename::VideoFrame(filename.clone(), i); to_embed_tx.blocking_send(EmbeddingInput { image: embed_buf, filename: filename.clone() })?; let meta = FileMetadata { height: frame.height(), width: frame.width(), frames: Some(i + 1) }; last_metadata = Some(meta.clone()); to_metadata_write_tx.blocking_send((filename.clone(), meta))?; if config.service.enable_thumbs { to_thumbnail_tx.blocking_send(LoadedImage { image: frame.clone(), filename, original_filesize: None, fast_thumbnails_only: true })?; } i += 1; Ok(()) }; match video_reader::run(&path, callback, config.service.video_frame_interval) { Ok(()) => { VIDEOS_LOADED_COUNTER.inc(); return anyhow::Result::Ok(last_metadata) }, Err(e) => { tracing::error!("Could not read {} as video: {}", filename, e); VIDEOS_LOADED_ERROR_COUNTER.inc(); } } return anyhow::Result::Ok(last_metadata) }).await??; if let Some(meta) = meta { video_meta.write().await.insert(record.filename, meta); } return Ok(()) } }; let filename = Filename::Actual(record.filename); if record.needs_metadata { let metadata = FileMetadata { width: image.width(), height: image.height(), frames: None }; to_metadata_write_tx.send((filename.clone(), metadata)).await?; } IMAGES_LOADED_COUNTER.inc(); if record.needs_embed { let resized = resize_for_embed(config.backend.clone(), image.clone()).await?; to_embed_tx.send(EmbeddingInput { image: resized, filename: filename.clone() }).await? } if record.needs_thumbnail { to_thumbnail_tx .send(LoadedImage { image: image.clone(), filename: filename.clone(), original_filesize: Some(std::fs::metadata(&path)?.len() as usize), fast_thumbnails_only: false }) .await?; } if record.needs_ocr { to_ocr_tx .send(LoadedImage { image, filename: filename.clone(), original_filesize: None, fast_thumbnails_only: true }) .await?; } Ok(()) } #[instrument(skip(video_thumb_times, pool, formats))] async fn generate_thumbnail(image: LoadedImage, config: Arc, video_thumb_times: Arc>>, pool: SqlitePool, formats: Arc>) -> Result<()> { use image::codecs::*; let filename = image.filename.clone(); tracing::debug!("thumbnailing {:?}", filename); let generated_formats = tokio::task::spawn_blocking(move || { let mut generated_formats = Vec::new(); let rgb = DynamicImage::from(image.image.to_rgb8()); for (format_name, format_config) in &*formats { if !format_config.is_fast && image.fast_thumbnails_only { continue } let resized = if format_config.target_filesize != 0 { let mut lb = 1; let mut ub = 100; loop { let quality = (lb + ub) / 2; let thumbnail = rgb.resize( format_config.target_width.min(rgb.width()), u32::MAX, FilterType::Lanczos3, ); let mut buf: Vec = Vec::new(); let mut csr = Cursor::new(&mut buf); // this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait) match format_config.format { ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)), ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)), _ => unimplemented!() }?; if buf.len() > format_config.target_filesize { ub = quality; } else { lb = quality + 1; } if lb >= ub { break buf; } } } else { let thumbnail = rgb.resize( format_config.target_width.min(rgb.width()), u32::MAX, FilterType::Lanczos3, ); let mut buf: Vec = Vec::new(); let mut csr = Cursor::new(&mut buf); match format_config.format { ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)), ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)), ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)), _ => unimplemented!() }?; buf }; if resized.len() < image.original_filesize.unwrap_or(usize::MAX) { generated_formats.push(format_name.clone()); let thumbnail_path = Path::new(&config.service.thumbs_path).join( generate_thumbnail_filename( &image.filename, format_name, format_config, ), ); THUMBNAILS_GENERATED_COUNTER.get_metric_with_label_values(&[format_name]).unwrap().inc(); std::fs::write(thumbnail_path, resized)?; } } Ok::, anyhow::Error>(generated_formats) }).await??; IMAGES_THUMBNAILED_COUNTER.inc(); let formats_data = rmp_serde::to_vec(&generated_formats)?; let ts = timestamp(); let filename_enc = filename.encode()?; let mut conn = pool.acquire().await?; ensure_filename_record_exists(&mut conn, &filename_enc).await?; match filename { Filename::VideoFrame(container, _) => { video_thumb_times.write().await.insert(container.clone(), timestamp()); }, _ => () } sqlx::query!( "UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?", formats_data, ts, filename_enc ) .execute(&mut *conn) .await?; Ok(()) } #[instrument] async fn do_ocr(image: LoadedImage, config: Arc, client: Client, pool: SqlitePool) -> Result<()> { tracing::debug!("OCRing {:?}", image.filename); let scan = match scan_image(&client, &image.image).await { Ok(scan) => scan, Err(e) => { IMAGES_OCRED_ERROR_COUNTER.inc(); tracing::error!("OCR failure {:?}: {}", image.filename, e); return Ok(()) } }; IMAGES_OCRED_COUNTER.inc(); let ocr_text = scan .iter() .map(|segment| segment.text.clone()) .collect::>() .join("\n"); let ocr_data = rmp_serde::to_vec(&scan)?; let ts = timestamp(); let filename_enc = image.filename.encode()?; let mut conn = pool.acquire().await?; ensure_filename_record_exists(&mut conn, &filename_enc).await?; sqlx::query!( "UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?", ocr_text, ocr_data, ts, filename_enc ) .execute(&mut *conn) .await?; Ok(()) } #[instrument] async fn ingest_files(config: Arc) -> Result<()> { let pool = initialize_database(&config.service).await?; let client = Client::new(); let formats = image_formats(&config.service); let (to_process_tx, to_process_rx) = mpsc::channel::(100); let (to_embed_tx, to_embed_rx) = mpsc::channel(config.backend.batch as usize); let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30); let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30); let (to_metadata_write_tx, mut to_metadata_write_rx) = mpsc::channel::<(Filename, FileMetadata)>(100); let cpus = num_cpus::get(); let video_meta = Arc::new(RwLock::new(HashMap::new())); let video_thumb_times = Arc::new(RwLock::new(HashMap::new())); let video_embed_times = Arc::new(RwLock::new(HashMap::new())); // Image loading and preliminary resizing let image_loading: JoinHandle> = tokio::spawn({ let config = config.clone(); let video_meta = video_meta.clone(); let stream = ReceiverStream::new(to_process_rx).map(Ok); stream.try_for_each_concurrent(Some(cpus), move |record| { let config = config.clone(); let to_embed_tx = to_embed_tx.clone(); let to_thumbnail_tx = to_thumbnail_tx.clone(); let to_ocr_tx = to_ocr_tx.clone(); let video_meta = video_meta.clone(); let to_metadata_write_tx = to_metadata_write_tx.clone(); load_image(record, to_embed_tx, to_thumbnail_tx, to_ocr_tx, to_metadata_write_tx, config, video_meta) }) }); let metadata_writer: JoinHandle> = tokio::spawn({ let pool = pool.clone(); async move { while let Some((filename, metadata)) = to_metadata_write_rx.recv().await { write_metadata(&mut *pool.acquire().await?, &filename.encode()?, metadata).await?; } Ok(()) } }); let thumbnail_generation: Option>> = if config.service.enable_thumbs { let config = config.clone(); let pool = pool.clone(); let stream = ReceiverStream::new(to_thumbnail_rx).map(Ok); let formats = Arc::new(formats); let video_thumb_times = video_thumb_times.clone(); Some(tokio::spawn({ stream.try_for_each_concurrent(Some(cpus), move |image| { let formats = formats.clone(); let config = config.clone(); let pool = pool.clone(); let video_thumb_times = video_thumb_times.clone(); generate_thumbnail(image, config, video_thumb_times, pool, formats) }) })) } else { None }; // TODO: save OCR errors and don't retry let ocr: Option>> = if config.service.enable_ocr { let client = client.clone(); let pool = pool.clone(); let config = config.clone(); let stream = ReceiverStream::new(to_ocr_rx).map(Ok); Some(tokio::spawn({ stream.try_for_each_concurrent(Some(config.service.ocr_concurrency), move |image| { let client = client.clone(); let pool = pool.clone(); let config = config.clone(); do_ocr(image, config, client, pool) }) })) } else { None }; let embedding_generation: JoinHandle> = tokio::spawn({ let stream = ReceiverStream::new(to_embed_rx).chunks(config.backend.batch); let client = client.clone(); let config = config.clone(); let pool = pool.clone(); let video_embed_times = video_embed_times.clone(); // keep multiple embedding requests in flight stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| { let client = client.clone(); let config = config.clone(); let pool = pool.clone(); let video_embed_times = video_embed_times.clone(); handle_embedding_batch(client, config, pool, batch, video_embed_times) }) }); let mut actual_filenames = HashMap::new(); // blocking OS calls tokio::task::block_in_place(|| -> anyhow::Result<()> { for entry in WalkDir::new(config.service.files.as_str()) { let entry = entry?; let path = entry.path(); if path.is_file() { let filename = CompactString::from(path.strip_prefix(&config.service.files)?.to_str().unwrap()); let modtime = entry.metadata()?.modified()?.duration_since(std::time::UNIX_EPOCH)?; let modtime = modtime.as_micros() as i64; actual_filenames.insert(filename.clone(), (path.to_path_buf(), modtime)); } } Ok(()) })?; tracing::debug!("finished reading filenames"); for (filename, (_path, modtime)) in actual_filenames.iter() { let modtime = *modtime; let filename_arr = filename.as_bytes(); let record = sqlx::query_as!(RawFileRecord, "SELECT * FROM files WHERE filename = ?", filename_arr) .fetch_optional(&pool) .await?; let new_record = match record { None => Some(FileRecord { filename: filename.clone(), needs_embed: true, needs_ocr: config.service.enable_ocr, needs_thumbnail: config.service.enable_thumbs, needs_metadata: true }), Some(r) => { let needs_embed = modtime > r.embedding_time.unwrap_or(i64::MIN); let needs_ocr = modtime > r.ocr_time.unwrap_or(i64::MIN) && config.service.enable_ocr; let needs_thumbnail = modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.service.enable_thumbs; let needs_metadata = modtime > r.embedding_time.unwrap_or(i64::MIN) || r.metadata.is_none(); // we don't store metadata acquisition time so assume it happens roughly when embedding does if needs_embed || needs_ocr || needs_thumbnail || needs_metadata { Some(FileRecord { filename: filename.clone(), needs_embed, needs_ocr, needs_thumbnail, needs_metadata }) } else { None } } }; if let Some(record) = new_record { tracing::debug!("processing {}", record.filename); // we need to exit here to actually capture the error if !to_process_tx.send(record).await.is_ok() { break } } } drop(to_process_tx); embedding_generation.await?.context("generating embeddings")?; metadata_writer.await?.context("writing metadata")?; if let Some(thumbnail_generation) = thumbnail_generation { thumbnail_generation.await?.context("generating thumbnails")?; } if let Some(ocr) = ocr { ocr.await?.context("OCRing")?; } image_loading.await?.context("loading images")?; let stored: Vec> = sqlx::query_scalar("SELECT filename FROM files").fetch_all(&pool).await?; let mut tx = pool.begin().await?; let video_meta = video_meta.read().await; for filename in stored { let parsed_filename = Filename::decode(filename.clone())?; match parsed_filename { Filename::Actual(s) => { let s = &*s; let raw = &filename; if !actual_filenames.contains_key(s) { sqlx::query!("DELETE FROM files WHERE filename = ?", raw) .execute(&mut *tx) .await?; } }, // This might fail in some cases where for whatever reason a video is replaced with a file of the same name which is not a video. Don't do that. Filename::VideoFrame(container, frame) => { // We don't necessarily have video lengths accessible, but any time a video is modified they will be available. if !actual_filenames.contains_key(&container) || frame > video_meta.get(&container).map(|x| x.frames.unwrap()).unwrap_or(u32::MAX) { sqlx::query!("DELETE FROM files WHERE filename = ?", filename) .execute(&mut *tx) .await?; } } } } let video_thumb_times = video_thumb_times.read().await; let video_embed_times = video_embed_times.read().await; for (container_filename, metadata) in video_meta.iter() { let embed_time = video_embed_times.get(container_filename); let thumb_time = video_thumb_times.get(container_filename); let container_filename: &[u8] = container_filename.as_bytes(); let metadata = rmp_serde::to_vec_named(metadata)?; sqlx::query!("INSERT OR REPLACE INTO files (filename, embedding_time, thumbnail_time, metadata) VALUES (?, ?, ?, ?)", container_filename, embed_time, thumb_time, metadata) .execute(&mut *tx) .await?; } tx.commit().await?; tracing::info!("Ingest done"); Result::Ok(()) } const INDEX_ADD_BATCH: usize = 1024; #[instrument] async fn build_index(config: Arc) -> Result { let pool = initialize_database(&config.service).await?; let mut index = IIndex { vectors: scalar_quantizer::ScalarQuantizerIndexImpl::new(config.backend.embedding_size as u32, scalar_quantizer::QuantizerType::QT_fp16, faiss::MetricType::InnerProduct)?, filenames: Vec::new(), format_codes: Vec::new(), format_names: Vec::new(), metadata: Vec::new() }; let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM files") .fetch_one(&pool) .await?; index.filenames = Vec::with_capacity(count as usize); index.format_codes = Vec::with_capacity(count as usize); let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * config.backend.embedding_size as usize); index.format_names = Vec::with_capacity(5); index.format_names.push(String::from("VIDEO")); let video_format_code = 1<<0; let mut rows = sqlx::query_as::<_, RawFileRecord>("SELECT * FROM files").fetch(&pool); while let Some(record) = rows.try_next().await? { if let Some(emb) = record.embedding { let parsed = Filename::decode(record.filename)?; let mut format_code = match parsed { Filename::VideoFrame(_, _) => video_format_code, _ => 0 }; index.filenames.push(parsed); for i in (0..emb.len()).step_by(2) { buffer.push( half::f16::from_le_bytes([emb[i], emb[i + 1]]) .to_f32(), ); } if buffer.len() == buffer.capacity() { index.vectors.add(&buffer)?; buffer.clear(); } let mut formats: Vec = Vec::new(); if let Some(t) = record.thumbnails { formats = rmp_serde::from_slice(&t)?; } if let Some(m) = record.metadata { index.metadata.push(Some(rmp_serde::from_slice(&m)?)); } else { index.metadata.push(None); } for format_string in &formats { let mut found = false; for (i, name) in index.format_names.iter().enumerate() { if name == format_string { format_code |= 1 << i; found = true; break; } } if !found { let new_index = index.format_names.len(); format_code |= 1 << new_index; index.format_names.push(format_string.clone()); } } index.format_codes.push(format_code); } } if !buffer.is_empty() { index.vectors.add(&buffer)?; } Ok(index) } #[instrument(skip(index))] async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result> { let result = index.vectors.search(&query, k as usize)?; let mut seen_videos = HashSet::new(); let items = result.distances .into_iter() .zip(result.labels) .filter_map(|(distance, id)| { let id = id.get()? as usize; match (video, &index.filenames[id]) { (_, Filename::Actual(_)) => (), (false, Filename::VideoFrame(_, _)) => return None, (true, Filename::VideoFrame(container, _)) => { if !seen_videos.insert(container) { return None } } } Some(( distance, index.filenames[id].container_filename(), generate_filename_hash(&index.filenames[id as usize]).clone(), index.format_codes[id], index.metadata[id].as_ref().map(|x| (x.width, x.height)), Option::<()>::None )) }) .collect(); Ok(QueryResult { matches: items, formats: index.format_names.clone(), extensions: HashMap::new(), }) } #[instrument(skip(config, client, index))] async fn handle_request(config: Arc, client: Arc, index: &IIndex, req: Json) -> Result> { let embedding = common::get_total_embedding( &req.terms, &config.backend, |batch, (config, client)| async move { query_clip_server(&client, &config.service.clip_server, "", batch).await }, |image, config| async move { let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&image))?); Ok(serde_bytes::ByteBuf::from(resize_for_embed(config.backend.clone(), image).await?)) }, &config.clone().predefined_embeddings, config.clone(), (config.clone(), client.clone())).await?; let k = req.k.unwrap_or(1000); let qres = query_index(index, embedding, k, req.include_video).await?; let mut extensions = HashMap::new(); for (k, v) in image_formats(&config.service) { extensions.insert(k, v.extension); } Ok(Json(QueryResult { matches: qres.matches, formats: qres.formats, extensions, }).into_response()) } #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt().init(); let config_path = std::env::args().nth(1).expect("Missing config file path"); let config: Config = serde_json::from_slice(&std::fs::read(config_path)?)?; let backend = get_backend_config(&config.clip_server).await; let mut predefined_embeddings = HashMap::new(); { let db = initialize_database(&config).await?; let result = sqlx::query!("SELECT * FROM predefined_embeddings") .fetch_all(&db).await?; for row in result { predefined_embeddings.insert(row.name, ndarray::Array::from(decode_fp16_buffer(&row.embedding))); } } let config = Arc::new(WConfig { service: config, backend, predefined_embeddings }); if config.service.no_run_server { ingest_files(config.clone()).await?; return Ok(()) } let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1); let index = Arc::new(tokio::sync::RwLock::new(build_index(config.clone()).await?)); let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1); let done_tx = Arc::new(ingest_done_tx.clone()); let _ingest_task = tokio::spawn({ let config = config.clone(); let index = index.clone(); async move { loop { tracing::info!("Ingest running"); match ingest_files(config.clone()).await { Ok(_) => { match build_index(config.clone()).await { Ok(new_index) => { LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64); *index.write().await = new_index; tracing::info!("Index loaded"); } Err(e) => { tracing::error!("Index build failed: {:?}", e); ingest_done_tx.send((false, format!("{:?}", e))).unwrap(); } } } Err(e) => { tracing::error!("Ingest failed: {:?}", e); ingest_done_tx.send((false, format!("{:?}", e))).unwrap(); } } ingest_done_tx.send((true, format!("OK"))).unwrap(); RELOADS_COUNTER.inc(); request_ingest_rx.recv().await; } } }); let cors = CorsLayer::permissive(); let config_ = config.clone(); let client = Arc::new(Client::new()); let index_ = index.clone(); let config__ = config.clone(); let app = Router::new() .route("/", post(|req| async move { let config = config.clone(); let index = index.read().await; // TODO: use ConcurrentIndex here let client = client.clone(); QUERIES_COUNTER.inc(); handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e)) }).layer(DefaultBodyLimit::max(2<<24))) .route("/", get(|_req: ()| async move { Json(FrontendInit { n_total: index_.read().await.vectors.ntotal(), predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect(), d_emb: config__.backend.embedding_size }) })) .route("/reload", post(|_req: ()| async move { tracing::info!("Requesting index reload"); let mut done_rx = done_tx.clone().subscribe(); let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full match done_rx.recv().await { Ok((true, status)) => { let mut res = status.into_response(); *res.status_mut() = StatusCode::OK; res }, Ok((false, status)) => { let mut res = status.into_response(); *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; res }, Err(_) => { let mut res = "internal error".into_response(); *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; res } } })) .route("/metrics", get(|_req: ()| async move { let mut buffer = Vec::new(); let encoder = prometheus::TextEncoder::new(); let metric_families = prometheus::gather(); encoder.encode(&metric_families, &mut buffer).unwrap(); buffer })) .layer(cors); let addr = format!("0.0.0.0:{}", config_.service.port); tracing::info!("Starting server on {}", addr); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app).await?; Ok(()) }