1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-13 07:19:54 +00:00

improve observability and fix up Reddit dump for full-scale run

This commit is contained in:
osmarks 2024-11-02 19:38:05 +00:00
parent 1d0ff95955
commit 7fa14d45ae
7 changed files with 1394 additions and 891 deletions

1359
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -6,14 +6,13 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1", features = ["full"] }
tokio = { version = "1", features = ["full", "tracing"] }
axum = "0.7"
image = { version = "0.25", features = ["avif", "avif-native", "nasm"] }
reqwest = { version = "0.12", features = ["multipart"] }
serde = { version = "1", features = ["derive"] }
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
walkdir = "1"
log = "0.4"
rmp-serde = "1"
serde_json = "1"
chrono = "0.4"
@ -24,7 +23,8 @@ faiss = "0.12"
ndarray = "0.15"
half = { version = "2" }
regex = "1"
pretty_env_logger = "0.5"
tracing = "0.1"
console-subscriber = "0.4"
futures-util = "0.3"
tokio-stream = "0.1"
num_cpus = "1"
@ -41,9 +41,8 @@ mimalloc = "0.1"
sonic-rs = "0.3"
ffmpeg-the-third = "2.0"
compact_str = { version = "0.8.0-beta", features = ["serde"] }
[patch.crates-io]
image = { git = "https://github.com/fintelia/image/", branch = "upgrade-zune-jpeg" }
itertools = "0.13"
async-recursion = "1"
[[bin]]
name = "reddit-dump"
@ -52,3 +51,7 @@ path = "src/reddit_dump.rs"
[[bin]]
name = "video-reader"
path = "src/video_reader.rs"
[[bin]]
name = "dump-processor"
path = "src/dump_processor.rs"

View File

