mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-08 07:00:06 +00:00
improve observability and fix up Reddit dump for full-scale run
This commit is contained in:
parent
1d0ff95955
commit
7fa14d45ae
1359
Cargo.lock
generated
1359
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
15
Cargo.toml
15
Cargo.toml
@ -6,14 +6,13 @@ edition = "2021"
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full", "tracing"] }
|
||||||
axum = "0.7"
|
axum = "0.7"
|
||||||
image = { version = "0.25", features = ["avif", "avif-native", "nasm"] }
|
image = { version = "0.25", features = ["avif", "avif-native", "nasm"] }
|
||||||
reqwest = { version = "0.12", features = ["multipart"] }
|
reqwest = { version = "0.12", features = ["multipart"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
|
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
|
||||||
walkdir = "1"
|
walkdir = "1"
|
||||||
log = "0.4"
|
|
||||||
rmp-serde = "1"
|
rmp-serde = "1"
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
@ -24,7 +23,8 @@ faiss = "0.12"
|
|||||||
ndarray = "0.15"
|
ndarray = "0.15"
|
||||||
half = { version = "2" }
|
half = { version = "2" }
|
||||||
regex = "1"
|
regex = "1"
|
||||||
pretty_env_logger = "0.5"
|
tracing = "0.1"
|
||||||
|
console-subscriber = "0.4"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
num_cpus = "1"
|
num_cpus = "1"
|
||||||
@ -41,9 +41,8 @@ mimalloc = "0.1"
|
|||||||
sonic-rs = "0.3"
|
sonic-rs = "0.3"
|
||||||
ffmpeg-the-third = "2.0"
|
ffmpeg-the-third = "2.0"
|
||||||
compact_str = { version = "0.8.0-beta", features = ["serde"] }
|
compact_str = { version = "0.8.0-beta", features = ["serde"] }
|
||||||
|
itertools = "0.13"
|
||||||
[patch.crates-io]
|
async-recursion = "1"
|
||||||
image = { git = "https://github.com/fintelia/image/", branch = "upgrade-zune-jpeg" }
|
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "reddit-dump"
|
name = "reddit-dump"
|
||||||
@ -52,3 +51,7 @@ path = "src/reddit_dump.rs"
|
|||||||
[[bin]]
|
[[bin]]
|
||||||
name = "video-reader"
|
name = "video-reader"
|
||||||
path = "src/video_reader.rs"
|
path = "src/video_reader.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "dump-processor"
|
||||||
|
path = "src/dump_processor.rs"
|
||||||
|
@ -4,6 +4,7 @@ use image::{DynamicImage, imageops::FilterType, ImageFormat};
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct InferenceServerConfig {
|
pub struct InferenceServerConfig {
|
||||||
@ -13,11 +14,13 @@ pub struct InferenceServerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
||||||
let new = image.borrow().resize(
|
// the model currently in use wants aspect ratio 1:1 regardless of input
|
||||||
|
// I think this was previously being handled in the CLIP server but that is slightly lossy
|
||||||
|
let new = image.borrow().resize_exact(
|
||||||
config.image_size.0,
|
config.image_size.0,
|
||||||
config.image_size.1,
|
config.image_size.1,
|
||||||
FilterType::Lanczos3
|
FilterType::CatmullRom
|
||||||
);
|
).into_rgb8();
|
||||||
let mut buf = Vec::new();
|
let mut buf = Vec::new();
|
||||||
let mut csr = Cursor::new(&mut buf);
|
let mut csr = Cursor::new(&mut buf);
|
||||||
new.write_to(&mut csr, ImageFormat::Png)?;
|
new.write_to(&mut csr, ImageFormat::Png)?;
|
||||||
@ -46,13 +49,14 @@ pub async fn get_backend_config(clip_server: &str) -> InferenceServerConfig {
|
|||||||
match fetch_backend_config(&clip_server).await {
|
match fetch_backend_config(&clip_server).await {
|
||||||
Ok(backend) => break backend,
|
Ok(backend) => break backend,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("Backend failed (fetch): {}", e);
|
tracing::error!("Backend failed (fetch): {}", e);
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(client, data))]
|
||||||
pub async fn query_clip_server<I, O>(client: &Client, base_url: &str, path: &str, data: I) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
|
pub async fn query_clip_server<I, O>(client: &Client, base_url: &str, path: &str, data: I) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
|
||||||
{
|
{
|
||||||
let response = client
|
let response = client
|
||||||
|
53
src/dump_processor.rs
Normal file
53
src/dump_processor.rs
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
use anyhow::{Result, Context};
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
use std::io::BufReader;
|
||||||
|
use rmp_serde::decode::Error as DecodeError;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
// TODO refactor
|
||||||
|
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||||
|
struct OriginalImageMetadata {
|
||||||
|
mime_type: String,
|
||||||
|
original_file_size: usize,
|
||||||
|
dimension: (u32, u32),
|
||||||
|
final_url: String
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||||
|
struct ProcessedEntry {
|
||||||
|
url: String,
|
||||||
|
id: String,
|
||||||
|
title: String,
|
||||||
|
subreddit: String,
|
||||||
|
author: String,
|
||||||
|
timestamp: u64,
|
||||||
|
#[serde(with = "serde_bytes")]
|
||||||
|
embedding: Vec<u8>,
|
||||||
|
metadata: OriginalImageMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let path = std::env::args().nth(1).context("missing path")?;
|
||||||
|
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||||
|
let mut stream = BufReader::new(stream);
|
||||||
|
let mut latest_timestamp = 0;
|
||||||
|
let mut count = 0;
|
||||||
|
loop {
|
||||||
|
let res: Result<ProcessedEntry, DecodeError> = rmp_serde::from_read(&mut stream);
|
||||||
|
if res.is_ok() {
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
match res {
|
||||||
|
Ok(x) => {
|
||||||
|
if x.timestamp > latest_timestamp {
|
||||||
|
println!("{} {} https://reddit.com/r/{}/comments/{}", x.timestamp, count, x.subreddit, x.id);
|
||||||
|
latest_timestamp = x.timestamp;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||||
|
Err(e) => return Err(e).context("decode fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("{} {}", latest_timestamp, count);
|
||||||
|
Ok(())
|
||||||
|
}
|
545
src/main.rs
545
src/main.rs
@ -33,6 +33,7 @@ use faiss::index::scalar_quantizer;
|
|||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
||||||
use ndarray::ArrayBase;
|
use ndarray::ArrayBase;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
mod ocr;
|
mod ocr;
|
||||||
mod common;
|
mod common;
|
||||||
@ -249,7 +250,7 @@ async fn initialize_database(config: &Config) -> Result<SqlitePool> {
|
|||||||
if (index as i32) < version {
|
if (index as i32) < version {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log::info!("Migrating to DB version {}", index);
|
tracing::info!("Migrating to DB version {}", index);
|
||||||
sqlx::query(sql).execute(&mut *tx).await?;
|
sqlx::query(sql).execute(&mut *tx).await?;
|
||||||
sqlx::query(&format!("PRAGMA user_version = {}", index + 1)).execute(&mut *tx).await?;
|
sqlx::query(&format!("PRAGMA user_version = {}", index + 1)).execute(&mut *tx).await?;
|
||||||
}
|
}
|
||||||
@ -317,6 +318,7 @@ fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
|
|||||||
formats
|
formats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc: &Vec<u8>) -> Result<()> {
|
async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc: &Vec<u8>) -> Result<()> {
|
||||||
sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename_enc)
|
sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename_enc)
|
||||||
.execute(conn)
|
.execute(conn)
|
||||||
@ -324,6 +326,7 @@ async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
async fn write_metadata(conn: &mut SqliteConnection, filename_enc: &Vec<u8>, metadata: FileMetadata) -> Result<()> {
|
async fn write_metadata(conn: &mut SqliteConnection, filename_enc: &Vec<u8>, metadata: FileMetadata) -> Result<()> {
|
||||||
ensure_filename_record_exists(conn, filename_enc).await?;
|
ensure_filename_record_exists(conn, filename_enc).await?;
|
||||||
let metadata_serialized = rmp_serde::to_vec_named(&metadata)?;
|
let metadata_serialized = rmp_serde::to_vec_named(&metadata)?;
|
||||||
@ -333,6 +336,264 @@ async fn write_metadata(conn: &mut SqliteConnection, filename_enc: &Vec<u8>, met
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument]
|
||||||
|
async fn handle_embedding_batch(client: reqwest::Client, config: Arc<WConfig>, pool: SqlitePool, batch: Vec<EmbeddingInput>, video_embed_times: Arc<RwLock<HashMap<CompactString, i64>>>) -> Result<()> {
|
||||||
|
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
|
||||||
|
&client,
|
||||||
|
&config.service.clip_server,
|
||||||
|
"",
|
||||||
|
EmbeddingRequest::Images {
|
||||||
|
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
|
||||||
|
},
|
||||||
|
).await.context("querying CLIP server")?;
|
||||||
|
|
||||||
|
let mut tx = pool.begin().await?;
|
||||||
|
let ts = timestamp();
|
||||||
|
for (i, vector) in result.into_iter().enumerate() {
|
||||||
|
let vector = vector.into_vec();
|
||||||
|
tracing::debug!("embedded {:?}", batch[i].filename);
|
||||||
|
let encoded_filename = batch[i].filename.encode()?;
|
||||||
|
IMAGES_EMBEDDED_COUNTER.inc();
|
||||||
|
ensure_filename_record_exists(&mut *tx, &encoded_filename).await?;
|
||||||
|
match &batch[i].filename {
|
||||||
|
Filename::VideoFrame(container, _) => { video_embed_times.write().await.insert(container.clone(), timestamp()); },
|
||||||
|
_ => ()
|
||||||
|
}
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
|
||||||
|
ts,
|
||||||
|
vector,
|
||||||
|
encoded_filename
|
||||||
|
)
|
||||||
|
.execute(&mut *tx)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
tx.commit().await?;
|
||||||
|
anyhow::Result::Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(to_embed_tx, to_thumbnail_tx, to_ocr_tx, to_metadata_write_tx, video_meta))]
|
||||||
|
async fn load_image(record: FileRecord, to_embed_tx: mpsc::Sender<EmbeddingInput>, to_thumbnail_tx: mpsc::Sender<LoadedImage>, to_ocr_tx: mpsc::Sender<LoadedImage>, to_metadata_write_tx: mpsc::Sender<(Filename, FileMetadata)>, config: Arc<WConfig>, video_meta: Arc<RwLock<HashMap<CompactString, FileMetadata>>>) -> Result<()> {
|
||||||
|
let path = Path::new(&config.service.files).join(&*record.filename);
|
||||||
|
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
|
||||||
|
let image = match image {
|
||||||
|
Ok(image) => image,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Could not read {} as image: {}", record.filename, e);
|
||||||
|
let filename = record.filename.clone();
|
||||||
|
IMAGES_LOADED_ERROR_COUNTER.inc();
|
||||||
|
let meta = tokio::task::spawn_blocking(move || -> Result<Option<FileMetadata>> {
|
||||||
|
let mut i: u32 = 0;
|
||||||
|
let mut last_metadata = None;
|
||||||
|
let callback = |frame: RgbImage| {
|
||||||
|
let frame: Arc<DynamicImage> = Arc::new(frame.into());
|
||||||
|
let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?;
|
||||||
|
let filename = Filename::VideoFrame(filename.clone(), i);
|
||||||
|
to_embed_tx.blocking_send(EmbeddingInput {
|
||||||
|
image: embed_buf,
|
||||||
|
filename: filename.clone()
|
||||||
|
})?;
|
||||||
|
let meta = FileMetadata {
|
||||||
|
height: frame.height(),
|
||||||
|
width: frame.width(),
|
||||||
|
frames: Some(i + 1)
|
||||||
|
};
|
||||||
|
last_metadata = Some(meta.clone());
|
||||||
|
to_metadata_write_tx.blocking_send((filename.clone(), meta))?;
|
||||||
|
if config.service.enable_thumbs {
|
||||||
|
to_thumbnail_tx.blocking_send(LoadedImage {
|
||||||
|
image: frame.clone(),
|
||||||
|
filename,
|
||||||
|
original_filesize: None,
|
||||||
|
fast_thumbnails_only: true
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
match video_reader::run(&path, callback, config.service.video_frame_interval) {
|
||||||
|
Ok(()) => {
|
||||||
|
VIDEOS_LOADED_COUNTER.inc();
|
||||||
|
return anyhow::Result::Ok(last_metadata)
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Could not read {} as video: {}", filename, e);
|
||||||
|
VIDEOS_LOADED_ERROR_COUNTER.inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return anyhow::Result::Ok(last_metadata)
|
||||||
|
}).await??;
|
||||||
|
if let Some(meta) = meta {
|
||||||
|
video_meta.write().await.insert(record.filename, meta);
|
||||||
|
}
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let filename = Filename::Actual(record.filename);
|
||||||
|
if record.needs_metadata {
|
||||||
|
let metadata = FileMetadata {
|
||||||
|
width: image.width(),
|
||||||
|
height: image.height(),
|
||||||
|
frames: None
|
||||||
|
};
|
||||||
|
to_metadata_write_tx.send((filename.clone(), metadata)).await?;
|
||||||
|
}
|
||||||
|
IMAGES_LOADED_COUNTER.inc();
|
||||||
|
if record.needs_embed {
|
||||||
|
let resized = resize_for_embed(config.backend.clone(), image.clone()).await?;
|
||||||
|
|
||||||
|
to_embed_tx.send(EmbeddingInput { image: resized, filename: filename.clone() }).await?
|
||||||
|
}
|
||||||
|
if record.needs_thumbnail {
|
||||||
|
to_thumbnail_tx
|
||||||
|
.send(LoadedImage {
|
||||||
|
image: image.clone(),
|
||||||
|
filename: filename.clone(),
|
||||||
|
original_filesize: Some(std::fs::metadata(&path)?.len() as usize),
|
||||||
|
fast_thumbnails_only: false
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
if record.needs_ocr {
|
||||||
|
to_ocr_tx
|
||||||
|
.send(LoadedImage {
|
||||||
|
image,
|
||||||
|
filename: filename.clone(),
|
||||||
|
original_filesize: None,
|
||||||
|
fast_thumbnails_only: true
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(video_thumb_times, pool, formats))]
|
||||||
|
async fn generate_thumbnail(image: LoadedImage, config: Arc<WConfig>, video_thumb_times: Arc<RwLock<HashMap<CompactString, i64>>>, pool: SqlitePool, formats: Arc<HashMap<String, ImageFormatConfig>>) -> Result<()> {
|
||||||
|
use image::codecs::*;
|
||||||
|
|
||||||
|
let filename = image.filename.clone();
|
||||||
|
tracing::debug!("thumbnailing {:?}", filename);
|
||||||
|
|
||||||
|
let generated_formats = tokio::task::spawn_blocking(move || {
|
||||||
|
let mut generated_formats = Vec::new();
|
||||||
|
let rgb = DynamicImage::from(image.image.to_rgb8());
|
||||||
|
for (format_name, format_config) in &*formats {
|
||||||
|
if !format_config.is_fast && image.fast_thumbnails_only { continue }
|
||||||
|
let resized = if format_config.target_filesize != 0 {
|
||||||
|
let mut lb = 1;
|
||||||
|
let mut ub = 100;
|
||||||
|
loop {
|
||||||
|
let quality = (lb + ub) / 2;
|
||||||
|
let thumbnail = rgb.resize(
|
||||||
|
format_config.target_width.min(rgb.width()),
|
||||||
|
u32::MAX,
|
||||||
|
FilterType::Lanczos3,
|
||||||
|
);
|
||||||
|
let mut buf: Vec<u8> = Vec::new();
|
||||||
|
let mut csr = Cursor::new(&mut buf);
|
||||||
|
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
|
||||||
|
match format_config.format {
|
||||||
|
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
|
||||||
|
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
|
||||||
|
_ => unimplemented!()
|
||||||
|
}?;
|
||||||
|
if buf.len() > format_config.target_filesize {
|
||||||
|
ub = quality;
|
||||||
|
} else {
|
||||||
|
lb = quality + 1;
|
||||||
|
}
|
||||||
|
if lb >= ub {
|
||||||
|
break buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let thumbnail = rgb.resize(
|
||||||
|
format_config.target_width.min(rgb.width()),
|
||||||
|
u32::MAX,
|
||||||
|
FilterType::Lanczos3,
|
||||||
|
);
|
||||||
|
let mut buf: Vec<u8> = Vec::new();
|
||||||
|
let mut csr = Cursor::new(&mut buf);
|
||||||
|
match format_config.format {
|
||||||
|
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
|
||||||
|
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
|
||||||
|
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
|
||||||
|
_ => unimplemented!()
|
||||||
|
}?;
|
||||||
|
buf
|
||||||
|
};
|
||||||
|
if resized.len() < image.original_filesize.unwrap_or(usize::MAX) {
|
||||||
|
generated_formats.push(format_name.clone());
|
||||||
|
let thumbnail_path = Path::new(&config.service.thumbs_path).join(
|
||||||
|
generate_thumbnail_filename(
|
||||||
|
&image.filename,
|
||||||
|
format_name,
|
||||||
|
format_config,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
THUMBNAILS_GENERATED_COUNTER.get_metric_with_label_values(&[format_name]).unwrap().inc();
|
||||||
|
std::fs::write(thumbnail_path, resized)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok::<Vec<String>, anyhow::Error>(generated_formats)
|
||||||
|
}).await??;
|
||||||
|
|
||||||
|
IMAGES_THUMBNAILED_COUNTER.inc();
|
||||||
|
let formats_data = rmp_serde::to_vec(&generated_formats)?;
|
||||||
|
let ts = timestamp();
|
||||||
|
let filename_enc = filename.encode()?;
|
||||||
|
let mut conn = pool.acquire().await?;
|
||||||
|
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
|
||||||
|
match filename {
|
||||||
|
Filename::VideoFrame(container, _) => { video_thumb_times.write().await.insert(container.clone(), timestamp()); },
|
||||||
|
_ => ()
|
||||||
|
}
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
|
||||||
|
formats_data,
|
||||||
|
ts,
|
||||||
|
filename_enc
|
||||||
|
)
|
||||||
|
.execute(&mut *conn)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument]
|
||||||
|
async fn do_ocr(image: LoadedImage, config: Arc<WConfig>, client: Client, pool: SqlitePool) -> Result<()> {
|
||||||
|
tracing::debug!("OCRing {:?}", image.filename);
|
||||||
|
let scan = match scan_image(&client, &image.image).await {
|
||||||
|
Ok(scan) => scan,
|
||||||
|
Err(e) => {
|
||||||
|
IMAGES_OCRED_ERROR_COUNTER.inc();
|
||||||
|
tracing::error!("OCR failure {:?}: {}", image.filename, e);
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
IMAGES_OCRED_COUNTER.inc();
|
||||||
|
let ocr_text = scan
|
||||||
|
.iter()
|
||||||
|
.map(|segment| segment.text.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
let ocr_data = rmp_serde::to_vec(&scan)?;
|
||||||
|
let ts = timestamp();
|
||||||
|
let filename_enc = image.filename.encode()?;
|
||||||
|
let mut conn = pool.acquire().await?;
|
||||||
|
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
|
||||||
|
ocr_text,
|
||||||
|
ocr_data,
|
||||||
|
ts,
|
||||||
|
filename_enc
|
||||||
|
)
|
||||||
|
.execute(&mut *conn)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument]
|
||||||
async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||||
let pool = initialize_database(&config.service).await?;
|
let pool = initialize_database(&config.service).await?;
|
||||||
let client = Client::new();
|
let client = Client::new();
|
||||||
@ -363,99 +624,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
|||||||
let to_ocr_tx = to_ocr_tx.clone();
|
let to_ocr_tx = to_ocr_tx.clone();
|
||||||
let video_meta = video_meta.clone();
|
let video_meta = video_meta.clone();
|
||||||
let to_metadata_write_tx = to_metadata_write_tx.clone();
|
let to_metadata_write_tx = to_metadata_write_tx.clone();
|
||||||
async move {
|
load_image(record, to_embed_tx, to_thumbnail_tx, to_ocr_tx, to_metadata_write_tx, config, video_meta)
|
||||||
let path = Path::new(&config.service.files).join(&*record.filename);
|
|
||||||
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
|
|
||||||
let image = match image {
|
|
||||||
Ok(image) => image,
|
|
||||||
Err(e) => {
|
|
||||||
log::warn!("Could not read {} as image: {}", record.filename, e);
|
|
||||||
let filename = record.filename.clone();
|
|
||||||
IMAGES_LOADED_ERROR_COUNTER.inc();
|
|
||||||
let meta = tokio::task::spawn_blocking(move || -> Result<Option<FileMetadata>> {
|
|
||||||
let mut i: u32 = 0;
|
|
||||||
let mut last_metadata = None;
|
|
||||||
let callback = |frame: RgbImage| {
|
|
||||||
let frame: Arc<DynamicImage> = Arc::new(frame.into());
|
|
||||||
let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?;
|
|
||||||
let filename = Filename::VideoFrame(filename.clone(), i);
|
|
||||||
to_embed_tx.blocking_send(EmbeddingInput {
|
|
||||||
image: embed_buf,
|
|
||||||
filename: filename.clone()
|
|
||||||
})?;
|
|
||||||
let meta = FileMetadata {
|
|
||||||
height: frame.height(),
|
|
||||||
width: frame.width(),
|
|
||||||
frames: Some(i + 1)
|
|
||||||
};
|
|
||||||
last_metadata = Some(meta.clone());
|
|
||||||
to_metadata_write_tx.blocking_send((filename.clone(), meta))?;
|
|
||||||
if config.service.enable_thumbs {
|
|
||||||
to_thumbnail_tx.blocking_send(LoadedImage {
|
|
||||||
image: frame.clone(),
|
|
||||||
filename,
|
|
||||||
original_filesize: None,
|
|
||||||
fast_thumbnails_only: true
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
Ok(())
|
|
||||||
};
|
|
||||||
match video_reader::run(&path, callback, config.service.video_frame_interval) {
|
|
||||||
Ok(()) => {
|
|
||||||
VIDEOS_LOADED_COUNTER.inc();
|
|
||||||
return anyhow::Result::Ok(last_metadata)
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Could not read {} as video: {}", filename, e);
|
|
||||||
VIDEOS_LOADED_ERROR_COUNTER.inc();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return anyhow::Result::Ok(last_metadata)
|
|
||||||
}).await??;
|
|
||||||
if let Some(meta) = meta {
|
|
||||||
video_meta.write().await.insert(record.filename, meta);
|
|
||||||
}
|
|
||||||
return Ok(())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let filename = Filename::Actual(record.filename);
|
|
||||||
if record.needs_metadata {
|
|
||||||
let metadata = FileMetadata {
|
|
||||||
width: image.width(),
|
|
||||||
height: image.height(),
|
|
||||||
frames: None
|
|
||||||
};
|
|
||||||
to_metadata_write_tx.send((filename.clone(), metadata)).await?;
|
|
||||||
}
|
|
||||||
IMAGES_LOADED_COUNTER.inc();
|
|
||||||
if record.needs_embed {
|
|
||||||
let resized = resize_for_embed(config.backend.clone(), image.clone()).await?;
|
|
||||||
|
|
||||||
to_embed_tx.send(EmbeddingInput { image: resized, filename: filename.clone() }).await?
|
|
||||||
}
|
|
||||||
if record.needs_thumbnail {
|
|
||||||
to_thumbnail_tx
|
|
||||||
.send(LoadedImage {
|
|
||||||
image: image.clone(),
|
|
||||||
filename: filename.clone(),
|
|
||||||
original_filesize: Some(std::fs::metadata(&path)?.len() as usize),
|
|
||||||
fast_thumbnails_only: false
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
if record.needs_ocr {
|
|
||||||
to_ocr_tx
|
|
||||||
.send(LoadedImage {
|
|
||||||
image,
|
|
||||||
filename: filename.clone(),
|
|
||||||
original_filesize: None,
|
|
||||||
fast_thumbnails_only: true
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -477,98 +646,11 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
|||||||
let video_thumb_times = video_thumb_times.clone();
|
let video_thumb_times = video_thumb_times.clone();
|
||||||
Some(tokio::spawn({
|
Some(tokio::spawn({
|
||||||
stream.try_for_each_concurrent(Some(cpus), move |image| {
|
stream.try_for_each_concurrent(Some(cpus), move |image| {
|
||||||
use image::codecs::*;
|
|
||||||
|
|
||||||
let formats = formats.clone();
|
let formats = formats.clone();
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let pool = pool.clone();
|
let pool = pool.clone();
|
||||||
let video_thumb_times = video_thumb_times.clone();
|
let video_thumb_times = video_thumb_times.clone();
|
||||||
async move {
|
generate_thumbnail(image, config, video_thumb_times, pool, formats)
|
||||||
let filename = image.filename.clone();
|
|
||||||
log::debug!("thumbnailing {:?}", filename);
|
|
||||||
let generated_formats = tokio::task::spawn_blocking(move || {
|
|
||||||
let mut generated_formats = Vec::new();
|
|
||||||
let rgb = DynamicImage::from(image.image.to_rgb8());
|
|
||||||
for (format_name, format_config) in &*formats {
|
|
||||||
if !format_config.is_fast && image.fast_thumbnails_only { continue }
|
|
||||||
let resized = if format_config.target_filesize != 0 {
|
|
||||||
let mut lb = 1;
|
|
||||||
let mut ub = 100;
|
|
||||||
loop {
|
|
||||||
let quality = (lb + ub) / 2;
|
|
||||||
let thumbnail = rgb.resize(
|
|
||||||
format_config.target_width.min(rgb.width()),
|
|
||||||
u32::MAX,
|
|
||||||
FilterType::Lanczos3,
|
|
||||||
);
|
|
||||||
let mut buf: Vec<u8> = Vec::new();
|
|
||||||
let mut csr = Cursor::new(&mut buf);
|
|
||||||
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
|
|
||||||
match format_config.format {
|
|
||||||
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
|
|
||||||
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
|
|
||||||
_ => unimplemented!()
|
|
||||||
}?;
|
|
||||||
if buf.len() > format_config.target_filesize {
|
|
||||||
ub = quality;
|
|
||||||
} else {
|
|
||||||
lb = quality + 1;
|
|
||||||
}
|
|
||||||
if lb >= ub {
|
|
||||||
break buf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let thumbnail = rgb.resize(
|
|
||||||
format_config.target_width.min(rgb.width()),
|
|
||||||
u32::MAX,
|
|
||||||
FilterType::Lanczos3,
|
|
||||||
);
|
|
||||||
let mut buf: Vec<u8> = Vec::new();
|
|
||||||
let mut csr = Cursor::new(&mut buf);
|
|
||||||
match format_config.format {
|
|
||||||
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
|
|
||||||
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
|
|
||||||
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
|
|
||||||
_ => unimplemented!()
|
|
||||||
}?;
|
|
||||||
buf
|
|
||||||
};
|
|
||||||
if resized.len() < image.original_filesize.unwrap_or(usize::MAX) {
|
|
||||||
generated_formats.push(format_name.clone());
|
|
||||||
let thumbnail_path = Path::new(&config.service.thumbs_path).join(
|
|
||||||
generate_thumbnail_filename(
|
|
||||||
&image.filename,
|
|
||||||
format_name,
|
|
||||||
format_config,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
THUMBNAILS_GENERATED_COUNTER.get_metric_with_label_values(&[format_name]).unwrap().inc();
|
|
||||||
std::fs::write(thumbnail_path, resized)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok::<Vec<String>, anyhow::Error>(generated_formats)
|
|
||||||
}).await??;
|
|
||||||
IMAGES_THUMBNAILED_COUNTER.inc();
|
|
||||||
let formats_data = rmp_serde::to_vec(&generated_formats)?;
|
|
||||||
let ts = timestamp();
|
|
||||||
let filename_enc = filename.encode()?;
|
|
||||||
let mut conn = pool.acquire().await?;
|
|
||||||
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
|
|
||||||
match filename {
|
|
||||||
Filename::VideoFrame(container, _) => { video_thumb_times.write().await.insert(container.clone(), timestamp()); },
|
|
||||||
_ => ()
|
|
||||||
}
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
|
|
||||||
formats_data,
|
|
||||||
ts,
|
|
||||||
filename_enc
|
|
||||||
)
|
|
||||||
.execute(&mut *conn)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
} else {
|
} else {
|
||||||
@ -579,43 +661,14 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
|||||||
let ocr: Option<JoinHandle<Result<()>>> = if config.service.enable_ocr {
|
let ocr: Option<JoinHandle<Result<()>>> = if config.service.enable_ocr {
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
let pool = pool.clone();
|
let pool = pool.clone();
|
||||||
|
let config = config.clone();
|
||||||
let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
|
let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
|
||||||
Some(tokio::spawn({
|
Some(tokio::spawn({
|
||||||
stream.try_for_each_concurrent(Some(config.service.ocr_concurrency), move |image| {
|
stream.try_for_each_concurrent(Some(config.service.ocr_concurrency), move |image| {
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
let pool = pool.clone();
|
let pool = pool.clone();
|
||||||
async move {
|
let config = config.clone();
|
||||||
log::debug!("OCRing {:?}", image.filename);
|
do_ocr(image, config, client, pool)
|
||||||
let scan = match scan_image(&client, &image.image).await {
|
|
||||||
Ok(scan) => scan,
|
|
||||||
Err(e) => {
|
|
||||||
IMAGES_OCRED_ERROR_COUNTER.inc();
|
|
||||||
log::error!("OCR failure {:?}: {}", image.filename, e);
|
|
||||||
return Ok(())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
IMAGES_OCRED_COUNTER.inc();
|
|
||||||
let ocr_text = scan
|
|
||||||
.iter()
|
|
||||||
.map(|segment| segment.text.clone())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n");
|
|
||||||
let ocr_data = rmp_serde::to_vec(&scan)?;
|
|
||||||
let ts = timestamp();
|
|
||||||
let filename_enc = image.filename.encode()?;
|
|
||||||
let mut conn = pool.acquire().await?;
|
|
||||||
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
|
|
||||||
ocr_text,
|
|
||||||
ocr_data,
|
|
||||||
ts,
|
|
||||||
filename_enc
|
|
||||||
)
|
|
||||||
.execute(&mut *conn)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
} else {
|
} else {
|
||||||
@ -634,40 +687,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
|||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let pool = pool.clone();
|
let pool = pool.clone();
|
||||||
let video_embed_times = video_embed_times.clone();
|
let video_embed_times = video_embed_times.clone();
|
||||||
async move {
|
handle_embedding_batch(client, config, pool, batch, video_embed_times)
|
||||||
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
|
|
||||||
&client,
|
|
||||||
&config.service.clip_server,
|
|
||||||
"",
|
|
||||||
EmbeddingRequest::Images {
|
|
||||||
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
|
|
||||||
},
|
|
||||||
).await.context("querying CLIP server")?;
|
|
||||||
|
|
||||||
let mut tx = pool.begin().await?;
|
|
||||||
let ts = timestamp();
|
|
||||||
for (i, vector) in result.into_iter().enumerate() {
|
|
||||||
let vector = vector.into_vec();
|
|
||||||
log::debug!("embedded {:?}", batch[i].filename);
|
|
||||||
let encoded_filename = batch[i].filename.encode()?;
|
|
||||||
IMAGES_EMBEDDED_COUNTER.inc();
|
|
||||||
ensure_filename_record_exists(&mut *tx, &encoded_filename).await?;
|
|
||||||
match &batch[i].filename {
|
|
||||||
Filename::VideoFrame(container, _) => { video_embed_times.write().await.insert(container.clone(), timestamp()); },
|
|
||||||
_ => ()
|
|
||||||
}
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
|
|
||||||
ts,
|
|
||||||
vector,
|
|
||||||
encoded_filename
|
|
||||||
)
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
tx.commit().await?;
|
|
||||||
anyhow::Result::Ok(())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -688,7 +708,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
log::debug!("finished reading filenames");
|
tracing::debug!("finished reading filenames");
|
||||||
|
|
||||||
for (filename, (_path, modtime)) in actual_filenames.iter() {
|
for (filename, (_path, modtime)) in actual_filenames.iter() {
|
||||||
let modtime = *modtime;
|
let modtime = *modtime;
|
||||||
@ -721,7 +741,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
if let Some(record) = new_record {
|
if let Some(record) = new_record {
|
||||||
log::debug!("processing {}", record.filename);
|
tracing::debug!("processing {}", record.filename);
|
||||||
// we need to exit here to actually capture the error
|
// we need to exit here to actually capture the error
|
||||||
if !to_process_tx.send(record).await.is_ok() {
|
if !to_process_tx.send(record).await.is_ok() {
|
||||||
break
|
break
|
||||||
@ -785,13 +805,14 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
|||||||
|
|
||||||
tx.commit().await?;
|
tx.commit().await?;
|
||||||
|
|
||||||
log::info!("Ingest done");
|
tracing::info!("Ingest done");
|
||||||
|
|
||||||
Result::Ok(())
|
Result::Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
const INDEX_ADD_BATCH: usize = 512;
|
const INDEX_ADD_BATCH: usize = 512;
|
||||||
|
|
||||||
|
#[instrument]
|
||||||
async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
|
async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
|
||||||
let pool = initialize_database(&config.service).await?;
|
let pool = initialize_database(&config.service).await?;
|
||||||
|
|
||||||
@ -904,6 +925,7 @@ struct QueryRequest {
|
|||||||
include_video: bool
|
include_video: bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(index))]
|
||||||
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
|
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
|
||||||
let result = index.vectors.search(&query, k as usize)?;
|
let result = index.vectors.search(&query, k as usize)?;
|
||||||
|
|
||||||
@ -940,6 +962,7 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(config, client, index))]
|
||||||
async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
|
async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
|
||||||
let mut total_embedding = ndarray::Array::from(vec![0.0; config.backend.embedding_size]);
|
let mut total_embedding = ndarray::Array::from(vec![0.0; config.backend.embedding_size]);
|
||||||
|
|
||||||
@ -1016,12 +1039,13 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
|||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct FrontendInit {
|
struct FrontendInit {
|
||||||
n_total: u64,
|
n_total: u64,
|
||||||
predefined_embedding_names: Vec<String>
|
predefined_embedding_names: Vec<String>,
|
||||||
|
d_emb: usize
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
pretty_env_logger::init();
|
console_subscriber::init();
|
||||||
|
|
||||||
let config_path = std::env::args().nth(1).expect("Missing config file path");
|
let config_path = std::env::args().nth(1).expect("Missing config file path");
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_path)?)?;
|
let config: Config = serde_json::from_slice(&std::fs::read(config_path)?)?;
|
||||||
@ -1062,23 +1086,23 @@ async fn main() -> Result<()> {
|
|||||||
let index = index.clone();
|
let index = index.clone();
|
||||||
async move {
|
async move {
|
||||||
loop {
|
loop {
|
||||||
log::info!("Ingest running");
|
tracing::info!("Ingest running");
|
||||||
match ingest_files(config.clone()).await {
|
match ingest_files(config.clone()).await {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
match build_index(config.clone()).await {
|
match build_index(config.clone()).await {
|
||||||
Ok(new_index) => {
|
Ok(new_index) => {
|
||||||
LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64);
|
LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64);
|
||||||
*index.write().await = new_index;
|
*index.write().await = new_index;
|
||||||
log::info!("Index loaded");
|
tracing::info!("Index loaded");
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("Index build failed: {:?}", e);
|
tracing::error!("Index build failed: {:?}", e);
|
||||||
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
|
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("Ingest failed: {:?}", e);
|
tracing::error!("Ingest failed: {:?}", e);
|
||||||
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
|
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1106,11 +1130,12 @@ async fn main() -> Result<()> {
|
|||||||
.route("/", get(|_req: ()| async move {
|
.route("/", get(|_req: ()| async move {
|
||||||
Json(FrontendInit {
|
Json(FrontendInit {
|
||||||
n_total: index_.read().await.vectors.ntotal(),
|
n_total: index_.read().await.vectors.ntotal(),
|
||||||
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect()
|
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect(),
|
||||||
|
d_emb: config__.backend.embedding_size
|
||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
.route("/reload", post(|_req: ()| async move {
|
.route("/reload", post(|_req: ()| async move {
|
||||||
log::info!("Requesting index reload");
|
tracing::info!("Requesting index reload");
|
||||||
let mut done_rx = done_tx.clone().subscribe();
|
let mut done_rx = done_tx.clone().subscribe();
|
||||||
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
|
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
|
||||||
match done_rx.recv().await {
|
match done_rx.recv().await {
|
||||||
@ -1141,7 +1166,7 @@ async fn main() -> Result<()> {
|
|||||||
.layer(cors);
|
.layer(cors);
|
||||||
|
|
||||||
let addr = format!("0.0.0.0:{}", config_.service.port);
|
let addr = format!("0.0.0.0:{}", config_.service.port);
|
||||||
log::info!("Starting server on {}", addr);
|
tracing::info!("Starting server on {}", addr);
|
||||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ use reqwest::{
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{io::Cursor, time::{SystemTime, UNIX_EPOCH}};
|
use std::{io::Cursor, time::{SystemTime, UNIX_EPOCH}};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
const CALLBACK_REGEX: &str = r">AF_initDataCallback\((\{key: 'ds:1'.*?\})\);</script>";
|
const CALLBACK_REGEX: &str = r">AF_initDataCallback\((\{key: 'ds:1'.*?\})\);</script>";
|
||||||
const MAX_DIM: u32 = 1024;
|
const MAX_DIM: u32 = 1024;
|
||||||
@ -45,6 +46,7 @@ fn rationalize_coords_format1(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(client, image))]
|
||||||
async fn scan_image_chunk(
|
async fn scan_image_chunk(
|
||||||
client: &Client,
|
client: &Client,
|
||||||
image: &[u8],
|
image: &[u8],
|
||||||
@ -130,13 +132,14 @@ async fn scan_image_chunk(
|
|||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(client))]
|
||||||
pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanResult> {
|
pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanResult> {
|
||||||
let mut result = ScanResult::new();
|
let mut result = ScanResult::new();
|
||||||
let (width, height) = image.dimensions();
|
let (width, height) = image.dimensions();
|
||||||
|
|
||||||
let (width, height, image) = if width > MAX_DIM {
|
let (width, height, image) = if width > MAX_DIM {
|
||||||
let height = ((height as f64) * (MAX_DIM as f64) / (width as f64)).round() as u32;
|
let height = ((height as f64) * (MAX_DIM as f64) / (width as f64)).round() as u32;
|
||||||
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::Lanczos3));
|
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::CatmullRom));
|
||||||
(MAX_DIM, height, std::borrow::Cow::Owned(new_image))
|
(MAX_DIM, height, std::borrow::Cow::Owned(new_image))
|
||||||
} else {
|
} else {
|
||||||
(width, height, std::borrow::Cow::Borrowed(image))
|
(width, height, std::borrow::Cow::Borrowed(image))
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use common::resize_for_embed;
|
use common::resize_for_embed;
|
||||||
|
use itertools::Itertools;
|
||||||
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, time::Duration, sync::Arc, str::FromStr, path::PathBuf};
|
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, time::Duration, sync::Arc, str::FromStr, path::PathBuf};
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use regex::{RegexSet, bytes, Regex};
|
use regex::{bytes, Regex, RegexSet, RegexSetBuilder};
|
||||||
use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}};
|
use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}};
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use futures_util::stream::{StreamExt, TryStreamExt};
|
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||||
use image::{DynamicImage, io::Reader as ImageReader};
|
use image::{DynamicImage, ImageReader};
|
||||||
use mimalloc::MiMalloc;
|
use mimalloc::MiMalloc;
|
||||||
|
use tracing::instrument;
|
||||||
|
use prometheus::{Encoder, register_int_counter, IntCounter, register_histogram_vec, HistogramVec};
|
||||||
|
|
||||||
#[global_allocator]
|
#[global_allocator]
|
||||||
static GLOBAL: MiMalloc = MiMalloc;
|
static GLOBAL: MiMalloc = MiMalloc;
|
||||||
@ -50,6 +53,14 @@ struct Entry {
|
|||||||
id: String
|
id: String
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||||
|
struct OriginalImageMetadata {
|
||||||
|
mime_type: String,
|
||||||
|
original_file_size: usize,
|
||||||
|
dimension: (u32, u32),
|
||||||
|
final_url: String
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||||
struct ProcessedEntry {
|
struct ProcessedEntry {
|
||||||
url: String,
|
url: String,
|
||||||
@ -58,10 +69,13 @@ struct ProcessedEntry {
|
|||||||
subreddit: String,
|
subreddit: String,
|
||||||
author: String,
|
author: String,
|
||||||
timestamp: u64,
|
timestamp: u64,
|
||||||
blob: Vec<u8>
|
#[serde(with = "serde_bytes")]
|
||||||
|
embedding: Vec<u8>,
|
||||||
|
metadata: OriginalImageMetadata
|
||||||
}
|
}
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
|
// we do exclude galleries doing this but there don't seem to be any in the dataset
|
||||||
static ref URL_IGNORE: RegexSet = RegexSet::new([
|
static ref URL_IGNORE: RegexSet = RegexSet::new([
|
||||||
r"//reddit\.com",
|
r"//reddit\.com",
|
||||||
r"\.html?",
|
r"\.html?",
|
||||||
@ -69,16 +83,39 @@ lazy_static! {
|
|||||||
r"\?articleid=",
|
r"\?articleid=",
|
||||||
r"\.aspx?",
|
r"\.aspx?",
|
||||||
r"\.xml",
|
r"\.xml",
|
||||||
r"//youtube\.com",
|
|
||||||
r"/rss/",
|
r"/rss/",
|
||||||
r"//vimeo\.com",
|
r"//vimeo\.com",
|
||||||
r"//www\.youtube\.com",
|
|
||||||
r"//youtu\.be",
|
|
||||||
r"//www\.reddit\.com",
|
r"//www\.reddit\.com",
|
||||||
|
r"//v\.redd\.it",
|
||||||
|
r"\.gifv$",
|
||||||
|
r"youtube\.com/user/"
|
||||||
// TODO fill in more things, maybe try and collect thumbnails or something
|
// TODO fill in more things, maybe try and collect thumbnails or something
|
||||||
]).unwrap();
|
]).unwrap();
|
||||||
static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
|
static ref URL_MUST_CONTAIN: RegexSet = RegexSetBuilder::new([
|
||||||
.into_iter().map(str::as_bytes).collect();
|
"jpg",
|
||||||
|
"jpeg",
|
||||||
|
"png",
|
||||||
|
"webp",
|
||||||
|
r"\.gif",
|
||||||
|
"=gif",
|
||||||
|
"jpeg",
|
||||||
|
"bmp",
|
||||||
|
"tiff",
|
||||||
|
"avif",
|
||||||
|
"webp",
|
||||||
|
"imgur",
|
||||||
|
"image",
|
||||||
|
r"//i\.",
|
||||||
|
"img",
|
||||||
|
r"cdn\.",
|
||||||
|
r"media\.",
|
||||||
|
"/i/",
|
||||||
|
"/media",
|
||||||
|
r"youtu\.be",
|
||||||
|
r"youtube\.com",
|
||||||
|
]).case_insensitive(true).build().unwrap();
|
||||||
|
static ref ACCEPTABLE_FILETYPES: HashSet<&'static str> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
|
||||||
|
.into_iter().collect();
|
||||||
static ref OBJECT_HACKY_IGNORE: bytes::RegexSet = bytes::RegexSet::new([
|
static ref OBJECT_HACKY_IGNORE: bytes::RegexSet = bytes::RegexSet::new([
|
||||||
r#""author":"\[deleted\]""#,
|
r#""author":"\[deleted\]""#,
|
||||||
r#""promoted":true"#, // these seem to be ads which are in the data for some reason, and lack some important fields
|
r#""promoted":true"#, // these seem to be ads which are in the data for some reason, and lack some important fields
|
||||||
@ -86,11 +123,34 @@ lazy_static! {
|
|||||||
r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record
|
r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record
|
||||||
]).unwrap();
|
]).unwrap();
|
||||||
static ref URL_REPLACEMENT_RULES: Vec<(Regex, &'static str)> = [
|
static ref URL_REPLACEMENT_RULES: Vec<(Regex, &'static str)> = [
|
||||||
(r"//imgur.com/([A-Za-z0-9]+)", r"//i.imgur.com/$1.jpg"),
|
(r"imgur\.com/([A-Za-z0-9]+),", r"imgur.com/$1"),
|
||||||
(r"^http://", r"https://")
|
(r"//imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
|
||||||
|
(r"//www\.imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
|
||||||
|
(r"//m\.imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
|
||||||
|
(r"^http://", r"https://"),
|
||||||
|
(r"//youtu\.be/(.*)", r"//youtube.com/watch?v=$1"),
|
||||||
|
(r"//[a-z]+\.youtube\.com/(.*)", r"//youtube.com/$1"),
|
||||||
|
(r"//www.youtube.com/attribution_link?.*v%3D([A-Za-z0-9_-]+).*", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"), // redirect to youtube thumbnail API
|
||||||
|
(r"//youtube.com/embed/([A-Za-z0-9_-]+)", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"),
|
||||||
|
(r"//youtube\.com/(?:.*)v=([A-Za-z0-9_-]+)(?:.*)", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"),
|
||||||
|
(r"&", "&") // this is such an intensely cursed feature of the dumps
|
||||||
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), e)).collect();
|
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), e)).collect();
|
||||||
|
|
||||||
|
static ref HTML_EXTRACTION_RULES: Vec<(Regex, Regex)> = [
|
||||||
|
(r"//imgur\.com/a/[A-Za-z0-9]+", r#"<meta name="twitter:image" data-react-helmet="true" content="([^"]+)">"#),
|
||||||
|
(r"//imgur\.com/gallery/[A-Za-z0-9]+", r#"<meta name="twitter:image" data-react-helmet="true" content="([^"]+)">"#),
|
||||||
|
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), Regex::new(e).unwrap())).collect();
|
||||||
|
|
||||||
|
static ref IMAGES_FETCHED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_fetched", "images fetched").unwrap();
|
||||||
|
static ref IMAGES_PROCESSED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_processed", "images processed").unwrap();
|
||||||
|
static ref ENTRIES_PROCESSED_COUNTER: IntCounter = register_int_counter!("mse_scrape_entries_processed", "entries processed").unwrap();
|
||||||
|
static ref IMAGES_FAILED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_failed", "images failed").unwrap();
|
||||||
|
static ref IMAGE_FILESIZES_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_filesizes", "filesizes of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.5, 29).unwrap()).unwrap();
|
||||||
|
static ref IMAGE_PIXELS_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_pixels", "pixel count of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.3, 53).unwrap()).unwrap();
|
||||||
|
static ref HTML_EXTRACTS_COUNTER: IntCounter = register_int_counter!("mse_scrape_html_extracts", "html extraction operations").unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(tx))]
|
||||||
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Option<u64>) -> Result<()> {
|
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Option<u64>) -> Result<()> {
|
||||||
let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||||
stream.window_log_max(31)?;
|
stream.window_log_max(31)?;
|
||||||
@ -105,15 +165,16 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
|
|||||||
buf.clear();
|
buf.clear();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
ENTRIES_PROCESSED_COUNTER.inc();
|
||||||
let entry = match sonic_rs::serde::from_slice::<Entry>(buf.as_slice()) {
|
let entry = match sonic_rs::serde::from_slice::<Entry>(buf.as_slice()) {
|
||||||
Ok(x) => x,
|
Ok(x) => x,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf));
|
tracing::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf));
|
||||||
return Ok(())
|
return Ok(())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() && entry.subreddit.is_some() {
|
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() && entry.subreddit.is_some() {
|
||||||
if !URL_IGNORE.is_match(&entry.url) {
|
if !URL_IGNORE.is_match(&entry.url) && URL_MUST_CONTAIN.is_match(&entry.url) {
|
||||||
match &entry.post_hint {
|
match &entry.post_hint {
|
||||||
Some(x) if x == "na" || x == "image" => {
|
Some(x) if x == "na" || x == "image" => {
|
||||||
// Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine".
|
// Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine".
|
||||||
@ -139,23 +200,38 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_content_length: usize,
|
max_content_length: usize,
|
||||||
input: String,
|
input: String,
|
||||||
output: String,
|
output: String,
|
||||||
backend: String,
|
backend: String,
|
||||||
mode: OperatingMode,
|
mode: OperatingMode,
|
||||||
filename_threshold: Option<String>
|
filename_threshold: Option<String>,
|
||||||
|
metrics_addr: String,
|
||||||
|
contact_info: String
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<Vec<u8>> {
|
#[instrument(skip(client, config))]
|
||||||
// inelegant but I can't get it to work using Cows
|
#[async_recursion::async_recursion]
|
||||||
|
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<(Vec<u8>, String, String)> {
|
||||||
let mut url = url.to_string();
|
let mut url = url.to_string();
|
||||||
for (regex, replacement) in URL_REPLACEMENT_RULES.iter() {
|
for (regex, replacement) in URL_REPLACEMENT_RULES.iter() {
|
||||||
url = regex.replace(&url, *replacement).to_string();
|
url = regex.replace(&url, *replacement).to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut html_extract_rule = None;
|
||||||
|
|
||||||
|
for (url_rule, extract_rule) in HTML_EXTRACTION_RULES.iter() {
|
||||||
|
if url_rule.is_match(&url) {
|
||||||
|
html_extract_rule = Some(extract_rule);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut response = client.get(&*url).send().await?;
|
let mut response = client.get(&*url).send().await?;
|
||||||
if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) {
|
let content_type = std::str::from_utf8(&response.headers().get(reqwest::header::CONTENT_TYPE).context("no content type")?.as_bytes())?.to_owned();
|
||||||
|
if !(ACCEPTABLE_FILETYPES.contains(&content_type[..]) || (html_extract_rule.is_some() && content_type == "text/html")) {
|
||||||
return Err(anyhow!("invalid Content-Type"));
|
return Err(anyhow!("invalid Content-Type"));
|
||||||
}
|
}
|
||||||
match response.content_length() {
|
match response.content_length() {
|
||||||
@ -169,11 +245,24 @@ async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) ->
|
|||||||
return Err(anyhow!("response too large"));
|
return Err(anyhow!("response too large"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(buffer)
|
if let Some(extract_rule) = html_extract_rule {
|
||||||
|
if content_type == "text/html" {
|
||||||
|
let buffer = String::from_utf8_lossy(&buffer).to_string();
|
||||||
|
if let Some(mat) = extract_rule.captures(&buffer) {
|
||||||
|
let new_url = mat.get(1).unwrap().as_str();
|
||||||
|
HTML_EXTRACTS_COUNTER.inc();
|
||||||
|
tracing::debug!("found new URL: {}", new_url);
|
||||||
|
return fetch_file(client, config, new_url).await;
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!("no extraction match"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((buffer, content_type, response.url().to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result<()> {
|
fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result<()> {
|
||||||
let mut out = fs::File::options().append(true).open(&config.output)?;
|
let mut out = fs::File::options().create(true).append(true).open(&config.output)?;
|
||||||
let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish();
|
let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish();
|
||||||
let mut buf_stream = BufWriter::new(stream);
|
let mut buf_stream = BufWriter::new(stream);
|
||||||
while let Some(x) = rx.blocking_recv() {
|
while let Some(x) = rx.blocking_recv() {
|
||||||
@ -182,12 +271,14 @@ fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
enum OperatingMode {
|
enum OperatingMode {
|
||||||
Count,
|
Count,
|
||||||
Sample(f32),
|
Sample(f32),
|
||||||
FullRun
|
FullRun
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument]
|
||||||
fn readback_output(path: &str) -> Result<(u64, usize)> {
|
fn readback_output(path: &str) -> Result<(u64, usize)> {
|
||||||
use rmp_serde::decode::Error;
|
use rmp_serde::decode::Error;
|
||||||
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||||
@ -208,27 +299,47 @@ fn readback_output(path: &str) -> Result<(u64, usize)> {
|
|||||||
Ok((latest_timestamp, count))
|
Ok((latest_timestamp, count))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn serve_metrics(config: Arc<Config>) -> Result<()> {
|
||||||
|
let metrics = axum::Router::new().route("/metrics", axum::routing::get(|| async move {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
let encoder = prometheus::TextEncoder::new();
|
||||||
|
let metric_families = prometheus::gather();
|
||||||
|
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||||
|
buffer
|
||||||
|
}));
|
||||||
|
let listener = tokio::net::TcpListener::bind(&config.metrics_addr).await?;
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
let _ = axum::serve(listener, metrics).await;
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
pretty_env_logger::init();
|
console_subscriber::init();
|
||||||
|
|
||||||
let cpus = num_cpus::get();
|
let cpus = num_cpus::get();
|
||||||
|
|
||||||
let config = Arc::new(Config {
|
let config = Arc::new(Config {
|
||||||
max_content_length: 1<<23,
|
max_content_length: 1<<24,
|
||||||
input: String::from("./submissions"),
|
input: String::from("./reddit_subs_202212/"),
|
||||||
output: String::from("./sample.zst"),
|
output: String::from("./sample.zst"),
|
||||||
backend: String::from("http://localhost:1708"),
|
backend: String::from("http://localhost:1708"),
|
||||||
mode: OperatingMode::Sample(0.004),
|
mode: OperatingMode::FullRun,
|
||||||
filename_threshold: None
|
filename_threshold: None,
|
||||||
|
metrics_addr: String::from("0.0.0.0:9914"),
|
||||||
|
contact_info: String::from("scraping-ops@osmarks.net")
|
||||||
});
|
});
|
||||||
|
|
||||||
|
serve_metrics(config.clone()).await?;
|
||||||
|
|
||||||
let timestamp_threshold = match config.mode {
|
let timestamp_threshold = match config.mode {
|
||||||
OperatingMode::Count => None,
|
OperatingMode::Count => None,
|
||||||
_ => {
|
_ => {
|
||||||
match readback_output(&config.output) {
|
match readback_output(&config.output) {
|
||||||
Ok(x) => Some(x),
|
Ok(x) => Some(x),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::warn!("could not read output: {}", e);
|
tracing::warn!("could not read output: {}", e);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -237,19 +348,19 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
|
|
||||||
if let Some((threshold, count)) = timestamp_threshold {
|
if let Some((threshold, count)) = timestamp_threshold {
|
||||||
log::info!("threshold is {}, {} items", threshold, count);
|
tracing::info!("threshold is {}, {} items", threshold, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
let backend = get_backend_config(&config.backend).await;
|
let backend = get_backend_config(&config.backend).await;
|
||||||
|
|
||||||
log::info!("connected to inference server");
|
tracing::info!("connected to inference server");
|
||||||
|
|
||||||
let (entries_tx, mut entries_rx) = mpsc::channel::<Entry>(32768);
|
let (entries_tx, mut entries_rx) = mpsc::channel::<Entry>(32768);
|
||||||
let (buffers_tx, buffers_rx) = mpsc::channel(128);
|
let (buffers_tx, buffers_rx) = mpsc::channel(128);
|
||||||
let (resized_tx, resized_rx) = mpsc::channel(backend.batch);
|
let (resized_tx, resized_rx) = mpsc::channel(backend.batch);
|
||||||
let (final_write_tx, final_write_rx) = mpsc::channel::<ProcessedEntry>(32768);
|
let (final_write_tx, final_write_rx) = mpsc::channel::<ProcessedEntry>(32768);
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")))
|
.user_agent(format!("{}/{} (contact {})", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"), config.contact_info))
|
||||||
.timeout(Duration::from_secs(30))
|
.timeout(Duration::from_secs(30))
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
@ -278,11 +389,13 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
match fetch_file(client, config.clone(), &entry.url).await {
|
match fetch_file(client, config.clone(), &entry.url).await {
|
||||||
Ok(buf) => {
|
Ok(buf) => {
|
||||||
log::debug!("got {}", &entry.url);
|
IMAGES_FETCHED_COUNTER.inc();
|
||||||
|
tracing::debug!("got {}", &entry.url);
|
||||||
buffers_tx.send((entry, buf)).await?;
|
buffers_tx.send((entry, buf)).await?;
|
||||||
},
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::warn!("{} failed: {}", &entry.url, e)
|
IMAGES_FAILED_COUNTER.inc();
|
||||||
|
tracing::debug!("{} failed: {}", &entry.url, e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -296,8 +409,10 @@ async fn main() -> Result<()> {
|
|||||||
_ => Some(tokio::task::spawn({
|
_ => Some(tokio::task::spawn({
|
||||||
let stream = ReceiverStream::new(buffers_rx);
|
let stream = ReceiverStream::new(buffers_rx);
|
||||||
let backend = backend.clone();
|
let backend = backend.clone();
|
||||||
stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, buffer)| {
|
stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, (buffer, mime_type, final_url))| {
|
||||||
let backend = backend.clone();
|
let backend = backend.clone();
|
||||||
|
let size = buffer.len();
|
||||||
|
IMAGE_FILESIZES_HISTOGRAM.with_label_values(&[&mime_type]).observe(size as f64);
|
||||||
let resized_tx = resized_tx.clone();
|
let resized_tx = resized_tx.clone();
|
||||||
async move {
|
async move {
|
||||||
let image_result = tokio::task::spawn_blocking(|| {
|
let image_result = tokio::task::spawn_blocking(|| {
|
||||||
@ -308,12 +423,20 @@ async fn main() -> Result<()> {
|
|||||||
let image = match image_result {
|
let image = match image_result {
|
||||||
Ok(image) => image,
|
Ok(image) => image,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::warn!("loading {} failed: {}", entry.url, e);
|
tracing::debug!("loading {} failed: {}", entry.url, e);
|
||||||
return Result::<(), anyhow::Error>::Ok(());
|
return Result::<(), anyhow::Error>::Ok(());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let dim = (image.width(), image.height());
|
||||||
|
IMAGE_PIXELS_HISTOGRAM.with_label_values(&[&mime_type]).observe(dim.0 as f64 * dim.1 as f64);
|
||||||
|
let metadata = OriginalImageMetadata {
|
||||||
|
mime_type,
|
||||||
|
original_file_size: size,
|
||||||
|
dimension: dim,
|
||||||
|
final_url
|
||||||
|
};
|
||||||
let resized = resize_for_embed(backend.clone(), image).await?;
|
let resized = resize_for_embed(backend.clone(), image).await?;
|
||||||
resized_tx.send((entry, resized)).await?;
|
resized_tx.send((entry, resized, metadata)).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -328,7 +451,7 @@ async fn main() -> Result<()> {
|
|||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
// keep multiple embedding requests in flight
|
// keep multiple embedding requests in flight
|
||||||
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
|
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
|
||||||
let (entries, bytes): (Vec<Entry>, Vec<Vec<u8>>) = batch.into_iter().unzip();
|
let (entries, bytes, batch_dimensions): (Vec<Entry>, Vec<Vec<u8>>, Vec<OriginalImageMetadata>) = batch.into_iter().multiunzip();
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let final_write_tx = final_write_tx.clone();
|
let final_write_tx = final_write_tx.clone();
|
||||||
@ -342,16 +465,19 @@ async fn main() -> Result<()> {
|
|||||||
},
|
},
|
||||||
).await.context("querying CLIP server")?;
|
).await.context("querying CLIP server")?;
|
||||||
|
|
||||||
for (vector, entry) in result.into_iter().zip(entries) {
|
for (vector, entry,
|
||||||
|
metadata) in itertools::izip!(result.into_iter(), entries, batch_dimensions) {
|
||||||
final_write_tx.send(ProcessedEntry {
|
final_write_tx.send(ProcessedEntry {
|
||||||
url: entry.url,
|
url: entry.url,
|
||||||
id: entry.id,
|
id: entry.id,
|
||||||
title: entry.title,
|
title: entry.title,
|
||||||
subreddit: entry.subreddit.unwrap(),
|
subreddit: entry.subreddit.unwrap(),
|
||||||
author: entry.author.unwrap(),
|
author: entry.author.unwrap(),
|
||||||
blob: vector.into_vec(),
|
embedding: vector.into_vec(),
|
||||||
timestamp: entry.created_utc.to_u64()?
|
timestamp: entry.created_utc.to_u64()?,
|
||||||
|
metadata
|
||||||
}).await?;
|
}).await?;
|
||||||
|
IMAGES_PROCESSED_COUNTER.inc();
|
||||||
}
|
}
|
||||||
anyhow::Result::Ok(())
|
anyhow::Result::Ok(())
|
||||||
}
|
}
|
||||||
@ -365,7 +491,7 @@ async fn main() -> Result<()> {
|
|||||||
_ => None
|
_ => None
|
||||||
};
|
};
|
||||||
|
|
||||||
log::info!("working...");
|
tracing::info!("working...");
|
||||||
|
|
||||||
let mut paths = vec![];
|
let mut paths = vec![];
|
||||||
for file in fs::read_dir(&config.input)? {
|
for file in fs::read_dir(&config.input)? {
|
||||||
@ -381,36 +507,26 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let mut file_readers = JoinSet::new();
|
let mut file_readers = JoinSet::new();
|
||||||
|
|
||||||
match config.mode {
|
let readers = match config.mode {
|
||||||
OperatingMode::Count | OperatingMode::Sample(_) => {
|
OperatingMode::Count | OperatingMode::Sample(_) => cpus,
|
||||||
let semaphore = Arc::new(Semaphore::new(cpus));
|
OperatingMode::FullRun => 1
|
||||||
|
};
|
||||||
|
|
||||||
for path in paths {
|
let semaphore = Arc::new(Semaphore::new(readers));
|
||||||
let semaphore = semaphore.clone();
|
|
||||||
let permit = semaphore.acquire_owned().await?;
|
for path in paths {
|
||||||
let entries_tx = entries_tx.clone();
|
let semaphore = semaphore.clone();
|
||||||
let path_ = path.clone();
|
let permit = semaphore.acquire_owned().await?;
|
||||||
log::info!("reading {:?}", path);
|
let entries_tx = entries_tx.clone();
|
||||||
file_readers.spawn_blocking(move || {
|
let path_ = path.clone();
|
||||||
match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
|
tracing::info!("reading {:?}", path);
|
||||||
Ok(_) => (),
|
file_readers.spawn_blocking(move || {
|
||||||
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
|
match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
|
||||||
}
|
Ok(_) => (),
|
||||||
std::mem::drop(permit);
|
Err(e) => tracing::error!("could not parse {:?} {:?}", &path, e)
|
||||||
});
|
|
||||||
}
|
}
|
||||||
},
|
std::mem::drop(permit);
|
||||||
OperatingMode::FullRun => {
|
});
|
||||||
for path in paths {
|
|
||||||
let entries_tx = entries_tx.clone();
|
|
||||||
let path_ = path.clone();
|
|
||||||
log::info!("reading {:?}", path);
|
|
||||||
file_readers.spawn_blocking(move || match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
while let Some(x) = file_readers.try_join_next() {
|
while let Some(x) = file_readers.try_join_next() {
|
||||||
@ -419,9 +535,9 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
std::mem::drop(entries_tx);
|
std::mem::drop(entries_tx);
|
||||||
println!("{:?}", load_task.await?);
|
println!("{:?}", load_task.await?);
|
||||||
if let Some(task) = resize_task { println!("{:?}", task.await?); }
|
if let Some(task) = resize_task { println!("resize: {:?}", task.await?); }
|
||||||
if let Some(task) = embedding_generation_task { println!("{:?}", task.await?) };
|
if let Some(task) = embedding_generation_task { println!("embedding: {:?}", task.await?) };
|
||||||
if let Some(task) = output_writer_task { println!("{:?}", task.await?) };
|
if let Some(task) = output_writer_task { println!("output: {:?}", task.await?) };
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user