mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-13 07:19:54 +00:00
improve observability and fix up Reddit dump for full-scale run
This commit is contained in:
parent
1d0ff95955
commit
7fa14d45ae
1359
Cargo.lock
generated
1359
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
15
Cargo.toml
15
Cargo.toml
@ -6,14 +6,13 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio = { version = "1", features = ["full", "tracing"] }
|
||||
axum = "0.7"
|
||||
image = { version = "0.25", features = ["avif", "avif-native", "nasm"] }
|
||||
reqwest = { version = "0.12", features = ["multipart"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
|
||||
walkdir = "1"
|
||||
log = "0.4"
|
||||
rmp-serde = "1"
|
||||
serde_json = "1"
|
||||
chrono = "0.4"
|
||||
@ -24,7 +23,8 @@ faiss = "0.12"
|
||||
ndarray = "0.15"
|
||||
half = { version = "2" }
|
||||
regex = "1"
|
||||
pretty_env_logger = "0.5"
|
||||
tracing = "0.1"
|
||||
console-subscriber = "0.4"
|
||||
futures-util = "0.3"
|
||||
tokio-stream = "0.1"
|
||||
num_cpus = "1"
|
||||
@ -41,9 +41,8 @@ mimalloc = "0.1"
|
||||
sonic-rs = "0.3"
|
||||
ffmpeg-the-third = "2.0"
|
||||
compact_str = { version = "0.8.0-beta", features = ["serde"] }
|
||||
|
||||
[patch.crates-io]
|
||||
image = { git = "https://github.com/fintelia/image/", branch = "upgrade-zune-jpeg" }
|
||||
itertools = "0.13"
|
||||
async-recursion = "1"
|
||||
|
||||
[[bin]]
|
||||
name = "reddit-dump"
|
||||
@ -52,3 +51,7 @@ path = "src/reddit_dump.rs"
|
||||
[[bin]]
|
||||
name = "video-reader"
|
||||
path = "src/video_reader.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "dump-processor"
|
||||
path = "src/dump_processor.rs"
|
||||
|
@ -4,6 +4,7 @@ use image::{DynamicImage, imageops::FilterType, ImageFormat};
|
||||
use anyhow::Result;
|
||||
use std::io::Cursor;
|
||||
use reqwest::Client;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct InferenceServerConfig {
|
||||
@ -13,11 +14,13 @@ pub struct InferenceServerConfig {
|
||||
}
|
||||
|
||||
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
||||
let new = image.borrow().resize(
|
||||
// the model currently in use wants aspect ratio 1:1 regardless of input
|
||||
// I think this was previously being handled in the CLIP server but that is slightly lossy
|
||||
let new = image.borrow().resize_exact(
|
||||
config.image_size.0,
|
||||
config.image_size.1,
|
||||
FilterType::Lanczos3
|
||||
);
|
||||
FilterType::CatmullRom
|
||||
).into_rgb8();
|
||||
let mut buf = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
new.write_to(&mut csr, ImageFormat::Png)?;
|
||||
@ -46,13 +49,14 @@ pub async fn get_backend_config(clip_server: &str) -> InferenceServerConfig {
|
||||
match fetch_backend_config(&clip_server).await {
|
||||
Ok(backend) => break backend,
|
||||
Err(e) => {
|
||||
log::error!("Backend failed (fetch): {}", e);
|
||||
tracing::error!("Backend failed (fetch): {}", e);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(client, data))]
|
||||
pub async fn query_clip_server<I, O>(client: &Client, base_url: &str, path: &str, data: I) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
|
||||
{
|
||||
let response = client
|
||||
|
53
src/dump_processor.rs
Normal file
53
src/dump_processor.rs
Normal file
@ -0,0 +1,53 @@
|
||||
use anyhow::{Result, Context};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::io::BufReader;
|
||||
use rmp_serde::decode::Error as DecodeError;
|
||||
use std::fs;
|
||||
|
||||
// TODO refactor
|
||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||
struct OriginalImageMetadata {
|
||||
mime_type: String,
|
||||
original_file_size: usize,
|
||||
dimension: (u32, u32),
|
||||
final_url: String
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
struct ProcessedEntry {
|
||||
url: String,
|
||||
id: String,
|
||||
title: String,
|
||||
subreddit: String,
|
||||
author: String,
|
||||
timestamp: u64,
|
||||
#[serde(with = "serde_bytes")]
|
||||
embedding: Vec<u8>,
|
||||
metadata: OriginalImageMetadata
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let path = std::env::args().nth(1).context("missing path")?;
|
||||
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||
let mut stream = BufReader::new(stream);
|
||||
let mut latest_timestamp = 0;
|
||||
let mut count = 0;
|
||||
loop {
|
||||
let res: Result<ProcessedEntry, DecodeError> = rmp_serde::from_read(&mut stream);
|
||||
if res.is_ok() {
|
||||
count += 1;
|
||||
}
|
||||
match res {
|
||||
Ok(x) => {
|
||||
if x.timestamp > latest_timestamp {
|
||||
println!("{} {} https://reddit.com/r/{}/comments/{}", x.timestamp, count, x.subreddit, x.id);
|
||||
latest_timestamp = x.timestamp;
|
||||
}
|
||||
},
|
||||
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||
Err(e) => return Err(e).context("decode fail")
|
||||
}
|
||||
}
|
||||
println!("{} {}", latest_timestamp, count);
|
||||
Ok(())
|
||||
}
|
579
src/main.rs
579
src/main.rs
@ -33,6 +33,7 @@ use faiss::index::scalar_quantizer;
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
||||
use ndarray::ArrayBase;
|
||||
use tracing::instrument;
|
||||
|
||||
mod ocr;
|
||||
mod common;
|
||||
@ -249,7 +250,7 @@ async fn initialize_database(config: &Config) -> Result<SqlitePool> {
|
||||
if (index as i32) < version {
|
||||
continue
|
||||
}
|
||||
log::info!("Migrating to DB version {}", index);
|
||||
tracing::info!("Migrating to DB version {}", index);
|
||||
sqlx::query(sql).execute(&mut *tx).await?;
|
||||
sqlx::query(&format!("PRAGMA user_version = {}", index + 1)).execute(&mut *tx).await?;
|
||||
}
|
||||
@ -317,6 +318,7 @@ fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
|
||||
formats
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc: &Vec<u8>) -> Result<()> {
|
||||
sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename_enc)
|
||||
.execute(conn)
|
||||
@ -324,6 +326,7 @@ async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn write_metadata(conn: &mut SqliteConnection, filename_enc: &Vec<u8>, metadata: FileMetadata) -> Result<()> {
|
||||
ensure_filename_record_exists(conn, filename_enc).await?;
|
||||
let metadata_serialized = rmp_serde::to_vec_named(&metadata)?;
|
||||
@ -333,18 +336,276 @@ async fn write_metadata(conn: &mut SqliteConnection, filename_enc: &Vec<u8>, met
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument]
|
||||
async fn handle_embedding_batch(client: reqwest::Client, config: Arc<WConfig>, pool: SqlitePool, batch: Vec<EmbeddingInput>, video_embed_times: Arc<RwLock<HashMap<CompactString, i64>>>) -> Result<()> {
|
||||
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
|
||||
&client,
|
||||
&config.service.clip_server,
|
||||
"",
|
||||
EmbeddingRequest::Images {
|
||||
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
|
||||
},
|
||||
).await.context("querying CLIP server")?;
|
||||
|
||||
let mut tx = pool.begin().await?;
|
||||
let ts = timestamp();
|
||||
for (i, vector) in result.into_iter().enumerate() {
|
||||
let vector = vector.into_vec();
|
||||
tracing::debug!("embedded {:?}", batch[i].filename);
|
||||
let encoded_filename = batch[i].filename.encode()?;
|
||||
IMAGES_EMBEDDED_COUNTER.inc();
|
||||
ensure_filename_record_exists(&mut *tx, &encoded_filename).await?;
|
||||
match &batch[i].filename {
|
||||
Filename::VideoFrame(container, _) => { video_embed_times.write().await.insert(container.clone(), timestamp()); },
|
||||
_ => ()
|
||||
}
|
||||
sqlx::query!(
|
||||
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
|
||||
ts,
|
||||
vector,
|
||||
encoded_filename
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
tx.commit().await?;
|
||||
anyhow::Result::Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(to_embed_tx, to_thumbnail_tx, to_ocr_tx, to_metadata_write_tx, video_meta))]
|
||||
async fn load_image(record: FileRecord, to_embed_tx: mpsc::Sender<EmbeddingInput>, to_thumbnail_tx: mpsc::Sender<LoadedImage>, to_ocr_tx: mpsc::Sender<LoadedImage>, to_metadata_write_tx: mpsc::Sender<(Filename, FileMetadata)>, config: Arc<WConfig>, video_meta: Arc<RwLock<HashMap<CompactString, FileMetadata>>>) -> Result<()> {
|
||||
let path = Path::new(&config.service.files).join(&*record.filename);
|
||||
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
|
||||
let image = match image {
|
||||
Ok(image) => image,
|
||||
Err(e) => {
|
||||
tracing::warn!("Could not read {} as image: {}", record.filename, e);
|
||||
let filename = record.filename.clone();
|
||||
IMAGES_LOADED_ERROR_COUNTER.inc();
|
||||
let meta = tokio::task::spawn_blocking(move || -> Result<Option<FileMetadata>> {
|
||||
let mut i: u32 = 0;
|
||||
let mut last_metadata = None;
|
||||
let callback = |frame: RgbImage| {
|
||||
let frame: Arc<DynamicImage> = Arc::new(frame.into());
|
||||
let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?;
|
||||
let filename = Filename::VideoFrame(filename.clone(), i);
|
||||
to_embed_tx.blocking_send(EmbeddingInput {
|
||||
image: embed_buf,
|
||||
filename: filename.clone()
|
||||
})?;
|
||||
let meta = FileMetadata {
|
||||
height: frame.height(),
|
||||
width: frame.width(),
|
||||
frames: Some(i + 1)
|
||||
};
|
||||
last_metadata = Some(meta.clone());
|
||||
to_metadata_write_tx.blocking_send((filename.clone(), meta))?;
|
||||
if config.service.enable_thumbs {
|
||||
to_thumbnail_tx.blocking_send(LoadedImage {
|
||||
image: frame.clone(),
|
||||
filename,
|
||||
original_filesize: None,
|
||||
fast_thumbnails_only: true
|
||||
})?;
|
||||
}
|
||||
i += 1;
|
||||
Ok(())
|
||||
};
|
||||
match video_reader::run(&path, callback, config.service.video_frame_interval) {
|
||||
Ok(()) => {
|
||||
VIDEOS_LOADED_COUNTER.inc();
|
||||
return anyhow::Result::Ok(last_metadata)
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Could not read {} as video: {}", filename, e);
|
||||
VIDEOS_LOADED_ERROR_COUNTER.inc();
|
||||
}
|
||||
}
|
||||
return anyhow::Result::Ok(last_metadata)
|
||||
}).await??;
|
||||
if let Some(meta) = meta {
|
||||
video_meta.write().await.insert(record.filename, meta);
|
||||
}
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
let filename = Filename::Actual(record.filename);
|
||||
if record.needs_metadata {
|
||||
let metadata = FileMetadata {
|
||||
width: image.width(),
|
||||
height: image.height(),
|
||||
frames: None
|
||||
};
|
||||
to_metadata_write_tx.send((filename.clone(), metadata)).await?;
|
||||
}
|
||||
IMAGES_LOADED_COUNTER.inc();
|
||||
if record.needs_embed {
|
||||
let resized = resize_for_embed(config.backend.clone(), image.clone()).await?;
|
||||
|
||||
to_embed_tx.send(EmbeddingInput { image: resized, filename: filename.clone() }).await?
|
||||
}
|
||||
if record.needs_thumbnail {
|
||||
to_thumbnail_tx
|
||||
.send(LoadedImage {
|
||||
image: image.clone(),
|
||||
filename: filename.clone(),
|
||||
original_filesize: Some(std::fs::metadata(&path)?.len() as usize),
|
||||
fast_thumbnails_only: false
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
if record.needs_ocr {
|
||||
to_ocr_tx
|
||||
.send(LoadedImage {
|
||||
image,
|
||||
filename: filename.clone(),
|
||||
original_filesize: None,
|
||||
fast_thumbnails_only: true
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(video_thumb_times, pool, formats))]
|
||||
async fn generate_thumbnail(image: LoadedImage, config: Arc<WConfig>, video_thumb_times: Arc<RwLock<HashMap<CompactString, i64>>>, pool: SqlitePool, formats: Arc<HashMap<String, ImageFormatConfig>>) -> Result<()> {
|
||||
use image::codecs::*;
|
||||
|
||||
let filename = image.filename.clone();
|
||||
tracing::debug!("thumbnailing {:?}", filename);
|
||||
|
||||
let generated_formats = tokio::task::spawn_blocking(move || {
|
||||
let mut generated_formats = Vec::new();
|
||||
let rgb = DynamicImage::from(image.image.to_rgb8());
|
||||
for (format_name, format_config) in &*formats {
|
||||
if !format_config.is_fast && image.fast_thumbnails_only { continue }
|
||||
let resized = if format_config.target_filesize != 0 {
|
||||
let mut lb = 1;
|
||||
let mut ub = 100;
|
||||
loop {
|
||||
let quality = (lb + ub) / 2;
|
||||
let thumbnail = rgb.resize(
|
||||
format_config.target_width.min(rgb.width()),
|
||||
u32::MAX,
|
||||
FilterType::Lanczos3,
|
||||
);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
|
||||
match format_config.format {
|
||||
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
|
||||
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
|
||||
_ => unimplemented!()
|
||||
}?;
|
||||
if buf.len() > format_config.target_filesize {
|
||||
ub = quality;
|
||||
} else {
|
||||
lb = quality + 1;
|
||||
}
|
||||
if lb >= ub {
|
||||
break buf;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let thumbnail = rgb.resize(
|
||||
format_config.target_width.min(rgb.width()),
|
||||
u32::MAX,
|
||||
FilterType::Lanczos3,
|
||||
);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
match format_config.format {
|
||||
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
|
||||
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
|
||||
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
|
||||
_ => unimplemented!()
|
||||
}?;
|
||||
buf
|
||||
};
|
||||
if resized.len() < image.original_filesize.unwrap_or(usize::MAX) {
|
||||
generated_formats.push(format_name.clone());
|
||||
let thumbnail_path = Path::new(&config.service.thumbs_path).join(
|
||||
generate_thumbnail_filename(
|
||||
&image.filename,
|
||||
format_name,
|
||||
format_config,
|
||||
),
|
||||
);
|
||||
THUMBNAILS_GENERATED_COUNTER.get_metric_with_label_values(&[format_name]).unwrap().inc();
|
||||
std::fs::write(thumbnail_path, resized)?;
|
||||
}
|
||||
}
|
||||
Ok::<Vec<String>, anyhow::Error>(generated_formats)
|
||||
}).await??;
|
||||
|
||||
IMAGES_THUMBNAILED_COUNTER.inc();
|
||||
let formats_data = rmp_serde::to_vec(&generated_formats)?;
|
||||
let ts = timestamp();
|
||||
let filename_enc = filename.encode()?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
|
||||
match filename {
|
||||
Filename::VideoFrame(container, _) => { video_thumb_times.write().await.insert(container.clone(), timestamp()); },
|
||||
_ => ()
|
||||
}
|
||||
sqlx::query!(
|
||||
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
|
||||
formats_data,
|
||||
ts,
|
||||
filename_enc
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument]
|
||||
async fn do_ocr(image: LoadedImage, config: Arc<WConfig>, client: Client, pool: SqlitePool) -> Result<()> {
|
||||
tracing::debug!("OCRing {:?}", image.filename);
|
||||
let scan = match scan_image(&client, &image.image).await {
|
||||
Ok(scan) => scan,
|
||||
Err(e) => {
|
||||
IMAGES_OCRED_ERROR_COUNTER.inc();
|
||||
tracing::error!("OCR failure {:?}: {}", image.filename, e);
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
IMAGES_OCRED_COUNTER.inc();
|
||||
let ocr_text = scan
|
||||
.iter()
|
||||
.map(|segment| segment.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let ocr_data = rmp_serde::to_vec(&scan)?;
|
||||
let ts = timestamp();
|
||||
let filename_enc = image.filename.encode()?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
|
||||
sqlx::query!(
|
||||
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
|
||||
ocr_text,
|
||||
ocr_data,
|
||||
ts,
|
||||
filename_enc
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument]
|
||||
async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
let pool = initialize_database(&config.service).await?;
|
||||
let client = Client::new();
|
||||
|
||||
|
||||
let formats = image_formats(&config.service);
|
||||
|
||||
|
||||
let (to_process_tx, to_process_rx) = mpsc::channel::<FileRecord>(100);
|
||||
let (to_embed_tx, to_embed_rx) = mpsc::channel(config.backend.batch as usize);
|
||||
let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30);
|
||||
let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30);
|
||||
let (to_metadata_write_tx, mut to_metadata_write_rx) = mpsc::channel::<(Filename, FileMetadata)>(100);
|
||||
|
||||
|
||||
let cpus = num_cpus::get();
|
||||
|
||||
let video_meta = Arc::new(RwLock::new(HashMap::new()));
|
||||
@ -363,102 +624,10 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
let to_ocr_tx = to_ocr_tx.clone();
|
||||
let video_meta = video_meta.clone();
|
||||
let to_metadata_write_tx = to_metadata_write_tx.clone();
|
||||
async move {
|
||||
let path = Path::new(&config.service.files).join(&*record.filename);
|
||||
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
|
||||
let image = match image {
|
||||
Ok(image) => image,
|
||||
Err(e) => {
|
||||
log::warn!("Could not read {} as image: {}", record.filename, e);
|
||||
let filename = record.filename.clone();
|
||||
IMAGES_LOADED_ERROR_COUNTER.inc();
|
||||
let meta = tokio::task::spawn_blocking(move || -> Result<Option<FileMetadata>> {
|
||||
let mut i: u32 = 0;
|
||||
let mut last_metadata = None;
|
||||
let callback = |frame: RgbImage| {
|
||||
let frame: Arc<DynamicImage> = Arc::new(frame.into());
|
||||
let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?;
|
||||
let filename = Filename::VideoFrame(filename.clone(), i);
|
||||
to_embed_tx.blocking_send(EmbeddingInput {
|
||||
image: embed_buf,
|
||||
filename: filename.clone()
|
||||
})?;
|
||||
let meta = FileMetadata {
|
||||
height: frame.height(),
|
||||
width: frame.width(),
|
||||
frames: Some(i + 1)
|
||||
};
|
||||
last_metadata = Some(meta.clone());
|
||||
to_metadata_write_tx.blocking_send((filename.clone(), meta))?;
|
||||
if config.service.enable_thumbs {
|
||||
to_thumbnail_tx.blocking_send(LoadedImage {
|
||||
image: frame.clone(),
|
||||
filename,
|
||||
original_filesize: None,
|
||||
fast_thumbnails_only: true
|
||||
})?;
|
||||
}
|
||||
i += 1;
|
||||
Ok(())
|
||||
};
|
||||
match video_reader::run(&path, callback, config.service.video_frame_interval) {
|
||||
Ok(()) => {
|
||||
VIDEOS_LOADED_COUNTER.inc();
|
||||
return anyhow::Result::Ok(last_metadata)
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Could not read {} as video: {}", filename, e);
|
||||
VIDEOS_LOADED_ERROR_COUNTER.inc();
|
||||
}
|
||||
}
|
||||
return anyhow::Result::Ok(last_metadata)
|
||||
}).await??;
|
||||
if let Some(meta) = meta {
|
||||
video_meta.write().await.insert(record.filename, meta);
|
||||
}
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
let filename = Filename::Actual(record.filename);
|
||||
if record.needs_metadata {
|
||||
let metadata = FileMetadata {
|
||||
width: image.width(),
|
||||
height: image.height(),
|
||||
frames: None
|
||||
};
|
||||
to_metadata_write_tx.send((filename.clone(), metadata)).await?;
|
||||
}
|
||||
IMAGES_LOADED_COUNTER.inc();
|
||||
if record.needs_embed {
|
||||
let resized = resize_for_embed(config.backend.clone(), image.clone()).await?;
|
||||
|
||||
to_embed_tx.send(EmbeddingInput { image: resized, filename: filename.clone() }).await?
|
||||
}
|
||||
if record.needs_thumbnail {
|
||||
to_thumbnail_tx
|
||||
.send(LoadedImage {
|
||||
image: image.clone(),
|
||||
filename: filename.clone(),
|
||||
original_filesize: Some(std::fs::metadata(&path)?.len() as usize),
|
||||
fast_thumbnails_only: false
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
if record.needs_ocr {
|
||||
to_ocr_tx
|
||||
.send(LoadedImage {
|
||||
image,
|
||||
filename: filename.clone(),
|
||||
original_filesize: None,
|
||||
fast_thumbnails_only: true
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
load_image(record, to_embed_tx, to_thumbnail_tx, to_ocr_tx, to_metadata_write_tx, config, video_meta)
|
||||
})
|
||||
});
|
||||
|
||||
|
||||
let metadata_writer: JoinHandle<Result<()>> = tokio::spawn({
|
||||
let pool = pool.clone();
|
||||
async move {
|
||||
@ -468,7 +637,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.service.enable_thumbs {
|
||||
let config = config.clone();
|
||||
let pool = pool.clone();
|
||||
@ -477,145 +646,29 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
let video_thumb_times = video_thumb_times.clone();
|
||||
Some(tokio::spawn({
|
||||
stream.try_for_each_concurrent(Some(cpus), move |image| {
|
||||
use image::codecs::*;
|
||||
|
||||
let formats = formats.clone();
|
||||
let config = config.clone();
|
||||
let pool = pool.clone();
|
||||
let video_thumb_times = video_thumb_times.clone();
|
||||
async move {
|
||||
let filename = image.filename.clone();
|
||||
log::debug!("thumbnailing {:?}", filename);
|
||||
let generated_formats = tokio::task::spawn_blocking(move || {
|
||||
let mut generated_formats = Vec::new();
|
||||
let rgb = DynamicImage::from(image.image.to_rgb8());
|
||||
for (format_name, format_config) in &*formats {
|
||||
if !format_config.is_fast && image.fast_thumbnails_only { continue }
|
||||
let resized = if format_config.target_filesize != 0 {
|
||||
let mut lb = 1;
|
||||
let mut ub = 100;
|
||||
loop {
|
||||
let quality = (lb + ub) / 2;
|
||||
let thumbnail = rgb.resize(
|
||||
format_config.target_width.min(rgb.width()),
|
||||
u32::MAX,
|
||||
FilterType::Lanczos3,
|
||||
);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
|
||||
match format_config.format {
|
||||
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
|
||||
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
|
||||
_ => unimplemented!()
|
||||
}?;
|
||||
if buf.len() > format_config.target_filesize {
|
||||
ub = quality;
|
||||
} else {
|
||||
lb = quality + 1;
|
||||
}
|
||||
if lb >= ub {
|
||||
break buf;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let thumbnail = rgb.resize(
|
||||
format_config.target_width.min(rgb.width()),
|
||||
u32::MAX,
|
||||
FilterType::Lanczos3,
|
||||
);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
match format_config.format {
|
||||
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
|
||||
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
|
||||
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
|
||||
_ => unimplemented!()
|
||||
}?;
|
||||
buf
|
||||
};
|
||||
if resized.len() < image.original_filesize.unwrap_or(usize::MAX) {
|
||||
generated_formats.push(format_name.clone());
|
||||
let thumbnail_path = Path::new(&config.service.thumbs_path).join(
|
||||
generate_thumbnail_filename(
|
||||
&image.filename,
|
||||
format_name,
|
||||
format_config,
|
||||
),
|
||||
);
|
||||
THUMBNAILS_GENERATED_COUNTER.get_metric_with_label_values(&[format_name]).unwrap().inc();
|
||||
std::fs::write(thumbnail_path, resized)?;
|
||||
}
|
||||
}
|
||||
Ok::<Vec<String>, anyhow::Error>(generated_formats)
|
||||
}).await??;
|
||||
IMAGES_THUMBNAILED_COUNTER.inc();
|
||||
let formats_data = rmp_serde::to_vec(&generated_formats)?;
|
||||
let ts = timestamp();
|
||||
let filename_enc = filename.encode()?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
|
||||
match filename {
|
||||
Filename::VideoFrame(container, _) => { video_thumb_times.write().await.insert(container.clone(), timestamp()); },
|
||||
_ => ()
|
||||
}
|
||||
sqlx::query!(
|
||||
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
|
||||
formats_data,
|
||||
ts,
|
||||
filename_enc
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
generate_thumbnail(image, config, video_thumb_times, pool, formats)
|
||||
})
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
|
||||
// TODO: save OCR errors and don't retry
|
||||
let ocr: Option<JoinHandle<Result<()>>> = if config.service.enable_ocr {
|
||||
let client = client.clone();
|
||||
let pool = pool.clone();
|
||||
let config = config.clone();
|
||||
let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
|
||||
Some(tokio::spawn({
|
||||
stream.try_for_each_concurrent(Some(config.service.ocr_concurrency), move |image| {
|
||||
let client = client.clone();
|
||||
let pool = pool.clone();
|
||||
async move {
|
||||
log::debug!("OCRing {:?}", image.filename);
|
||||
let scan = match scan_image(&client, &image.image).await {
|
||||
Ok(scan) => scan,
|
||||
Err(e) => {
|
||||
IMAGES_OCRED_ERROR_COUNTER.inc();
|
||||
log::error!("OCR failure {:?}: {}", image.filename, e);
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
IMAGES_OCRED_COUNTER.inc();
|
||||
let ocr_text = scan
|
||||
.iter()
|
||||
.map(|segment| segment.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let ocr_data = rmp_serde::to_vec(&scan)?;
|
||||
let ts = timestamp();
|
||||
let filename_enc = image.filename.encode()?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
|
||||
sqlx::query!(
|
||||
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
|
||||
ocr_text,
|
||||
ocr_data,
|
||||
ts,
|
||||
filename_enc
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
let config = config.clone();
|
||||
do_ocr(image, config, client, pool)
|
||||
})
|
||||
}))
|
||||
} else {
|
||||
@ -634,45 +687,12 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
let config = config.clone();
|
||||
let pool = pool.clone();
|
||||
let video_embed_times = video_embed_times.clone();
|
||||
async move {
|
||||
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
|
||||
&client,
|
||||
&config.service.clip_server,
|
||||
"",
|
||||
EmbeddingRequest::Images {
|
||||
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
|
||||
},
|
||||
).await.context("querying CLIP server")?;
|
||||
|
||||
let mut tx = pool.begin().await?;
|
||||
let ts = timestamp();
|
||||
for (i, vector) in result.into_iter().enumerate() {
|
||||
let vector = vector.into_vec();
|
||||
log::debug!("embedded {:?}", batch[i].filename);
|
||||
let encoded_filename = batch[i].filename.encode()?;
|
||||
IMAGES_EMBEDDED_COUNTER.inc();
|
||||
ensure_filename_record_exists(&mut *tx, &encoded_filename).await?;
|
||||
match &batch[i].filename {
|
||||
Filename::VideoFrame(container, _) => { video_embed_times.write().await.insert(container.clone(), timestamp()); },
|
||||
_ => ()
|
||||
}
|
||||
sqlx::query!(
|
||||
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
|
||||
ts,
|
||||
vector,
|
||||
encoded_filename
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
tx.commit().await?;
|
||||
anyhow::Result::Ok(())
|
||||
}
|
||||
handle_embedding_batch(client, config, pool, batch, video_embed_times)
|
||||
})
|
||||
});
|
||||
|
||||
|
||||
let mut actual_filenames = HashMap::new();
|
||||
|
||||
|
||||
// blocking OS calls
|
||||
tokio::task::block_in_place(|| -> anyhow::Result<()> {
|
||||
for entry in WalkDir::new(config.service.files.as_str()) {
|
||||
@ -688,7 +708,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
log::debug!("finished reading filenames");
|
||||
tracing::debug!("finished reading filenames");
|
||||
|
||||
for (filename, (_path, modtime)) in actual_filenames.iter() {
|
||||
let modtime = *modtime;
|
||||
@ -721,7 +741,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
}
|
||||
};
|
||||
if let Some(record) = new_record {
|
||||
log::debug!("processing {}", record.filename);
|
||||
tracing::debug!("processing {}", record.filename);
|
||||
// we need to exit here to actually capture the error
|
||||
if !to_process_tx.send(record).await.is_ok() {
|
||||
break
|
||||
@ -730,20 +750,20 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
}
|
||||
|
||||
drop(to_process_tx);
|
||||
|
||||
|
||||
embedding_generation.await?.context("generating embeddings")?;
|
||||
metadata_writer.await?.context("writing metadata")?;
|
||||
|
||||
|
||||
if let Some(thumbnail_generation) = thumbnail_generation {
|
||||
thumbnail_generation.await?.context("generating thumbnails")?;
|
||||
}
|
||||
|
||||
|
||||
if let Some(ocr) = ocr {
|
||||
ocr.await?.context("OCRing")?;
|
||||
}
|
||||
|
||||
image_loading.await?.context("loading images")?;
|
||||
|
||||
|
||||
let stored: Vec<Vec<u8>> = sqlx::query_scalar("SELECT filename FROM files").fetch_all(&pool).await?;
|
||||
let mut tx = pool.begin().await?;
|
||||
let video_meta = video_meta.read().await;
|
||||
@ -785,13 +805,14 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
log::info!("Ingest done");
|
||||
|
||||
tracing::info!("Ingest done");
|
||||
|
||||
Result::Ok(())
|
||||
}
|
||||
|
||||
const INDEX_ADD_BATCH: usize = 512;
|
||||
|
||||
#[instrument]
|
||||
async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
|
||||
let pool = initialize_database(&config.service).await?;
|
||||
|
||||
@ -846,7 +867,7 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
|
||||
} else {
|
||||
index.metadata.push(None);
|
||||
}
|
||||
|
||||
|
||||
for format_string in &formats {
|
||||
let mut found = false;
|
||||
for (i, name) in index.format_names.iter().enumerate() {
|
||||
@ -904,6 +925,7 @@ struct QueryRequest {
|
||||
include_video: bool
|
||||
}
|
||||
|
||||
#[instrument(skip(index))]
|
||||
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
|
||||
let result = index.vectors.search(&query, k as usize)?;
|
||||
|
||||
@ -940,6 +962,7 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(config, client, index))]
|
||||
async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
|
||||
let mut total_embedding = ndarray::Array::from(vec![0.0; config.backend.embedding_size]);
|
||||
|
||||
@ -973,8 +996,8 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
||||
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
||||
}
|
||||
}
|
||||
|
||||
let mut batches = vec![];
|
||||
|
||||
let mut batches = vec![];
|
||||
|
||||
if !image_batch.is_empty() {
|
||||
batches.push(
|
||||
@ -1016,12 +1039,13 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct FrontendInit {
|
||||
n_total: u64,
|
||||
predefined_embedding_names: Vec<String>
|
||||
predefined_embedding_names: Vec<String>,
|
||||
d_emb: usize
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
pretty_env_logger::init();
|
||||
console_subscriber::init();
|
||||
|
||||
let config_path = std::env::args().nth(1).expect("Missing config file path");
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_path)?)?;
|
||||
@ -1062,23 +1086,23 @@ async fn main() -> Result<()> {
|
||||
let index = index.clone();
|
||||
async move {
|
||||
loop {
|
||||
log::info!("Ingest running");
|
||||
tracing::info!("Ingest running");
|
||||
match ingest_files(config.clone()).await {
|
||||
Ok(_) => {
|
||||
match build_index(config.clone()).await {
|
||||
Ok(new_index) => {
|
||||
LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64);
|
||||
*index.write().await = new_index;
|
||||
log::info!("Index loaded");
|
||||
tracing::info!("Index loaded");
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Index build failed: {:?}", e);
|
||||
tracing::error!("Index build failed: {:?}", e);
|
||||
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Ingest failed: {:?}", e);
|
||||
tracing::error!("Ingest failed: {:?}", e);
|
||||
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
|
||||
}
|
||||
}
|
||||
@ -1106,11 +1130,12 @@ async fn main() -> Result<()> {
|
||||
.route("/", get(|_req: ()| async move {
|
||||
Json(FrontendInit {
|
||||
n_total: index_.read().await.vectors.ntotal(),
|
||||
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect()
|
||||
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect(),
|
||||
d_emb: config__.backend.embedding_size
|
||||
})
|
||||
}))
|
||||
.route("/reload", post(|_req: ()| async move {
|
||||
log::info!("Requesting index reload");
|
||||
tracing::info!("Requesting index reload");
|
||||
let mut done_rx = done_tx.clone().subscribe();
|
||||
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
|
||||
match done_rx.recv().await {
|
||||
@ -1141,9 +1166,9 @@ async fn main() -> Result<()> {
|
||||
.layer(cors);
|
||||
|
||||
let addr = format!("0.0.0.0:{}", config_.service.port);
|
||||
log::info!("Starting server on {}", addr);
|
||||
tracing::info!("Starting server on {}", addr);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ use reqwest::{
|
||||
use serde_json::Value;
|
||||
use std::{io::Cursor, time::{SystemTime, UNIX_EPOCH}};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::instrument;
|
||||
|
||||
const CALLBACK_REGEX: &str = r">AF_initDataCallback\((\{key: 'ds:1'.*?\})\);</script>";
|
||||
const MAX_DIM: u32 = 1024;
|
||||
@ -45,6 +46,7 @@ fn rationalize_coords_format1(
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(client, image))]
|
||||
async fn scan_image_chunk(
|
||||
client: &Client,
|
||||
image: &[u8],
|
||||
@ -130,13 +132,14 @@ async fn scan_image_chunk(
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[instrument(skip(client))]
|
||||
pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanResult> {
|
||||
let mut result = ScanResult::new();
|
||||
let (width, height) = image.dimensions();
|
||||
|
||||
let (width, height, image) = if width > MAX_DIM {
|
||||
let height = ((height as f64) * (MAX_DIM as f64) / (width as f64)).round() as u32;
|
||||
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::Lanczos3));
|
||||
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::CatmullRom));
|
||||
(MAX_DIM, height, std::borrow::Cow::Owned(new_image))
|
||||
} else {
|
||||
(width, height, std::borrow::Cow::Borrowed(image))
|
||||
@ -170,4 +173,4 @@ pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanRes
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
@ -1,15 +1,18 @@
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use common::resize_for_embed;
|
||||
use itertools::Itertools;
|
||||
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, time::Duration, sync::Arc, str::FromStr, path::PathBuf};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use lazy_static::lazy_static;
|
||||
use regex::{RegexSet, bytes, Regex};
|
||||
use regex::{bytes, Regex, RegexSet, RegexSetBuilder};
|
||||
use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use reqwest::Client;
|
||||
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||
use image::{DynamicImage, io::Reader as ImageReader};
|
||||
use image::{DynamicImage, ImageReader};
|
||||
use mimalloc::MiMalloc;
|
||||
use tracing::instrument;
|
||||
use prometheus::{Encoder, register_int_counter, IntCounter, register_histogram_vec, HistogramVec};
|
||||
|
||||
#[global_allocator]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
@ -50,6 +53,14 @@ struct Entry {
|
||||
id: String
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||
struct OriginalImageMetadata {
|
||||
mime_type: String,
|
||||
original_file_size: usize,
|
||||
dimension: (u32, u32),
|
||||
final_url: String
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
struct ProcessedEntry {
|
||||
url: String,
|
||||
@ -58,10 +69,13 @@ struct ProcessedEntry {
|
||||
subreddit: String,
|
||||
author: String,
|
||||
timestamp: u64,
|
||||
blob: Vec<u8>
|
||||
#[serde(with = "serde_bytes")]
|
||||
embedding: Vec<u8>,
|
||||
metadata: OriginalImageMetadata
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
// we do exclude galleries doing this but there don't seem to be any in the dataset
|
||||
static ref URL_IGNORE: RegexSet = RegexSet::new([
|
||||
r"//reddit\.com",
|
||||
r"\.html?",
|
||||
@ -69,16 +83,39 @@ lazy_static! {
|
||||
r"\?articleid=",
|
||||
r"\.aspx?",
|
||||
r"\.xml",
|
||||
r"//youtube\.com",
|
||||
r"/rss/",
|
||||
r"//vimeo\.com",
|
||||
r"//www\.youtube\.com",
|
||||
r"//youtu\.be",
|
||||
r"//www\.reddit\.com",
|
||||
r"//v\.redd\.it",
|
||||
r"\.gifv$",
|
||||
r"youtube\.com/user/"
|
||||
// TODO fill in more things, maybe try and collect thumbnails or something
|
||||
]).unwrap();
|
||||
static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
|
||||
.into_iter().map(str::as_bytes).collect();
|
||||
static ref URL_MUST_CONTAIN: RegexSet = RegexSetBuilder::new([
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"png",
|
||||
"webp",
|
||||
r"\.gif",
|
||||
"=gif",
|
||||
"jpeg",
|
||||
"bmp",
|
||||
"tiff",
|
||||
"avif",
|
||||
"webp",
|
||||
"imgur",
|
||||
"image",
|
||||
r"//i\.",
|
||||
"img",
|
||||
r"cdn\.",
|
||||
r"media\.",
|
||||
"/i/",
|
||||
"/media",
|
||||
r"youtu\.be",
|
||||
r"youtube\.com",
|
||||
]).case_insensitive(true).build().unwrap();
|
||||
static ref ACCEPTABLE_FILETYPES: HashSet<&'static str> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
|
||||
.into_iter().collect();
|
||||
static ref OBJECT_HACKY_IGNORE: bytes::RegexSet = bytes::RegexSet::new([
|
||||
r#""author":"\[deleted\]""#,
|
||||
r#""promoted":true"#, // these seem to be ads which are in the data for some reason, and lack some important fields
|
||||
@ -86,11 +123,34 @@ lazy_static! {
|
||||
r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record
|
||||
]).unwrap();
|
||||
static ref URL_REPLACEMENT_RULES: Vec<(Regex, &'static str)> = [
|
||||
(r"//imgur.com/([A-Za-z0-9]+)", r"//i.imgur.com/$1.jpg"),
|
||||
(r"^http://", r"https://")
|
||||
(r"imgur\.com/([A-Za-z0-9]+),", r"imgur.com/$1"),
|
||||
(r"//imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
|
||||
(r"//www\.imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
|
||||
(r"//m\.imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
|
||||
(r"^http://", r"https://"),
|
||||
(r"//youtu\.be/(.*)", r"//youtube.com/watch?v=$1"),
|
||||
(r"//[a-z]+\.youtube\.com/(.*)", r"//youtube.com/$1"),
|
||||
(r"//www.youtube.com/attribution_link?.*v%3D([A-Za-z0-9_-]+).*", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"), // redirect to youtube thumbnail API
|
||||
(r"//youtube.com/embed/([A-Za-z0-9_-]+)", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"),
|
||||
(r"//youtube\.com/(?:.*)v=([A-Za-z0-9_-]+)(?:.*)", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"),
|
||||
(r"&", "&") // this is such an intensely cursed feature of the dumps
|
||||
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), e)).collect();
|
||||
|
||||
static ref HTML_EXTRACTION_RULES: Vec<(Regex, Regex)> = [
|
||||
(r"//imgur\.com/a/[A-Za-z0-9]+", r#"<meta name="twitter:image" data-react-helmet="true" content="([^"]+)">"#),
|
||||
(r"//imgur\.com/gallery/[A-Za-z0-9]+", r#"<meta name="twitter:image" data-react-helmet="true" content="([^"]+)">"#),
|
||||
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), Regex::new(e).unwrap())).collect();
|
||||
|
||||
static ref IMAGES_FETCHED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_fetched", "images fetched").unwrap();
|
||||
static ref IMAGES_PROCESSED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_processed", "images processed").unwrap();
|
||||
static ref ENTRIES_PROCESSED_COUNTER: IntCounter = register_int_counter!("mse_scrape_entries_processed", "entries processed").unwrap();
|
||||
static ref IMAGES_FAILED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_failed", "images failed").unwrap();
|
||||
static ref IMAGE_FILESIZES_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_filesizes", "filesizes of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.5, 29).unwrap()).unwrap();
|
||||
static ref IMAGE_PIXELS_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_pixels", "pixel count of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.3, 53).unwrap()).unwrap();
|
||||
static ref HTML_EXTRACTS_COUNTER: IntCounter = register_int_counter!("mse_scrape_html_extracts", "html extraction operations").unwrap();
|
||||
}
|
||||
|
||||
#[instrument(skip(tx))]
|
||||
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Option<u64>) -> Result<()> {
|
||||
let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||
stream.window_log_max(31)?;
|
||||
@ -105,15 +165,16 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
|
||||
buf.clear();
|
||||
continue;
|
||||
}
|
||||
ENTRIES_PROCESSED_COUNTER.inc();
|
||||
let entry = match sonic_rs::serde::from_slice::<Entry>(buf.as_slice()) {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
log::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf));
|
||||
tracing::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf));
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() && entry.subreddit.is_some() {
|
||||
if !URL_IGNORE.is_match(&entry.url) {
|
||||
if !URL_IGNORE.is_match(&entry.url) && URL_MUST_CONTAIN.is_match(&entry.url) {
|
||||
match &entry.post_hint {
|
||||
Some(x) if x == "na" || x == "image" => {
|
||||
// Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine".
|
||||
@ -127,7 +188,7 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
|
||||
},
|
||||
None => true
|
||||
};
|
||||
|
||||
|
||||
if after_threshold { tx.blocking_send(entry)?; }
|
||||
},
|
||||
_ => ()
|
||||
@ -139,23 +200,38 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Config {
|
||||
max_content_length: usize,
|
||||
input: String,
|
||||
output: String,
|
||||
backend: String,
|
||||
mode: OperatingMode,
|
||||
filename_threshold: Option<String>
|
||||
filename_threshold: Option<String>,
|
||||
metrics_addr: String,
|
||||
contact_info: String
|
||||
}
|
||||
|
||||
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<Vec<u8>> {
|
||||
// inelegant but I can't get it to work using Cows
|
||||
#[instrument(skip(client, config))]
|
||||
#[async_recursion::async_recursion]
|
||||
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<(Vec<u8>, String, String)> {
|
||||
let mut url = url.to_string();
|
||||
for (regex, replacement) in URL_REPLACEMENT_RULES.iter() {
|
||||
url = regex.replace(&url, *replacement).to_string();
|
||||
}
|
||||
|
||||
let mut html_extract_rule = None;
|
||||
|
||||
for (url_rule, extract_rule) in HTML_EXTRACTION_RULES.iter() {
|
||||
if url_rule.is_match(&url) {
|
||||
html_extract_rule = Some(extract_rule);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut response = client.get(&*url).send().await?;
|
||||
if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) {
|
||||
let content_type = std::str::from_utf8(&response.headers().get(reqwest::header::CONTENT_TYPE).context("no content type")?.as_bytes())?.to_owned();
|
||||
if !(ACCEPTABLE_FILETYPES.contains(&content_type[..]) || (html_extract_rule.is_some() && content_type == "text/html")) {
|
||||
return Err(anyhow!("invalid Content-Type"));
|
||||
}
|
||||
match response.content_length() {
|
||||
@ -169,11 +245,24 @@ async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) ->
|
||||
return Err(anyhow!("response too large"));
|
||||
}
|
||||
}
|
||||
Ok(buffer)
|
||||
if let Some(extract_rule) = html_extract_rule {
|
||||
if content_type == "text/html" {
|
||||
let buffer = String::from_utf8_lossy(&buffer).to_string();
|
||||
if let Some(mat) = extract_rule.captures(&buffer) {
|
||||
let new_url = mat.get(1).unwrap().as_str();
|
||||
HTML_EXTRACTS_COUNTER.inc();
|
||||
tracing::debug!("found new URL: {}", new_url);
|
||||
return fetch_file(client, config, new_url).await;
|
||||
} else {
|
||||
return Err(anyhow!("no extraction match"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((buffer, content_type, response.url().to_string()))
|
||||
}
|
||||
|
||||
fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result<()> {
|
||||
let mut out = fs::File::options().append(true).open(&config.output)?;
|
||||
let mut out = fs::File::options().create(true).append(true).open(&config.output)?;
|
||||
let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish();
|
||||
let mut buf_stream = BufWriter::new(stream);
|
||||
while let Some(x) = rx.blocking_recv() {
|
||||
@ -182,12 +271,14 @@ fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum OperatingMode {
|
||||
Count,
|
||||
Sample(f32),
|
||||
FullRun
|
||||
}
|
||||
|
||||
#[instrument]
|
||||
fn readback_output(path: &str) -> Result<(u64, usize)> {
|
||||
use rmp_serde::decode::Error;
|
||||
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||
@ -208,27 +299,47 @@ fn readback_output(path: &str) -> Result<(u64, usize)> {
|
||||
Ok((latest_timestamp, count))
|
||||
}
|
||||
|
||||
async fn serve_metrics(config: Arc<Config>) -> Result<()> {
|
||||
let metrics = axum::Router::new().route("/metrics", axum::routing::get(|| async move {
|
||||
let mut buffer = Vec::new();
|
||||
let encoder = prometheus::TextEncoder::new();
|
||||
let metric_families = prometheus::gather();
|
||||
encoder.encode(&metric_families, &mut buffer).unwrap();
|
||||
buffer
|
||||
}));
|
||||
let listener = tokio::net::TcpListener::bind(&config.metrics_addr).await?;
|
||||
tokio::task::spawn(async move {
|
||||
let _ = axum::serve(listener, metrics).await;
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
pretty_env_logger::init();
|
||||
console_subscriber::init();
|
||||
|
||||
let cpus = num_cpus::get();
|
||||
|
||||
let config = Arc::new(Config {
|
||||
max_content_length: 1<<23,
|
||||
input: String::from("./submissions"),
|
||||
max_content_length: 1<<24,
|
||||
input: String::from("./reddit_subs_202212/"),
|
||||
output: String::from("./sample.zst"),
|
||||
backend: String::from("http://localhost:1708"),
|
||||
mode: OperatingMode::Sample(0.004),
|
||||
filename_threshold: None
|
||||
mode: OperatingMode::FullRun,
|
||||
filename_threshold: None,
|
||||
metrics_addr: String::from("0.0.0.0:9914"),
|
||||
contact_info: String::from("scraping-ops@osmarks.net")
|
||||
});
|
||||
|
||||
serve_metrics(config.clone()).await?;
|
||||
|
||||
let timestamp_threshold = match config.mode {
|
||||
OperatingMode::Count => None,
|
||||
_ => {
|
||||
match readback_output(&config.output) {
|
||||
Ok(x) => Some(x),
|
||||
Err(e) => {
|
||||
log::warn!("could not read output: {}", e);
|
||||
tracing::warn!("could not read output: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
@ -237,19 +348,19 @@ async fn main() -> Result<()> {
|
||||
|
||||
|
||||
if let Some((threshold, count)) = timestamp_threshold {
|
||||
log::info!("threshold is {}, {} items", threshold, count);
|
||||
tracing::info!("threshold is {}, {} items", threshold, count);
|
||||
}
|
||||
|
||||
|
||||
let backend = get_backend_config(&config.backend).await;
|
||||
|
||||
log::info!("connected to inference server");
|
||||
tracing::info!("connected to inference server");
|
||||
|
||||
let (entries_tx, mut entries_rx) = mpsc::channel::<Entry>(32768);
|
||||
let (buffers_tx, buffers_rx) = mpsc::channel(128);
|
||||
let (resized_tx, resized_rx) = mpsc::channel(backend.batch);
|
||||
let (final_write_tx, final_write_rx) = mpsc::channel::<ProcessedEntry>(32768);
|
||||
let client = Client::builder()
|
||||
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")))
|
||||
.user_agent(format!("{}/{} (contact {})", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"), config.contact_info))
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
@ -278,11 +389,13 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
match fetch_file(client, config.clone(), &entry.url).await {
|
||||
Ok(buf) => {
|
||||
log::debug!("got {}", &entry.url);
|
||||
IMAGES_FETCHED_COUNTER.inc();
|
||||
tracing::debug!("got {}", &entry.url);
|
||||
buffers_tx.send((entry, buf)).await?;
|
||||
},
|
||||
Err(e) => {
|
||||
log::warn!("{} failed: {}", &entry.url, e)
|
||||
IMAGES_FAILED_COUNTER.inc();
|
||||
tracing::debug!("{} failed: {}", &entry.url, e)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@ -296,8 +409,10 @@ async fn main() -> Result<()> {
|
||||
_ => Some(tokio::task::spawn({
|
||||
let stream = ReceiverStream::new(buffers_rx);
|
||||
let backend = backend.clone();
|
||||
stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, buffer)| {
|
||||
stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, (buffer, mime_type, final_url))| {
|
||||
let backend = backend.clone();
|
||||
let size = buffer.len();
|
||||
IMAGE_FILESIZES_HISTOGRAM.with_label_values(&[&mime_type]).observe(size as f64);
|
||||
let resized_tx = resized_tx.clone();
|
||||
async move {
|
||||
let image_result = tokio::task::spawn_blocking(|| {
|
||||
@ -308,12 +423,20 @@ async fn main() -> Result<()> {
|
||||
let image = match image_result {
|
||||
Ok(image) => image,
|
||||
Err(e) => {
|
||||
log::warn!("loading {} failed: {}", entry.url, e);
|
||||
tracing::debug!("loading {} failed: {}", entry.url, e);
|
||||
return Result::<(), anyhow::Error>::Ok(());
|
||||
}
|
||||
};
|
||||
let dim = (image.width(), image.height());
|
||||
IMAGE_PIXELS_HISTOGRAM.with_label_values(&[&mime_type]).observe(dim.0 as f64 * dim.1 as f64);
|
||||
let metadata = OriginalImageMetadata {
|
||||
mime_type,
|
||||
original_file_size: size,
|
||||
dimension: dim,
|
||||
final_url
|
||||
};
|
||||
let resized = resize_for_embed(backend.clone(), image).await?;
|
||||
resized_tx.send((entry, resized)).await?;
|
||||
resized_tx.send((entry, resized, metadata)).await?;
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
@ -328,7 +451,7 @@ async fn main() -> Result<()> {
|
||||
let config = config.clone();
|
||||
// keep multiple embedding requests in flight
|
||||
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
|
||||
let (entries, bytes): (Vec<Entry>, Vec<Vec<u8>>) = batch.into_iter().unzip();
|
||||
let (entries, bytes, batch_dimensions): (Vec<Entry>, Vec<Vec<u8>>, Vec<OriginalImageMetadata>) = batch.into_iter().multiunzip();
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let final_write_tx = final_write_tx.clone();
|
||||
@ -341,17 +464,20 @@ async fn main() -> Result<()> {
|
||||
images: bytes.into_iter().map(serde_bytes::ByteBuf::from).collect(),
|
||||
},
|
||||
).await.context("querying CLIP server")?;
|
||||
|
||||
for (vector, entry) in result.into_iter().zip(entries) {
|
||||
|
||||
for (vector, entry,
|
||||
metadata) in itertools::izip!(result.into_iter(), entries, batch_dimensions) {
|
||||
final_write_tx.send(ProcessedEntry {
|
||||
url: entry.url,
|
||||
id: entry.id,
|
||||
title: entry.title,
|
||||
subreddit: entry.subreddit.unwrap(),
|
||||
author: entry.author.unwrap(),
|
||||
blob: vector.into_vec(),
|
||||
timestamp: entry.created_utc.to_u64()?
|
||||
embedding: vector.into_vec(),
|
||||
timestamp: entry.created_utc.to_u64()?,
|
||||
metadata
|
||||
}).await?;
|
||||
IMAGES_PROCESSED_COUNTER.inc();
|
||||
}
|
||||
anyhow::Result::Ok(())
|
||||
}
|
||||
@ -365,7 +491,7 @@ async fn main() -> Result<()> {
|
||||
_ => None
|
||||
};
|
||||
|
||||
log::info!("working...");
|
||||
tracing::info!("working...");
|
||||
|
||||
let mut paths = vec![];
|
||||
for file in fs::read_dir(&config.input)? {
|
||||
@ -381,36 +507,26 @@ async fn main() -> Result<()> {
|
||||
|
||||
let mut file_readers = JoinSet::new();
|
||||
|
||||
match config.mode {
|
||||
OperatingMode::Count | OperatingMode::Sample(_) => {
|
||||
let semaphore = Arc::new(Semaphore::new(cpus));
|
||||
let readers = match config.mode {
|
||||
OperatingMode::Count | OperatingMode::Sample(_) => cpus,
|
||||
OperatingMode::FullRun => 1
|
||||
};
|
||||
|
||||
for path in paths {
|
||||
let semaphore = semaphore.clone();
|
||||
let permit = semaphore.acquire_owned().await?;
|
||||
let entries_tx = entries_tx.clone();
|
||||
let path_ = path.clone();
|
||||
log::info!("reading {:?}", path);
|
||||
file_readers.spawn_blocking(move || {
|
||||
match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
|
||||
Ok(_) => (),
|
||||
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
|
||||
}
|
||||
std::mem::drop(permit);
|
||||
});
|
||||
let semaphore = Arc::new(Semaphore::new(readers));
|
||||
|
||||
for path in paths {
|
||||
let semaphore = semaphore.clone();
|
||||
let permit = semaphore.acquire_owned().await?;
|
||||
let entries_tx = entries_tx.clone();
|
||||
let path_ = path.clone();
|
||||
tracing::info!("reading {:?}", path);
|
||||
file_readers.spawn_blocking(move || {
|
||||
match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
|
||||
Ok(_) => (),
|
||||
Err(e) => tracing::error!("could not parse {:?} {:?}", &path, e)
|
||||
}
|
||||
},
|
||||
OperatingMode::FullRun => {
|
||||
for path in paths {
|
||||
let entries_tx = entries_tx.clone();
|
||||
let path_ = path.clone();
|
||||
log::info!("reading {:?}", path);
|
||||
file_readers.spawn_blocking(move || match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
|
||||
Ok(_) => (),
|
||||
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
|
||||
});
|
||||
}
|
||||
}
|
||||
std::mem::drop(permit);
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(x) = file_readers.try_join_next() {
|
||||
@ -419,9 +535,9 @@ async fn main() -> Result<()> {
|
||||
|
||||
std::mem::drop(entries_tx);
|
||||
println!("{:?}", load_task.await?);
|
||||
if let Some(task) = resize_task { println!("{:?}", task.await?); }
|
||||
if let Some(task) = embedding_generation_task { println!("{:?}", task.await?) };
|
||||
if let Some(task) = output_writer_task { println!("{:?}", task.await?) };
|
||||
if let Some(task) = resize_task { println!("resize: {:?}", task.await?); }
|
||||
if let Some(task) = embedding_generation_task { println!("embedding: {:?}", task.await?) };
|
||||
if let Some(task) = output_writer_task { println!("output: {:?}", task.await?) };
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user