diff --git a/.sqlx/query-fccbb4262990c387079141c60a96d4e030ff82b93975f13d96539957b24f3c13.json b/.sqlx/query-fccbb4262990c387079141c60a96d4e030ff82b93975f13d96539957b24f3c13.json new file mode 100644 index 0000000..a90ef33 --- /dev/null +++ b/.sqlx/query-fccbb4262990c387079141c60a96d4e030ff82b93975f13d96539957b24f3c13.json @@ -0,0 +1,12 @@ +{ + "db_name": "SQLite", + "query": "INSERT OR REPLACE INTO files (filename, embedding_time, thumbnail_time) VALUES (?, ?, ?)", + "describe": { + "columns": [], + "parameters": { + "Right": 3 + }, + "nullable": [] + }, + "hash": "fccbb4262990c387079141c60a96d4e030ff82b93975f13d96539957b24f3c13" +} diff --git a/clipfront2/src/App.svelte b/clipfront2/src/App.svelte index b0fa151..6b6ca54 100644 --- a/clipfront2/src/App.svelte +++ b/clipfront2/src/App.svelte @@ -67,7 +67,7 @@ border: 1px solid gray * display: block - .result img + .result img, .result video width: 100% @@ -84,6 +84,7 @@
  • Capitalization is ignored.
  • Only English is supported. Other languages might work slightly.
  • Sliders are generated from PCA on the index. The human-readable labels are approximate.
  • +
  • Want your own deployment? Use the open-source code on GitHub..
  • @@ -138,15 +139,21 @@ {#key `${queryCounter}${result.file}`}
    - - {#if util.hasFormat(results, result, "avifl")} - - {/if} - {#if util.hasFormat(results, result, "jpegl")} - - {/if} - {result[1]} - + {#if util.hasFormat(results, result, "VIDEO")} + + {:else} + + {#if util.hasFormat(results, result, "avifl")} + + {/if} + {#if util.hasFormat(results, result, "jpegl")} + + {/if} + {result[1]} + + {/if}
    {/key} @@ -240,7 +247,10 @@ let displayedResults = [] const runSearch = async () => { if (!resultPromise) { - let args = {"terms": queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))} + let args = { + "terms": queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] })), + "include_video": true + } queryCounter += 1 resultPromise = util.doQuery(args).then(res => { error = null diff --git a/src/common.rs b/src/common.rs index 86e705e..7b207df 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,18 +12,20 @@ pub struct InferenceServerConfig { pub embedding_size: usize, } +pub fn resize_for_embed_sync + Send + 'static>(config: InferenceServerConfig, image: T) -> Result> { + let new = image.borrow().resize( + config.image_size.0, + config.image_size.1, + FilterType::Lanczos3 + ); + let mut buf = Vec::new(); + let mut csr = Cursor::new(&mut buf); + new.write_to(&mut csr, ImageFormat::Png)?; + Ok::, anyhow::Error>(buf) +} + pub async fn resize_for_embed + Send + 'static>(config: InferenceServerConfig, image: T) -> Result> { - let resized = tokio::task::spawn_blocking(move || { - let new = image.borrow().resize( - config.image_size.0, - config.image_size.1, - FilterType::Lanczos3 - ); - let mut buf = Vec::new(); - let mut csr = Cursor::new(&mut buf); - new.write_to(&mut csr, ImageFormat::Png)?; - Ok::, anyhow::Error>(buf) - }).await??; + let resized = tokio::task::spawn_blocking(move || resize_for_embed_sync(config, image)).await??; Ok(resized) } diff --git a/src/main.rs b/src/main.rs index 1a6b740..d9e4e67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::{collections::HashMap, io::Cursor}; use std::path::Path; use std::sync::Arc; @@ -12,11 +13,15 @@ use axum::{ Router, http::StatusCode }; +use common::resize_for_embed_sync; +use ffmpeg_the_third::device::input::video; +use image::RgbImage; use image::{imageops::FilterType, io::Reader as ImageReader, DynamicImage, ImageFormat}; use reqwest::Client; use serde::{Deserialize, Serialize}; +use sqlx::SqliteConnection; use sqlx::{sqlite::SqliteConnectOptions, SqlitePool}; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::task::JoinHandle; use walkdir::WalkDir; use base64::prelude::*; @@ -31,6 +36,7 @@ use ndarray::ArrayBase; mod ocr; mod common; +mod video_reader; use crate::ocr::scan_image; use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server}; @@ -41,6 +47,8 @@ lazy_static! { static ref TERMS_COUNTER: IntCounterVec = register_int_counter_vec!("mse_terms", "terms used in queries, by type", &["type"]).unwrap(); static ref IMAGES_LOADED_COUNTER: IntCounter = register_int_counter!("mse_loads", "images loaded by ingest process").unwrap(); static ref IMAGES_LOADED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_load_errors", "image load fails by ingest process").unwrap(); + static ref VIDEOS_LOADED_COUNTER: IntCounter = register_int_counter!("mse_video_loads", "video loaded by ingest process").unwrap(); + static ref VIDEOS_LOADED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_video_load_errors", "video load fails by ingest process").unwrap(); static ref IMAGES_EMBEDDED_COUNTER: IntCounter = register_int_counter!("mse_embeds", "images embedded by ingest process").unwrap(); static ref IMAGES_OCRED_COUNTER: IntCounter = register_int_counter!("mse_ocrs", "images OCRed by ingest process").unwrap(); static ref IMAGES_OCRED_ERROR_COUNTER: IntCounter = register_int_counter!("mse_ocr_errors", "image OCR fails by ingest process").unwrap(); @@ -72,7 +80,7 @@ struct Config { #[derive(Debug)] struct IIndex { vectors: scalar_quantizer::ScalarQuantizerIndexImpl, - filenames: Vec, + filenames: Vec, format_codes: Vec, format_names: Vec, } @@ -89,35 +97,20 @@ CREATE TABLE IF NOT EXISTS files ( thumbnails BLOB ); -CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 ( - filename, - ocr, - tokenize='unicode61 remove_diacritics 2', - content='files' -); - CREATE TABLE IF NOT EXISTS predefined_embeddings ( name TEXT NOT NULL PRIMARY KEY, embedding BLOB NOT NULL ); -CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN - INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, '')); -END; - -CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN - INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, '')); -END; - -CREATE TRIGGER IF NOT EXISTS ocr_fts_upd AFTER UPDATE ON files BEGIN - INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, '')); - INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, '')); -END; +DROP TRIGGER IF EXISTS ocr_fts_upd; +DROP TRIGGER IF EXISTS ocr_fts_ins; +DROP TRIGGER IF EXISTS ocr_fts_del; +DROP TABLE IF EXISTS ocr_fts; "#; -#[derive(Debug, sqlx::FromRow, Clone, Default)] -struct FileRecord { - filename: String, +#[derive(Debug, sqlx::FromRow, Clone)] +struct RawFileRecord { + filename: Vec, embedding_time: Option, ocr_time: Option, thumbnail_time: Option, @@ -128,6 +121,14 @@ struct FileRecord { thumbnails: Option>, } +#[derive(Debug, Clone)] +struct FileRecord { + filename: String, + needs_embed: bool, + needs_ocr: bool, + needs_thumbnail: bool +} + #[derive(Debug, Clone)] struct WConfig { backend: InferenceServerConfig, @@ -138,14 +139,49 @@ struct WConfig { #[derive(Debug)] struct LoadedImage { image: Arc, - filename: String, - original_size: usize, + filename: Filename, + original_size: Option, + fast_thumbnails_only: bool +} + +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] +enum Filename { + Actual(String), + VideoFrame(String, u64) +} + +// this is a somewhat horrible hack, but probably nobody has NUL bytes at the start of filenames? +impl Filename { + fn decode(buf: Vec) -> Result { + Ok(match buf.strip_prefix(&[0]) { + Some(remainder) => rmp_serde::from_read(&*remainder)?, + None => Filename::Actual(String::from_utf8(buf)?.to_string()) + }) + } + + fn encode(&self) -> Result> { + match self { + Self::Actual(s) => Ok(s.to_string().into_bytes()), + x => { + let mut out = rmp_serde::to_vec(x).context("should not happen")?; + out.insert(0, 0); + Ok(out) + } + } + } + + fn container_filename(&self) -> String { + match self { + Self::Actual(s) => s.to_string(), + Self::VideoFrame(s, _) => s.to_string() + } + } } #[derive(Debug)] struct EmbeddingInput { image: Vec, - filename: String, + filename: Filename, } fn timestamp() -> i64 { @@ -155,21 +191,25 @@ fn timestamp() -> i64 { #[derive(Debug, Clone)] struct ImageFormatConfig { target_width: u32, - target_filesize: u32, + target_filesize: usize, quality: u8, format: ImageFormat, extension: String, + is_fast: bool } -fn generate_filename_hash(filename: &str) -> String { +fn generate_filename_hash(filename: &Filename) -> String { use std::hash::{Hash, Hasher}; let mut hasher = fnv::FnvHasher::default(); - filename.hash(&mut hasher); + match filename { + Filename::Actual(x) => x.hash(&mut hasher), + _ => filename.hash(&mut hasher) + }; BASE64_URL_SAFE_NO_PAD.encode(hasher.finish().to_le_bytes()) } fn generate_thumbnail_filename( - filename: &str, + filename: &Filename, format_name: &str, format_config: &ImageFormatConfig, ) -> String { @@ -200,6 +240,7 @@ fn image_formats(_config: &Config) -> HashMap { quality: 70, format: ImageFormat::Jpeg, extension: "jpg".to_string(), + is_fast: true }, ); formats.insert( @@ -210,6 +251,7 @@ fn image_formats(_config: &Config) -> HashMap { quality: 80, format: ImageFormat::Jpeg, extension: "jpg".to_string(), + is_fast: true }, ); formats.insert( @@ -220,6 +262,7 @@ fn image_formats(_config: &Config) -> HashMap { quality: 0, format: ImageFormat::Jpeg, extension: "jpg".to_string(), + is_fast: false }, ); formats.insert( @@ -230,6 +273,7 @@ fn image_formats(_config: &Config) -> HashMap { quality: 80, format: ImageFormat::Avif, extension: "avif".to_string(), + is_fast: false }, ); formats.insert( @@ -240,11 +284,19 @@ fn image_formats(_config: &Config) -> HashMap { quality: 70, format: ImageFormat::Avif, extension: "avif".to_string(), + is_fast: false }, ); formats } +async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc: &Vec) -> Result<()> { + sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename_enc) + .execute(conn) + .await?; + Ok(()) +} + async fn ingest_files(config: Arc) -> Result<()> { let pool = initialize_database(&config.service).await?; let client = Client::new(); @@ -258,47 +310,89 @@ async fn ingest_files(config: Arc) -> Result<()> { let cpus = num_cpus::get(); + let video_lengths = Arc::new(RwLock::new(HashMap::new())); + let video_thumb_times = Arc::new(RwLock::new(HashMap::new())); + let video_embed_times = Arc::new(RwLock::new(HashMap::new())); + // Image loading and preliminary resizing let image_loading: JoinHandle> = tokio::spawn({ let config = config.clone(); + let video_lengths = video_lengths.clone(); let stream = ReceiverStream::new(to_process_rx).map(Ok); stream.try_for_each_concurrent(Some(cpus), move |record| { let config = config.clone(); let to_embed_tx = to_embed_tx.clone(); let to_thumbnail_tx = to_thumbnail_tx.clone(); let to_ocr_tx = to_ocr_tx.clone(); + let video_lengths = video_lengths.clone(); async move { let path = Path::new(&config.service.files).join(&record.filename); let image: Result> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?))); let image = match image { Ok(image) => image, Err(e) => { - log::error!("Could not read {}: {}", record.filename, e); + log::warn!("Could not read {} as image: {}", record.filename, e); + let filename = record.filename.clone(); IMAGES_LOADED_ERROR_COUNTER.inc(); + let video_length = tokio::task::spawn_blocking(move || -> Result> { + let mut i = 0; + let callback = |frame: RgbImage| { + let frame: Arc = Arc::new(frame.into()); + let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?; + to_embed_tx.blocking_send(EmbeddingInput { + image: embed_buf, + filename: Filename::VideoFrame(filename.clone(), i) + })?; + to_thumbnail_tx.blocking_send(LoadedImage { + image: frame.clone(), + filename: Filename::VideoFrame(filename.clone(), i), + original_size: None, + fast_thumbnails_only: true + })?; + i += 1; + Ok(()) + }; + match video_reader::run(&path, callback) { + Ok(()) => { + VIDEOS_LOADED_COUNTER.inc(); + return anyhow::Result::Ok(Some(i)) + }, + Err(e) => { + log::error!("Could not read {} as video: {}", filename, e); + VIDEOS_LOADED_ERROR_COUNTER.inc(); + } + } + return anyhow::Result::Ok(None) + }).await??; + if let Some(length) = video_length { + video_lengths.write().await.insert(record.filename, length); + } return Ok(()) } }; IMAGES_LOADED_COUNTER.inc(); - if record.embedding.is_none() { + if record.needs_embed { let resized = resize_for_embed(config.backend.clone(), image.clone()).await?; - to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await? + to_embed_tx.send(EmbeddingInput { image: resized, filename: Filename::Actual(record.filename.clone()) }).await? } - if record.thumbnails.is_none() && config.service.enable_thumbs { + if record.needs_thumbnail { to_thumbnail_tx .send(LoadedImage { image: image.clone(), - filename: record.filename.clone(), - original_size: std::fs::metadata(&path)?.len() as usize, + filename: Filename::Actual(record.filename.clone()), + original_size: Some(std::fs::metadata(&path)?.len() as usize), + fast_thumbnails_only: false }) .await?; } - if record.raw_ocr_segments.is_none() && config.service.enable_ocr { + if record.needs_ocr { to_ocr_tx .send(LoadedImage { image, - filename: record.filename.clone(), - original_size: 0, + filename: Filename::Actual(record.filename.clone()), + original_size: None, + fast_thumbnails_only: true }) .await?; } @@ -313,6 +407,7 @@ async fn ingest_files(config: Arc) -> Result<()> { let pool = pool.clone(); let stream = ReceiverStream::new(to_thumbnail_rx).map(Ok); let formats = Arc::new(formats); + let video_thumb_times = video_thumb_times.clone(); Some(tokio::spawn({ stream.try_for_each_concurrent(Some(cpus), move |image| { use image::codecs::*; @@ -320,13 +415,15 @@ async fn ingest_files(config: Arc) -> Result<()> { let formats = formats.clone(); let config = config.clone(); let pool = pool.clone(); + let video_thumb_times = video_thumb_times.clone(); async move { let filename = image.filename.clone(); - log::debug!("thumbnailing {}", filename); + log::debug!("thumbnailing {:?}", filename); let generated_formats = tokio::task::spawn_blocking(move || { let mut generated_formats = Vec::new(); let rgb = DynamicImage::from(image.image.to_rgb8()); for (format_name, format_config) in &*formats { + if !format_config.is_fast && image.fast_thumbnails_only { continue } let resized = if format_config.target_filesize != 0 { let mut lb = 1; let mut ub = 100; @@ -345,7 +442,7 @@ async fn ingest_files(config: Arc) -> Result<()> { ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)), _ => unimplemented!() }?; - if buf.len() > image.original_size { + if buf.len() > format_config.target_filesize { ub = quality; } else { lb = quality + 1; @@ -370,7 +467,7 @@ async fn ingest_files(config: Arc) -> Result<()> { }?; buf }; - if resized.len() < image.original_size { + if resized.len() < image.original_size.unwrap_or(usize::MAX) { generated_formats.push(format_name.clone()); let thumbnail_path = Path::new(&config.service.thumbs_path).join( generate_thumbnail_filename( @@ -388,13 +485,20 @@ async fn ingest_files(config: Arc) -> Result<()> { IMAGES_THUMBNAILED_COUNTER.inc(); let formats_data = rmp_serde::to_vec(&generated_formats)?; let ts = timestamp(); + let filename_enc = filename.encode()?; + let mut conn = pool.acquire().await?; + ensure_filename_record_exists(&mut conn, &filename_enc).await?; + match filename { + Filename::VideoFrame(container, _) => { video_thumb_times.write().await.insert(container.to_string(), timestamp()); }, + _ => () + } sqlx::query!( "UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?", formats_data, ts, - filename + filename_enc ) - .execute(&pool) + .execute(&mut *conn) .await?; Ok(()) } @@ -405,6 +509,7 @@ async fn ingest_files(config: Arc) -> Result<()> { }; // OCR + // TODO: save OCR errors and don't retry let ocr: Option>> = if config.service.enable_ocr { let client = client.clone(); let pool = pool.clone(); @@ -414,12 +519,12 @@ async fn ingest_files(config: Arc) -> Result<()> { let client = client.clone(); let pool = pool.clone(); async move { - log::debug!("OCRing {}", image.filename); + log::debug!("OCRing {:?}", image.filename); let scan = match scan_image(&client, &image.image).await { Ok(scan) => scan, Err(e) => { IMAGES_OCRED_ERROR_COUNTER.inc(); - log::error!("OCR failure {}: {}", image.filename, e); + log::error!("OCR failure {:?}: {}", image.filename, e); return Ok(()) } }; @@ -431,14 +536,17 @@ async fn ingest_files(config: Arc) -> Result<()> { .join("\n"); let ocr_data = rmp_serde::to_vec(&scan)?; let ts = timestamp(); + let filename_enc = image.filename.encode()?; + let mut conn = pool.acquire().await?; + ensure_filename_record_exists(&mut conn, &filename_enc).await?; sqlx::query!( "UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?", ocr_text, ocr_data, ts, - image.filename + filename_enc ) - .execute(&pool) + .execute(&mut *conn) .await?; Ok(()) } @@ -453,11 +561,13 @@ async fn ingest_files(config: Arc) -> Result<()> { let client = client.clone(); let config = config.clone(); let pool = pool.clone(); + let video_embed_times = video_embed_times.clone(); // keep multiple embedding requests in flight stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| { let client = client.clone(); let config = config.clone(); let pool = pool.clone(); + let video_embed_times = video_embed_times.clone(); async move { let result: Vec = query_clip_server( &client, @@ -472,13 +582,19 @@ async fn ingest_files(config: Arc) -> Result<()> { let ts = timestamp(); for (i, vector) in result.into_iter().enumerate() { let vector = vector.into_vec(); - log::debug!("embedded {}", batch[i].filename); + log::debug!("embedded {:?}", batch[i].filename); + let encoded_filename = batch[i].filename.encode()?; IMAGES_EMBEDDED_COUNTER.inc(); + ensure_filename_record_exists(&mut *tx, &encoded_filename).await?; + match &batch[i].filename { + Filename::VideoFrame(container, _) => { video_embed_times.write().await.insert(container.to_string(), timestamp()); }, + _ => () + } sqlx::query!( "UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?", ts, vector, - batch[i].filename + encoded_filename ) .execute(&mut *tx) .await?; @@ -489,7 +605,7 @@ async fn ingest_files(config: Arc) -> Result<()> { }) }); - let mut filenames = HashMap::new(); + let mut actual_filenames = HashMap::new(); // blocking OS calls tokio::task::block_in_place(|| -> anyhow::Result<()> { @@ -500,7 +616,7 @@ async fn ingest_files(config: Arc) -> Result<()> { let filename = path.strip_prefix(&config.service.files)?.to_str().unwrap().to_string(); let modtime = entry.metadata()?.modified()?.duration_since(std::time::UNIX_EPOCH)?; let modtime = modtime.as_micros() as i64; - filenames.insert(filename.clone(), (path.to_path_buf(), modtime)); + actual_filenames.insert(filename.clone(), (path.to_path_buf(), modtime)); } } Ok(()) @@ -508,36 +624,35 @@ async fn ingest_files(config: Arc) -> Result<()> { log::debug!("finished reading filenames"); - for (filename, (_path, modtime)) in filenames.iter() { + for (filename, (_path, modtime)) in actual_filenames.iter() { let modtime = *modtime; - let record = sqlx::query_as!(FileRecord, "SELECT * FROM files WHERE filename = ?", filename) + let record = sqlx::query_as!(RawFileRecord, "SELECT * FROM files WHERE filename = ?", filename) .fetch_optional(&pool) .await?; let new_record = match record { None => Some(FileRecord { filename: filename.clone(), - ..Default::default() + needs_embed: true, + needs_ocr: true, + needs_thumbnail: true }), - Some(r) if modtime > r.embedding_time.unwrap_or(i64::MIN) || (modtime > r.ocr_time.unwrap_or(i64::MIN) && config.service.enable_ocr) || (modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.service.enable_thumbs) => { - Some(r) - }, - _ => None + Some(r) => { + let needs_embed = modtime > r.embedding_time.unwrap_or(i64::MIN); + let needs_ocr = modtime > r.ocr_time.unwrap_or(i64::MIN) && config.service.enable_ocr; + let needs_thumbnail = modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.service.enable_thumbs; + if needs_embed || needs_ocr || needs_thumbnail { + Some(FileRecord { + filename: filename.clone(), + needs_embed, needs_ocr, needs_thumbnail + }) + } else { + None + } + } }; - if let Some(mut record) = new_record { + if let Some(record) = new_record { log::debug!("processing {}", record.filename); - sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename) - .execute(&pool) - .await?; - if modtime > record.embedding_time.unwrap_or(i64::MIN) { - record.embedding = None; - } - if modtime > record.ocr_time.unwrap_or(i64::MIN) { - record.raw_ocr_segments = None; - } - if modtime > record.thumbnail_time.unwrap_or(i64::MIN) { - record.thumbnails = None; - } // we need to exit here to actually capture the error if !to_process_tx.send(record).await.is_ok() { break @@ -559,18 +674,43 @@ async fn ingest_files(config: Arc) -> Result<()> { image_loading.await?.context("loading images")?; - let stored: Vec = sqlx::query_scalar("SELECT filename FROM files").fetch_all(&pool).await?; + let stored: Vec> = sqlx::query_scalar("SELECT filename FROM files").fetch_all(&pool).await?; let mut tx = pool.begin().await?; + let video_lengths = video_lengths.read().await; for filename in stored { - if !filenames.contains_key(&filename) { - sqlx::query!("DELETE FROM files WHERE filename = ?", filename) - .execute(&mut *tx) - .await?; + let parsed_filename = Filename::decode(filename.clone())?; + match parsed_filename { + Filename::Actual(s) => if !actual_filenames.contains_key(&s) { + sqlx::query!("DELETE FROM files WHERE filename = ?", s) + .execute(&mut *tx) + .await?; + }, + // This might fail in some cases where for whatever reason a video is replaced with a file of the same name which is not a video. Don't do that. + Filename::VideoFrame(container, frame) => if !actual_filenames.contains_key(&container) { + if let Some(length) = video_lengths.get(&container) { + if frame > *length { + sqlx::query!("DELETE FROM files WHERE filename = ?", filename) + .execute(&mut *tx) + .await?; + } + } + } } } + + let video_thumb_times = video_thumb_times.read().await; + let video_embed_times = video_embed_times.read().await; + for container_filename in video_lengths.keys() { + let embed_time = video_embed_times.get(container_filename); + let thumb_time = video_thumb_times.get(container_filename); + sqlx::query!("INSERT OR REPLACE INTO files (filename, embedding_time, thumbnail_time) VALUES (?, ?, ?)", container_filename, embed_time, thumb_time) + .execute(&mut *tx) + .await?; + } + tx.commit().await?; - log::info!("ingest done"); + log::info!("Ingest done"); Result::Ok(()) } @@ -595,11 +735,20 @@ async fn build_index(config: Arc) -> Result { index.format_codes = Vec::with_capacity(count as usize); let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * config.backend.embedding_size as usize); index.format_names = Vec::with_capacity(5); + index.format_names.push(String::from("VIDEO")); + let video_format_code = 1<<0; - let mut rows = sqlx::query_as::<_, FileRecord>("SELECT * FROM files").fetch(&pool); + let mut rows = sqlx::query_as::<_, RawFileRecord>("SELECT * FROM files").fetch(&pool); while let Some(record) = rows.try_next().await? { if let Some(emb) = record.embedding { - index.filenames.push(record.filename); + let parsed = Filename::decode(record.filename)?; + + let mut format_code = match parsed { + Filename::VideoFrame(_, _) => video_format_code, + _ => 0 + }; + + index.filenames.push(parsed); for i in (0..emb.len()).step_by(2) { buffer.push( half::f16::from_le_bytes([emb[i], emb[i + 1]]) @@ -615,8 +764,7 @@ async fn build_index(config: Arc) -> Result { if let Some(t) = record.thumbnails { formats = rmp_serde::from_slice(&t)?; } - - let mut format_code = 0; + for format_string in &formats { let mut found = false; for (i, name) in index.format_names.iter().enumerate() { @@ -670,19 +818,32 @@ struct QueryTerm { struct QueryRequest { terms: Vec, k: Option, + #[serde(default)] + include_video: bool } -async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize) -> Result { +async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result { let result = index.vectors.search(&query, k as usize)?; + let mut seen_videos = HashSet::new(); + let items = result.distances .into_iter() .zip(result.labels) .filter_map(|(distance, id)| { let id = id.get()? as usize; + match (video, &index.filenames[id]) { + (_, Filename::Actual(_)) => (), + (false, Filename::VideoFrame(_, _)) => return None, + (true, Filename::VideoFrame(container, _)) => { + if !seen_videos.insert(container) { + return None + } + } + } Some(( distance, - index.filenames[id].clone(), + index.filenames[id].container_filename(), generate_filename_hash(&index.filenames[id as usize]).clone(), index.format_codes[id] )) @@ -755,7 +916,7 @@ async fn handle_request(config: Arc, client: Arc, index: &IInde } let k = req.k.unwrap_or(1000); - let qres = query_index(index, total_embedding.to_vec(), k).await?; + let qres = query_index(index, total_embedding.to_vec(), k, req.include_video).await?; let mut extensions = HashMap::new(); for (k, v) in image_formats(&config.service) { @@ -828,6 +989,7 @@ async fn main() -> Result<()> { Ok(new_index) => { LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64); *index.write().await = new_index; + log::info!("Index loaded"); } Err(e) => { log::error!("Index build failed: {:?}", e); diff --git a/src/video_reader.rs b/src/video_reader.rs index ca9b2b8..7bf39db 100644 --- a/src/video_reader.rs +++ b/src/video_reader.rs @@ -2,10 +2,12 @@ extern crate ffmpeg_the_third as ffmpeg; use anyhow::{Result, Context}; use image::RgbImage; use std::env; -use ffmpeg::{codec, filter, format::{self, Pixel}, media::Type, software::scaling, util::frame::video::Video}; +use ffmpeg::{codec, filter, format::{self, Pixel}, media::Type, util::frame::video::Video}; -fn main() -> Result<()> { - let mut ictx = format::input(&env::args().nth(1).unwrap()).context("parsing video")?; +const BYTES_PER_PIXEL: usize = 3; + +pub fn run, F: FnMut(RgbImage) -> Result<()>>(path: P, mut frame_callback: F) -> Result<()> { + let mut ictx = format::input(&path).context("parsing video")?; let video = ictx.streams().best(Type::Video).context("no video stream")?; let video_index = video.index(); @@ -15,31 +17,18 @@ fn main() -> Result<()> { let mut graph = filter::Graph::new(); let afr = video.avg_frame_rate(); let afr = (((afr.0 as f32) / (afr.1 as f32)).round() as i64).max(1); - // passing in the actual timebase breaks something, and the thumbnail filter should not need it graph.add(&filter::find("buffer").unwrap(), "in", &format!("video_size={}x{}:pix_fmt={}:time_base={}/{}:pixel_aspect={}/{}", decoder.width(), decoder.height(), decoder.format().descriptor().unwrap().name(), video.time_base().0, video.time_base().1, decoder.aspect_ratio().0, decoder.aspect_ratio().1))?; graph.add(&filter::find("buffersink").unwrap(), "out", "")?; - graph.output("in", 0)?.input("out", 0)?.parse(&format!("[in] thumbnail=n={} [thumbs]; [thumbs] select='gt(scene,0.05)+eq(n,0)' [out]", afr)).context("filtergraph parse failed")?; + graph.output("in", 0)?.input("out", 0)?.parse(&format!("[in] thumbnail=n={}:log=verbose [thumbs]; [thumbs] select='gt(scene,0.05)+eq(n,0)' [out]", afr)).context("filtergraph parse failed")?; let mut out = graph.get("out").unwrap(); - out.set_pixel_format(decoder.format()); + out.set_pixel_format(Pixel::RGB24); graph.validate().context("filtergraph build failed")?; - let mut scaler = scaling::Context::get( - decoder.format(), - decoder.width(), - decoder.height(), - Pixel::RGB24, - 384, - 384, - scaling::Flags::LANCZOS, - )?; - - let mut count = 0; let mut receive_and_process_decoded_frames = |decoder: &mut ffmpeg::decoder::Video, filter_graph: &mut filter::Graph| -> Result<()> { let mut decoded = Video::empty(); let mut filtered = Video::empty(); - let mut rgb_frame = Video::empty(); loop { if !decoder.receive_frame(&mut decoded).is_ok() { break } @@ -48,11 +37,15 @@ fn main() -> Result<()> { src.add(&decoded).context("add frame")?; while filter_graph.get("out").unwrap().sink().frame(&mut filtered).is_ok() { - scaler.run(&filtered, &mut rgb_frame).context("scaler")?; - println!("frame gotten {}x{} {:?} {}", rgb_frame.width(), rgb_frame.height(), rgb_frame.data(0).len(), count); - let image = RgbImage::from_vec(rgb_frame.width(), rgb_frame.height(), rgb_frame.data(0).to_vec()).unwrap(); // unfortunately, we have to copy - image.save(format!("/tmp/output-{}.png", count))?; - count += 1; + let mut image = vec![0u8; filtered.width() as usize * filtered.height() as usize * BYTES_PER_PIXEL]; + let stride = filtered.stride(0); + let data = filtered.data(0); + let width = filtered.width() as usize * BYTES_PER_PIXEL; + let height = filtered.height() as usize; + for y in 0..height { + image[y * width .. (y + 1) * width].copy_from_slice(&data[y * stride .. y * stride + width]); + } + frame_callback(image::ImageBuffer::from_vec(filtered.width(), filtered.height(), image).unwrap())?; } } Ok(()) @@ -61,11 +54,21 @@ fn main() -> Result<()> { for (stream, packet) in ictx.packets().filter_map(Result::ok) { if stream.index() == video_index { decoder.send_packet(&packet).context("decoder")?; - receive_and_process_decoded_frames(&mut decoder, &mut graph).context("processing")?; + receive_and_process_decoded_frames(&mut decoder, &mut graph).context("processing frame")?; } } decoder.send_eof()?; receive_and_process_decoded_frames(&mut decoder, &mut graph)?; Ok(()) +} + +fn main() -> Result<()> { + let mut count = 0; + let callback = |frame: RgbImage| { + frame.save(format!("/tmp/output-{}.png", count))?; + count += 1; + Ok(()) + }; + run(&env::args().nth(1).unwrap(), callback) } \ No newline at end of file