@ -4,6 +4,7 @@ use image::{DynamicImage, imageops::FilterType, ImageFormat};
use anyhow::Result;
use std::io::Cursor;
use reqwest::Client;
use tracing::instrument;
#[derive(Debug, Deserialize, Clone)]
pub struct InferenceServerConfig {
@ -13,11 +14,13 @@ pub struct InferenceServerConfig {
}
pub fn resize_for_embed_sync<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
let new = image.borrow().resize(
// the model currently in use wants aspect ratio 1:1 regardless of input
// I think this was previously being handled in the CLIP server but that is slightly lossy
let new = image.borrow().resize_exact(
config.image_size.0,
config.image_size.1,
FilterType::Lanczos3
);
FilterType::CatmullRom
).into_rgb8();
let mut buf = Vec::new();
let mut csr = Cursor::new(&mut buf);
new.write_to(&mut csr, ImageFormat::Png)?;
@ -46,13 +49,14 @@ pub async fn get_backend_config(clip_server: &str) -> InferenceServerConfig {
match fetch_backend_config(&clip_server).await {
Ok(backend) => break backend,
Err(e) => {
log::error!("Backend failed (fetch): {}", e);
tracing::error!("Backend failed (fetch): {}", e);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
}
}
#[instrument(skip(client, data))]
pub async fn query_clip_server<I, O>(client: &Client, base_url: &str, path: &str, data: I) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
{
let response = client

53
src/dump_processor.rs Normal file
View File

@ -0,0 +1,53 @@
use anyhow::{Result, Context};
use serde::{Serialize, Deserialize};
use std::io::BufReader;
use rmp_serde::decode::Error as DecodeError;
use std::fs;
// TODO refactor
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
struct OriginalImageMetadata {
mime_type: String,
original_file_size: usize,
dimension: (u32, u32),
final_url: String
}
#[derive(Clone, Deserialize, Serialize, Debug)]
struct ProcessedEntry {
url: String,
id: String,
title: String,
subreddit: String,
author: String,
timestamp: u64,
#[serde(with = "serde_bytes")]
embedding: Vec<u8>,
metadata: OriginalImageMetadata
}
fn main() -> Result<()> {
let path = std::env::args().nth(1).context("missing path")?;
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
let mut stream = BufReader::new(stream);
let mut latest_timestamp = 0;
let mut count = 0;
loop {
let res: Result<ProcessedEntry, DecodeError> = rmp_serde::from_read(&mut stream);
if res.is_ok() {
count += 1;
}
match res {
Ok(x) => {
if x.timestamp > latest_timestamp {
println!("{} {} https://reddit.com/r/{}/comments/{}", x.timestamp, count, x.subreddit, x.id);
latest_timestamp = x.timestamp;
}
},
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e).context("decode fail")
}
}
println!("{} {}", latest_timestamp, count);
Ok(())
}

View File

@ -33,6 +33,7 @@ use faiss::index::scalar_quantizer;
use lazy_static::lazy_static;
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
use ndarray::ArrayBase;
use tracing::instrument;
mod ocr;
mod common;
@ -249,7 +250,7 @@ async fn initialize_database(config: &Config) -> Result<SqlitePool> {
if (index as i32) < version {
continue
}
log::info!("Migrating to DB version {}", index);
tracing::info!("Migrating to DB version {}", index);
sqlx::query(sql).execute(&mut *tx).await?;
sqlx::query(&format!("PRAGMA user_version = {}", index + 1)).execute(&mut *tx).await?;
}
@ -317,6 +318,7 @@ fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
formats
}
#[instrument(skip_all)]
async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc: &Vec<u8>) -> Result<()> {
sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename_enc)
.execute(conn)
@ -324,6 +326,7 @@ async fn ensure_filename_record_exists(conn: &mut SqliteConnection, filename_enc
Ok(())
}
#[instrument(skip_all)]
async fn write_metadata(conn: &mut SqliteConnection, filename_enc: &Vec<u8>, metadata: FileMetadata) -> Result<()> {
ensure_filename_record_exists(conn, filename_enc).await?;
let metadata_serialized = rmp_serde::to_vec_named(&metadata)?;
@ -333,18 +336,276 @@ async fn write_metadata(conn: &mut SqliteConnection, filename_enc: &Vec<u8>, met
Ok(())
}
#[instrument]
async fn handle_embedding_batch(client: reqwest::Client, config: Arc<WConfig>, pool: SqlitePool, batch: Vec<EmbeddingInput>, video_embed_times: Arc<RwLock<HashMap<CompactString, i64>>>) -> Result<()> {
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
&client,
&config.service.clip_server,
"",
EmbeddingRequest::Images {
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
},
).await.context("querying CLIP server")?;
let mut tx = pool.begin().await?;
let ts = timestamp();
for (i, vector) in result.into_iter().enumerate() {
let vector = vector.into_vec();
tracing::debug!("embedded {:?}", batch[i].filename);
let encoded_filename = batch[i].filename.encode()?;
IMAGES_EMBEDDED_COUNTER.inc();
ensure_filename_record_exists(&mut *tx, &encoded_filename).await?;
match &batch[i].filename {
Filename::VideoFrame(container, _) => { video_embed_times.write().await.insert(container.clone(), timestamp()); },
_ => ()
}
sqlx::query!(
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
ts,
vector,
encoded_filename
)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
anyhow::Result::Ok(())
}
#[instrument(skip(to_embed_tx, to_thumbnail_tx, to_ocr_tx, to_metadata_write_tx, video_meta))]
async fn load_image(record: FileRecord, to_embed_tx: mpsc::Sender<EmbeddingInput>, to_thumbnail_tx: mpsc::Sender<LoadedImage>, to_ocr_tx: mpsc::Sender<LoadedImage>, to_metadata_write_tx: mpsc::Sender<(Filename, FileMetadata)>, config: Arc<WConfig>, video_meta: Arc<RwLock<HashMap<CompactString, FileMetadata>>>) -> Result<()> {
let path = Path::new(&config.service.files).join(&*record.filename);
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
let image = match image {
Ok(image) => image,
Err(e) => {
tracing::warn!("Could not read {} as image: {}", record.filename, e);
let filename = record.filename.clone();
IMAGES_LOADED_ERROR_COUNTER.inc();
let meta = tokio::task::spawn_blocking(move || -> Result<Option<FileMetadata>> {
let mut i: u32 = 0;
let mut last_metadata = None;
let callback = |frame: RgbImage| {
let frame: Arc<DynamicImage> = Arc::new(frame.into());
let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?;
let filename = Filename::VideoFrame(filename.clone(), i);
to_embed_tx.blocking_send(EmbeddingInput {
image: embed_buf,
filename: filename.clone()
})?;
let meta = FileMetadata {
height: frame.height(),
width: frame.width(),
frames: Some(i + 1)
};
last_metadata = Some(meta.clone());
to_metadata_write_tx.blocking_send((filename.clone(), meta))?;
if config.service.enable_thumbs {
to_thumbnail_tx.blocking_send(LoadedImage {
image: frame.clone(),
filename,
original_filesize: None,
fast_thumbnails_only: true
})?;
}
i += 1;
Ok(())
};
match video_reader::run(&path, callback, config.service.video_frame_interval) {
Ok(()) => {
VIDEOS_LOADED_COUNTER.inc();
return anyhow::Result::Ok(last_metadata)
},
Err(e) => {
tracing::error!("Could not read {} as video: {}", filename, e);
VIDEOS_LOADED_ERROR_COUNTER.inc();
}
}
return anyhow::Result::Ok(last_metadata)
}).await??;
if let Some(meta) = meta {
video_meta.write().await.insert(record.filename, meta);
}
return Ok(())
}
};
let filename = Filename::Actual(record.filename);
if record.needs_metadata {
let metadata = FileMetadata {
width: image.width(),
height: image.height(),
frames: None
};
to_metadata_write_tx.send((filename.clone(), metadata)).await?;
}
IMAGES_LOADED_COUNTER.inc();
if record.needs_embed {
let resized = resize_for_embed(config.backend.clone(), image.clone()).await?;
to_embed_tx.send(EmbeddingInput { image: resized, filename: filename.clone() }).await?
}
if record.needs_thumbnail {
to_thumbnail_tx
.send(LoadedImage {
image: image.clone(),
filename: filename.clone(),
original_filesize: Some(std::fs::metadata(&path)?.len() as usize),
fast_thumbnails_only: false
})
.await?;
}
if record.needs_ocr {
to_ocr_tx
.send(LoadedImage {
image,
filename: filename.clone(),
original_filesize: None,
fast_thumbnails_only: true
})
.await?;
}
Ok(())
}
#[instrument(skip(video_thumb_times, pool, formats))]
async fn generate_thumbnail(image: LoadedImage, config: Arc<WConfig>, video_thumb_times: Arc<RwLock<HashMap<CompactString, i64>>>, pool: SqlitePool, formats: Arc<HashMap<String, ImageFormatConfig>>) -> Result<()> {
use image::codecs::*;
let filename = image.filename.clone();
tracing::debug!("thumbnailing {:?}", filename);
let generated_formats = tokio::task::spawn_blocking(move || {
let mut generated_formats = Vec::new();
let rgb = DynamicImage::from(image.image.to_rgb8());
for (format_name, format_config) in &*formats {
if !format_config.is_fast && image.fast_thumbnails_only { continue }
let resized = if format_config.target_filesize != 0 {
let mut lb = 1;
let mut ub = 100;
loop {
let quality = (lb + ub) / 2;
let thumbnail = rgb.resize(
format_config.target_width.min(rgb.width()),
u32::MAX,
FilterType::Lanczos3,
);
let mut buf: Vec<u8> = Vec::new();
let mut csr = Cursor::new(&mut buf);
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
match format_config.format {
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
_ => unimplemented!()
}?;
if buf.len() > format_config.target_filesize {
ub = quality;
} else {
lb = quality + 1;
}
if lb >= ub {
break buf;
}
}
} else {
let thumbnail = rgb.resize(
format_config.target_width.min(rgb.width()),
u32::MAX,
FilterType::Lanczos3,
);
let mut buf: Vec<u8> = Vec::new();
let mut csr = Cursor::new(&mut buf);
match format_config.format {
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
_ => unimplemented!()
}?;
buf
};
if resized.len() < image.original_filesize.unwrap_or(usize::MAX) {
generated_formats.push(format_name.clone());
let thumbnail_path = Path::new(&config.service.thumbs_path).join(
generate_thumbnail_filename(
&image.filename,
format_name,
format_config,
),
);
THUMBNAILS_GENERATED_COUNTER.get_metric_with_label_values(&[format_name]).unwrap().inc();
std::fs::write(thumbnail_path, resized)?;
}
}
Ok::<Vec<String>, anyhow::Error>(generated_formats)
}).await??;
IMAGES_THUMBNAILED_COUNTER.inc();
let formats_data = rmp_serde::to_vec(&generated_formats)?;
let ts = timestamp();
let filename_enc = filename.encode()?;
let mut conn = pool.acquire().await?;
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
match filename {
Filename::VideoFrame(container, _) => { video_thumb_times.write().await.insert(container.clone(), timestamp()); },
_ => ()
}
sqlx::query!(
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
formats_data,
ts,
filename_enc
)
.execute(&mut *conn)
.await?;
Ok(())
}
#[instrument]
async fn do_ocr(image: LoadedImage, config: Arc<WConfig>, client: Client, pool: SqlitePool) -> Result<()> {
tracing::debug!("OCRing {:?}", image.filename);
let scan = match scan_image(&client, &image.image).await {
Ok(scan) => scan,
Err(e) => {
IMAGES_OCRED_ERROR_COUNTER.inc();
tracing::error!("OCR failure {:?}: {}", image.filename, e);
return Ok(())
}
};
IMAGES_OCRED_COUNTER.inc();
let ocr_text = scan
.iter()
.map(|segment| segment.text.clone())
.collect::<Vec<_>>()
.join("\n");
let ocr_data = rmp_serde::to_vec(&scan)?;
let ts = timestamp();
let filename_enc = image.filename.encode()?;
let mut conn = pool.acquire().await?;
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
sqlx::query!(
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
ocr_text,
ocr_data,
ts,
filename_enc
)
.execute(&mut *conn)
.await?;
Ok(())
}
#[instrument]
async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
let pool = initialize_database(&config.service).await?;
let client = Client::new();
let formats = image_formats(&config.service);
let (to_process_tx, to_process_rx) = mpsc::channel::<FileRecord>(100);
let (to_embed_tx, to_embed_rx) = mpsc::channel(config.backend.batch as usize);
let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30);
let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30);
let (to_metadata_write_tx, mut to_metadata_write_rx) = mpsc::channel::<(Filename, FileMetadata)>(100);
let cpus = num_cpus::get();
let video_meta = Arc::new(RwLock::new(HashMap::new()));
@ -363,102 +624,10 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
let to_ocr_tx = to_ocr_tx.clone();
let video_meta = video_meta.clone();
let to_metadata_write_tx = to_metadata_write_tx.clone();
async move {
let path = Path::new(&config.service.files).join(&*record.filename);
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
let image = match image {
Ok(image) => image,
Err(e) => {
log::warn!("Could not read {} as image: {}", record.filename, e);
let filename = record.filename.clone();
IMAGES_LOADED_ERROR_COUNTER.inc();
let meta = tokio::task::spawn_blocking(move || -> Result<Option<FileMetadata>> {
let mut i: u32 = 0;
let mut last_metadata = None;
let callback = |frame: RgbImage| {
let frame: Arc<DynamicImage> = Arc::new(frame.into());
let embed_buf = resize_for_embed_sync(config.backend.clone(), frame.clone())?;
let filename = Filename::VideoFrame(filename.clone(), i);
to_embed_tx.blocking_send(EmbeddingInput {
image: embed_buf,
filename: filename.clone()
})?;
let meta = FileMetadata {
height: frame.height(),
width: frame.width(),
frames: Some(i + 1)
};
last_metadata = Some(meta.clone());
to_metadata_write_tx.blocking_send((filename.clone(), meta))?;
if config.service.enable_thumbs {
to_thumbnail_tx.blocking_send(LoadedImage {
image: frame.clone(),
filename,
original_filesize: None,
fast_thumbnails_only: true
})?;
}
i += 1;
Ok(())
};
match video_reader::run(&path, callback, config.service.video_frame_interval) {
Ok(()) => {
VIDEOS_LOADED_COUNTER.inc();
return anyhow::Result::Ok(last_metadata)
},
Err(e) => {
log::error!("Could not read {} as video: {}", filename, e);
VIDEOS_LOADED_ERROR_COUNTER.inc();
}
}
return anyhow::Result::Ok(last_metadata)
}).await??;
if let Some(meta) = meta {
video_meta.write().await.insert(record.filename, meta);
}
return Ok(())
}
};
let filename = Filename::Actual(record.filename);
if record.needs_metadata {
let metadata = FileMetadata {
width: image.width(),
height: image.height(),
frames: None
};
to_metadata_write_tx.send((filename.clone(), metadata)).await?;
}
IMAGES_LOADED_COUNTER.inc();
if record.needs_embed {
let resized = resize_for_embed(config.backend.clone(), image.clone()).await?;
to_embed_tx.send(EmbeddingInput { image: resized, filename: filename.clone() }).await?
}
if record.needs_thumbnail {
to_thumbnail_tx
.send(LoadedImage {
image: image.clone(),
filename: filename.clone(),
original_filesize: Some(std::fs::metadata(&path)?.len() as usize),
fast_thumbnails_only: false
})
.await?;
}
if record.needs_ocr {
to_ocr_tx
.send(LoadedImage {
image,
filename: filename.clone(),
original_filesize: None,
fast_thumbnails_only: true
})
.await?;
}
Ok(())
}
load_image(record, to_embed_tx, to_thumbnail_tx, to_ocr_tx, to_metadata_write_tx, config, video_meta)
})
});
let metadata_writer: JoinHandle<Result<()>> = tokio::spawn({
let pool = pool.clone();
async move {
@ -468,7 +637,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
Ok(())
}
});
let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.service.enable_thumbs {
let config = config.clone();
let pool = pool.clone();
@ -477,145 +646,29 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
let video_thumb_times = video_thumb_times.clone();
Some(tokio::spawn({
stream.try_for_each_concurrent(Some(cpus), move |image| {
use image::codecs::*;
let formats = formats.clone();
let config = config.clone();
let pool = pool.clone();
let video_thumb_times = video_thumb_times.clone();
async move {
let filename = image.filename.clone();
log::debug!("thumbnailing {:?}", filename);
let generated_formats = tokio::task::spawn_blocking(move || {
let mut generated_formats = Vec::new();
let rgb = DynamicImage::from(image.image.to_rgb8());
for (format_name, format_config) in &*formats {
if !format_config.is_fast && image.fast_thumbnails_only { continue }
let resized = if format_config.target_filesize != 0 {
let mut lb = 1;
let mut ub = 100;
loop {
let quality = (lb + ub) / 2;
let thumbnail = rgb.resize(
format_config.target_width.min(rgb.width()),
u32::MAX,
FilterType::Lanczos3,
);
let mut buf: Vec<u8> = Vec::new();
let mut csr = Cursor::new(&mut buf);
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
match format_config.format {
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
_ => unimplemented!()
}?;
if buf.len() > format_config.target_filesize {
ub = quality;
} else {
lb = quality + 1;
}
if lb >= ub {
break buf;
}
}
} else {
let thumbnail = rgb.resize(
format_config.target_width.min(rgb.width()),
u32::MAX,
FilterType::Lanczos3,
);
let mut buf: Vec<u8> = Vec::new();
let mut csr = Cursor::new(&mut buf);
match format_config.format {
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
_ => unimplemented!()
}?;
buf
};
if resized.len() < image.original_filesize.unwrap_or(usize::MAX) {
generated_formats.push(format_name.clone());
let thumbnail_path = Path::new(&config.service.thumbs_path).join(
generate_thumbnail_filename(
&image.filename,
format_name,
format_config,
),
);
THUMBNAILS_GENERATED_COUNTER.get_metric_with_label_values(&[format_name]).unwrap().inc();
std::fs::write(thumbnail_path, resized)?;
}
}
Ok::<Vec<String>, anyhow::Error>(generated_formats)
}).await??;
IMAGES_THUMBNAILED_COUNTER.inc();
let formats_data = rmp_serde::to_vec(&generated_formats)?;
let ts = timestamp();
let filename_enc = filename.encode()?;
let mut conn = pool.acquire().await?;
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
match filename {
Filename::VideoFrame(container, _) => { video_thumb_times.write().await.insert(container.clone(), timestamp()); },
_ => ()
}
sqlx::query!(
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
formats_data,
ts,
filename_enc
)
.execute(&mut *conn)
.await?;
Ok(())
}
generate_thumbnail(image, config, video_thumb_times, pool, formats)
})
}))
} else {
None
};
// TODO: save OCR errors and don't retry
let ocr: Option<JoinHandle<Result<()>>> = if config.service.enable_ocr {
let client = client.clone();
let pool = pool.clone();
let config = config.clone();
let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
Some(tokio::spawn({
stream.try_for_each_concurrent(Some(config.service.ocr_concurrency), move |image| {
let client = client.clone();
let pool = pool.clone();
async move {
log::debug!("OCRing {:?}", image.filename);
let scan = match scan_image(&client, &image.image).await {
Ok(scan) => scan,
Err(e) => {
IMAGES_OCRED_ERROR_COUNTER.inc();
log::error!("OCR failure {:?}: {}", image.filename, e);
return Ok(())
}
};
IMAGES_OCRED_COUNTER.inc();
let ocr_text = scan
.iter()
.map(|segment| segment.text.clone())
.collect::<Vec<_>>()
.join("\n");
let ocr_data = rmp_serde::to_vec(&scan)?;
let ts = timestamp();
let filename_enc = image.filename.encode()?;
let mut conn = pool.acquire().await?;
ensure_filename_record_exists(&mut conn, &filename_enc).await?;
sqlx::query!(
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
ocr_text,
ocr_data,
ts,
filename_enc
)
.execute(&mut *conn)
.await?;
Ok(())
}
let config = config.clone();
do_ocr(image, config, client, pool)
})
}))
} else {
@ -634,45 +687,12 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
let config = config.clone();
let pool = pool.clone();
let video_embed_times = video_embed_times.clone();
async move {
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
&client,
&config.service.clip_server,
"",
EmbeddingRequest::Images {
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
},
).await.context("querying CLIP server")?;
let mut tx = pool.begin().await?;
let ts = timestamp();
for (i, vector) in result.into_iter().enumerate() {
let vector = vector.into_vec();
log::debug!("embedded {:?}", batch[i].filename);
let encoded_filename = batch[i].filename.encode()?;
IMAGES_EMBEDDED_COUNTER.inc();
ensure_filename_record_exists(&mut *tx, &encoded_filename).await?;
match &batch[i].filename {
Filename::VideoFrame(container, _) => { video_embed_times.write().await.insert(container.clone(), timestamp()); },
_ => ()
}
sqlx::query!(
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
ts,
vector,
encoded_filename
)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
anyhow::Result::Ok(())
}
handle_embedding_batch(client, config, pool, batch, video_embed_times)
})
});
let mut actual_filenames = HashMap::new();
// blocking OS calls
tokio::task::block_in_place(|| -> anyhow::Result<()> {
for entry in WalkDir::new(config.service.files.as_str()) {
@ -688,7 +708,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
Ok(())
})?;
log::debug!("finished reading filenames");
tracing::debug!("finished reading filenames");
for (filename, (_path, modtime)) in actual_filenames.iter() {
let modtime = *modtime;
@ -721,7 +741,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
}
};
if let Some(record) = new_record {
log::debug!("processing {}", record.filename);
tracing::debug!("processing {}", record.filename);
// we need to exit here to actually capture the error
if !to_process_tx.send(record).await.is_ok() {
break
@ -730,20 +750,20 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
}
drop(to_process_tx);
embedding_generation.await?.context("generating embeddings")?;
metadata_writer.await?.context("writing metadata")?;
if let Some(thumbnail_generation) = thumbnail_generation {
thumbnail_generation.await?.context("generating thumbnails")?;
}
if let Some(ocr) = ocr {
ocr.await?.context("OCRing")?;
}
image_loading.await?.context("loading images")?;
let stored: Vec<Vec<u8>> = sqlx::query_scalar("SELECT filename FROM files").fetch_all(&pool).await?;
let mut tx = pool.begin().await?;
let video_meta = video_meta.read().await;
@ -785,13 +805,14 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
tx.commit().await?;
log::info!("Ingest done");
tracing::info!("Ingest done");
Result::Ok(())
}
const INDEX_ADD_BATCH: usize = 512;
#[instrument]
async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
let pool = initialize_database(&config.service).await?;
@ -846,7 +867,7 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
} else {
index.metadata.push(None);
}
for format_string in &formats {
let mut found = false;
for (i, name) in index.format_names.iter().enumerate() {
@ -904,6 +925,7 @@ struct QueryRequest {
include_video: bool
}
#[instrument(skip(index))]
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
let result = index.vectors.search(&query, k as usize)?;
@ -940,6 +962,7 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
})
}
#[instrument(skip(config, client, index))]
async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; config.backend.embedding_size]);
@ -973,8 +996,8 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
}
}
let mut batches = vec![];
let mut batches = vec![];
if !image_batch.is_empty() {
batches.push(
@ -1016,12 +1039,13 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
#[derive(Serialize, Deserialize)]
struct FrontendInit {
n_total: u64,
predefined_embedding_names: Vec<String>
predefined_embedding_names: Vec<String>,
d_emb: usize
}
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();
console_subscriber::init();
let config_path = std::env::args().nth(1).expect("Missing config file path");
let config: Config = serde_json::from_slice(&std::fs::read(config_path)?)?;
@ -1062,23 +1086,23 @@ async fn main() -> Result<()> {
let index = index.clone();
async move {
loop {
log::info!("Ingest running");
tracing::info!("Ingest running");
match ingest_files(config.clone()).await {
Ok(_) => {
match build_index(config.clone()).await {
Ok(new_index) => {
LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64);
*index.write().await = new_index;
log::info!("Index loaded");
tracing::info!("Index loaded");
}
Err(e) => {
log::error!("Index build failed: {:?}", e);
tracing::error!("Index build failed: {:?}", e);
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
}
}
}
Err(e) => {
log::error!("Ingest failed: {:?}", e);
tracing::error!("Ingest failed: {:?}", e);
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
}
}
@ -1106,11 +1130,12 @@ async fn main() -> Result<()> {
.route("/", get(|_req: ()| async move {
Json(FrontendInit {
n_total: index_.read().await.vectors.ntotal(),
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect()
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect(),
d_emb: config__.backend.embedding_size
})
}))
.route("/reload", post(|_req: ()| async move {
log::info!("Requesting index reload");
tracing::info!("Requesting index reload");
let mut done_rx = done_tx.clone().subscribe();
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
match done_rx.recv().await {
@ -1141,9 +1166,9 @@ async fn main() -> Result<()> {
.layer(cors);
let addr = format!("0.0.0.0:{}", config_.service.port);
log::info!("Starting server on {}", addr);
tracing::info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await?;
Ok(())
}
}

View File

@ -9,6 +9,7 @@ use reqwest::{
use serde_json::Value;
use std::{io::Cursor, time::{SystemTime, UNIX_EPOCH}};
use serde::{Deserialize, Serialize};
use tracing::instrument;
const CALLBACK_REGEX: &str = r">AF_initDataCallback\((\{key: 'ds:1'.*?\})\);</script>";
const MAX_DIM: u32 = 1024;
@ -45,6 +46,7 @@ fn rationalize_coords_format1(
}
}
#[instrument(skip(client, image))]
async fn scan_image_chunk(
client: &Client,
image: &[u8],
@ -130,13 +132,14 @@ async fn scan_image_chunk(
.collect())
}
#[instrument(skip(client))]
pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanResult> {
let mut result = ScanResult::new();
let (width, height) = image.dimensions();
let (width, height, image) = if width > MAX_DIM {
let height = ((height as f64) * (MAX_DIM as f64) / (width as f64)).round() as u32;
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::Lanczos3));
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::CatmullRom));
(MAX_DIM, height, std::borrow::Cow::Owned(new_image))
} else {
(width, height, std::borrow::Cow::Borrowed(image))
@ -170,4 +173,4 @@ pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanRes
}
Ok(result)
}
}

View File

@ -1,15 +1,18 @@
use anyhow::{anyhow, Context, Result};
use common::resize_for_embed;
use itertools::Itertools;
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, time::Duration, sync::Arc, str::FromStr, path::PathBuf};
use serde::{Serialize, Deserialize};
use lazy_static::lazy_static;
use regex::{RegexSet, bytes, Regex};
use regex::{bytes, Regex, RegexSet, RegexSetBuilder};
use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}};
use tokio_stream::wrappers::ReceiverStream;
use reqwest::Client;
use futures_util::stream::{StreamExt, TryStreamExt};
use image::{DynamicImage, io::Reader as ImageReader};
use image::{DynamicImage, ImageReader};
use mimalloc::MiMalloc;
use tracing::instrument;
use prometheus::{Encoder, register_int_counter, IntCounter, register_histogram_vec, HistogramVec};
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
@ -50,6 +53,14 @@ struct Entry {
id: String
}
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
struct OriginalImageMetadata {
mime_type: String,
original_file_size: usize,
dimension: (u32, u32),
final_url: String
}
#[derive(Clone, Deserialize, Serialize, Debug)]
struct ProcessedEntry {
url: String,
@ -58,10 +69,13 @@ struct ProcessedEntry {
subreddit: String,
author: String,
timestamp: u64,
blob: Vec<u8>
#[serde(with = "serde_bytes")]
embedding: Vec<u8>,
metadata: OriginalImageMetadata
}
lazy_static! {
// we do exclude galleries doing this but there don't seem to be any in the dataset
static ref URL_IGNORE: RegexSet = RegexSet::new([
r"//reddit\.com",
r"\.html?",
@ -69,16 +83,39 @@ lazy_static! {
r"\?articleid=",
r"\.aspx?",
r"\.xml",
r"//youtube\.com",
r"/rss/",
r"//vimeo\.com",
r"//www\.youtube\.com",
r"//youtu\.be",
r"//www\.reddit\.com",
r"//v\.redd\.it",
r"\.gifv$",
r"youtube\.com/user/"
// TODO fill in more things, maybe try and collect thumbnails or something
]).unwrap();
static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
.into_iter().map(str::as_bytes).collect();
static ref URL_MUST_CONTAIN: RegexSet = RegexSetBuilder::new([
"jpg",
"jpeg",
"png",
"webp",
r"\.gif",
"=gif",
"jpeg",
"bmp",
"tiff",
"avif",
"webp",
"imgur",
"image",
r"//i\.",
"img",
r"cdn\.",
r"media\.",
"/i/",
"/media",
r"youtu\.be",
r"youtube\.com",
]).case_insensitive(true).build().unwrap();
static ref ACCEPTABLE_FILETYPES: HashSet<&'static str> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
.into_iter().collect();
static ref OBJECT_HACKY_IGNORE: bytes::RegexSet = bytes::RegexSet::new([
r#""author":"\[deleted\]""#,
r#""promoted":true"#, // these seem to be ads which are in the data for some reason, and lack some important fields
@ -86,11 +123,34 @@ lazy_static! {
r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record
]).unwrap();
static ref URL_REPLACEMENT_RULES: Vec<(Regex, &'static str)> = [
(r"//imgur.com/([A-Za-z0-9]+)", r"//i.imgur.com/$1.jpg"),
(r"^http://", r"https://")
(r"imgur\.com/([A-Za-z0-9]+),", r"imgur.com/$1"),
(r"//imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
(r"//www\.imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
(r"//m\.imgur\.com/([A-Za-z0-9]+)$", r"//i.imgur.com/$1.jpg"),
(r"^http://", r"https://"),
(r"//youtu\.be/(.*)", r"//youtube.com/watch?v=$1"),
(r"//[a-z]+\.youtube\.com/(.*)", r"//youtube.com/$1"),
(r"//www.youtube.com/attribution_link?.*v%3D([A-Za-z0-9_-]+).*", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"), // redirect to youtube thumbnail API
(r"//youtube.com/embed/([A-Za-z0-9_-]+)", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"),
(r"//youtube\.com/(?:.*)v=([A-Za-z0-9_-]+)(?:.*)", r"//i.ytimg.com/vi/$1/maxresdefault.jpg"),
(r"&amp;", "&") // this is such an intensely cursed feature of the dumps
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), e)).collect();
static ref HTML_EXTRACTION_RULES: Vec<(Regex, Regex)> = [
(r"//imgur\.com/a/[A-Za-z0-9]+", r#"<meta name="twitter:image" data-react-helmet="true" content="([^"]+)">"#),
(r"//imgur\.com/gallery/[A-Za-z0-9]+", r#"<meta name="twitter:image" data-react-helmet="true" content="([^"]+)">"#),
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), Regex::new(e).unwrap())).collect();
static ref IMAGES_FETCHED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_fetched", "images fetched").unwrap();
static ref IMAGES_PROCESSED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_processed", "images processed").unwrap();
static ref ENTRIES_PROCESSED_COUNTER: IntCounter = register_int_counter!("mse_scrape_entries_processed", "entries processed").unwrap();
static ref IMAGES_FAILED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_failed", "images failed").unwrap();
static ref IMAGE_FILESIZES_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_filesizes", "filesizes of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.5, 29).unwrap()).unwrap();
static ref IMAGE_PIXELS_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_pixels", "pixel count of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.3, 53).unwrap()).unwrap();
static ref HTML_EXTRACTS_COUNTER: IntCounter = register_int_counter!("mse_scrape_html_extracts", "html extraction operations").unwrap();
}
#[instrument(skip(tx))]
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Option<u64>) -> Result<()> {
let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
stream.window_log_max(31)?;
@ -105,15 +165,16 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
buf.clear();
continue;
}
ENTRIES_PROCESSED_COUNTER.inc();
let entry = match sonic_rs::serde::from_slice::<Entry>(buf.as_slice()) {
Ok(x) => x,
Err(e) => {
log::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf));
tracing::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf));
return Ok(())
}
};
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() && entry.subreddit.is_some() {
if !URL_IGNORE.is_match(&entry.url) {
if !URL_IGNORE.is_match(&entry.url) && URL_MUST_CONTAIN.is_match(&entry.url) {
match &entry.post_hint {
Some(x) if x == "na" || x == "image" => {
// Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine".
@ -127,7 +188,7 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
},
None => true
};
if after_threshold { tx.blocking_send(entry)?; }
},
_ => ()
@ -139,23 +200,38 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
Ok(())
}
#[derive(Debug)]
struct Config {
max_content_length: usize,
input: String,
output: String,
backend: String,
mode: OperatingMode,
filename_threshold: Option<String>
filename_threshold: Option<String>,
metrics_addr: String,
contact_info: String
}
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<Vec<u8>> {
// inelegant but I can't get it to work using Cows
#[instrument(skip(client, config))]
#[async_recursion::async_recursion]
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<(Vec<u8>, String, String)> {
let mut url = url.to_string();
for (regex, replacement) in URL_REPLACEMENT_RULES.iter() {
url = regex.replace(&url, *replacement).to_string();
}
let mut html_extract_rule = None;
for (url_rule, extract_rule) in HTML_EXTRACTION_RULES.iter() {
if url_rule.is_match(&url) {
html_extract_rule = Some(extract_rule);
break;
}
}
let mut response = client.get(&*url).send().await?;
if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) {
let content_type = std::str::from_utf8(&response.headers().get(reqwest::header::CONTENT_TYPE).context("no content type")?.as_bytes())?.to_owned();
if !(ACCEPTABLE_FILETYPES.contains(&content_type[..]) || (html_extract_rule.is_some() && content_type == "text/html")) {
return Err(anyhow!("invalid Content-Type"));
}
match response.content_length() {
@ -169,11 +245,24 @@ async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) ->
return Err(anyhow!("response too large"));
}
}
Ok(buffer)
if let Some(extract_rule) = html_extract_rule {
if content_type == "text/html" {
let buffer = String::from_utf8_lossy(&buffer).to_string();
if let Some(mat) = extract_rule.captures(&buffer) {
let new_url = mat.get(1).unwrap().as_str();
HTML_EXTRACTS_COUNTER.inc();
tracing::debug!("found new URL: {}", new_url);
return fetch_file(client, config, new_url).await;
} else {
return Err(anyhow!("no extraction match"));
}
}
}
Ok((buffer, content_type, response.url().to_string()))
}
fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result<()> {
let mut out = fs::File::options().append(true).open(&config.output)?;
let mut out = fs::File::options().create(true).append(true).open(&config.output)?;
let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish();
let mut buf_stream = BufWriter::new(stream);
while let Some(x) = rx.blocking_recv() {
@ -182,12 +271,14 @@ fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result
Ok(())
}
#[derive(Debug)]
enum OperatingMode {
Count,
Sample(f32),
FullRun
}
#[instrument]
fn readback_output(path: &str) -> Result<(u64, usize)> {
use rmp_serde::decode::Error;
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
@ -208,27 +299,47 @@ fn readback_output(path: &str) -> Result<(u64, usize)> {
Ok((latest_timestamp, count))
}
async fn serve_metrics(config: Arc<Config>) -> Result<()> {
let metrics = axum::Router::new().route("/metrics", axum::routing::get(|| async move {
let mut buffer = Vec::new();
let encoder = prometheus::TextEncoder::new();
let metric_families = prometheus::gather();
encoder.encode(&metric_families, &mut buffer).unwrap();
buffer
}));
let listener = tokio::net::TcpListener::bind(&config.metrics_addr).await?;
tokio::task::spawn(async move {
let _ = axum::serve(listener, metrics).await;
});
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();
console_subscriber::init();
let cpus = num_cpus::get();
let config = Arc::new(Config {
max_content_length: 1<<23,
input: String::from("./submissions"),
max_content_length: 1<<24,
input: String::from("./reddit_subs_202212/"),
output: String::from("./sample.zst"),
backend: String::from("http://localhost:1708"),
mode: OperatingMode::Sample(0.004),
filename_threshold: None
mode: OperatingMode::FullRun,
filename_threshold: None,
metrics_addr: String::from("0.0.0.0:9914"),
contact_info: String::from("scraping-ops@osmarks.net")
});
serve_metrics(config.clone()).await?;
let timestamp_threshold = match config.mode {
OperatingMode::Count => None,
_ => {
match readback_output(&config.output) {
Ok(x) => Some(x),
Err(e) => {
log::warn!("could not read output: {}", e);
tracing::warn!("could not read output: {}", e);
None
}
}
@ -237,19 +348,19 @@ async fn main() -> Result<()> {
if let Some((threshold, count)) = timestamp_threshold {
log::info!("threshold is {}, {} items", threshold, count);
tracing::info!("threshold is {}, {} items", threshold, count);
}
let backend = get_backend_config(&config.backend).await;
log::info!("connected to inference server");
tracing::info!("connected to inference server");
let (entries_tx, mut entries_rx) = mpsc::channel::<Entry>(32768);
let (buffers_tx, buffers_rx) = mpsc::channel(128);
let (resized_tx, resized_rx) = mpsc::channel(backend.batch);
let (final_write_tx, final_write_rx) = mpsc::channel::<ProcessedEntry>(32768);
let client = Client::builder()
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")))
.user_agent(format!("{}/{} (contact {})", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"), config.contact_info))
.timeout(Duration::from_secs(30))
.build()?;
@ -278,11 +389,13 @@ async fn main() -> Result<()> {
}
match fetch_file(client, config.clone(), &entry.url).await {
Ok(buf) => {
log::debug!("got {}", &entry.url);
IMAGES_FETCHED_COUNTER.inc();
tracing::debug!("got {}", &entry.url);
buffers_tx.send((entry, buf)).await?;
},
Err(e) => {
log::warn!("{} failed: {}", &entry.url, e)
IMAGES_FAILED_COUNTER.inc();
tracing::debug!("{} failed: {}", &entry.url, e)
}
}
Ok(())
@ -296,8 +409,10 @@ async fn main() -> Result<()> {
_ => Some(tokio::task::spawn({
let stream = ReceiverStream::new(buffers_rx);
let backend = backend.clone();
stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, buffer)| {
stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, (buffer, mime_type, final_url))| {
let backend = backend.clone();
let size = buffer.len();
IMAGE_FILESIZES_HISTOGRAM.with_label_values(&[&mime_type]).observe(size as f64);
let resized_tx = resized_tx.clone();
async move {
let image_result = tokio::task::spawn_blocking(|| {
@ -308,12 +423,20 @@ async fn main() -> Result<()> {
let image = match image_result {
Ok(image) => image,
Err(e) => {
log::warn!("loading {} failed: {}", entry.url, e);
tracing::debug!("loading {} failed: {}", entry.url, e);
return Result::<(), anyhow::Error>::Ok(());
}
};
let dim = (image.width(), image.height());
IMAGE_PIXELS_HISTOGRAM.with_label_values(&[&mime_type]).observe(dim.0 as f64 * dim.1 as f64);
let metadata = OriginalImageMetadata {
mime_type,
original_file_size: size,
dimension: dim,
final_url
};
let resized = resize_for_embed(backend.clone(), image).await?;
resized_tx.send((entry, resized)).await?;
resized_tx.send((entry, resized, metadata)).await?;
Ok(())
}
})
@ -328,7 +451,7 @@ async fn main() -> Result<()> {
let config = config.clone();
// keep multiple embedding requests in flight
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
let (entries, bytes): (Vec<Entry>, Vec<Vec<u8>>) = batch.into_iter().unzip();
let (entries, bytes, batch_dimensions): (Vec<Entry>, Vec<Vec<u8>>, Vec<OriginalImageMetadata>) = batch.into_iter().multiunzip();
let client = client.clone();
let config = config.clone();
let final_write_tx = final_write_tx.clone();
@ -341,17 +464,20 @@ async fn main() -> Result<()> {
images: bytes.into_iter().map(serde_bytes::ByteBuf::from).collect(),
},
).await.context("querying CLIP server")?;
for (vector, entry) in result.into_iter().zip(entries) {
for (vector, entry,
metadata) in itertools::izip!(result.into_iter(), entries, batch_dimensions) {
final_write_tx.send(ProcessedEntry {
url: entry.url,
id: entry.id,
title: entry.title,
subreddit: entry.subreddit.unwrap(),
author: entry.author.unwrap(),
blob: vector.into_vec(),
timestamp: entry.created_utc.to_u64()?
embedding: vector.into_vec(),
timestamp: entry.created_utc.to_u64()?,
metadata
}).await?;
IMAGES_PROCESSED_COUNTER.inc();
}
anyhow::Result::Ok(())
}
@ -365,7 +491,7 @@ async fn main() -> Result<()> {
_ => None
};
log::info!("working...");
tracing::info!("working...");
let mut paths = vec![];
for file in fs::read_dir(&config.input)? {
@ -381,36 +507,26 @@ async fn main() -> Result<()> {
let mut file_readers = JoinSet::new();
match config.mode {
OperatingMode::Count | OperatingMode::Sample(_) => {
let semaphore = Arc::new(Semaphore::new(cpus));
let readers = match config.mode {
OperatingMode::Count | OperatingMode::Sample(_) => cpus,
OperatingMode::FullRun => 1
};
for path in paths {
let semaphore = semaphore.clone();
let permit = semaphore.acquire_owned().await?;
let entries_tx = entries_tx.clone();
let path_ = path.clone();
log::info!("reading {:?}", path);
file_readers.spawn_blocking(move || {
match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
Ok(_) => (),
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
}
std::mem::drop(permit);
});
let semaphore = Arc::new(Semaphore::new(readers));
for path in paths {
let semaphore = semaphore.clone();
let permit = semaphore.acquire_owned().await?;
let entries_tx = entries_tx.clone();
let path_ = path.clone();
tracing::info!("reading {:?}", path);
file_readers.spawn_blocking(move || {
match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
Ok(_) => (),
Err(e) => tracing::error!("could not parse {:?} {:?}", &path, e)
}
},
OperatingMode::FullRun => {
for path in paths {
let entries_tx = entries_tx.clone();
let path_ = path.clone();
log::info!("reading {:?}", path);
file_readers.spawn_blocking(move || match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
Ok(_) => (),
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
});
}
}
std::mem::drop(permit);
});
}
while let Some(x) = file_readers.try_join_next() {
@ -419,9 +535,9 @@ async fn main() -> Result<()> {
std::mem::drop(entries_tx);
println!("{:?}", load_task.await?);
if let Some(task) = resize_task { println!("{:?}", task.await?); }
if let Some(task) = embedding_generation_task { println!("{:?}", task.await?) };
if let Some(task) = output_writer_task { println!("{:?}", task.await?) };
if let Some(task) = resize_task { println!("resize: {:?}", task.await?); }
if let Some(task) = embedding_generation_task { println!("embedding: {:?}", task.await?) };
if let Some(task) = output_writer_task { println!("output: {:?}", task.await?) };
Ok(())
}
}