From f8d68d9d540dc720ae235e723206f6aa2fae2260 Mon Sep 17 00:00:00 2001 From: osmarks Date: Fri, 24 May 2024 17:47:18 +0100 Subject: [PATCH] WIP Reddit dump loader --- Cargo.lock | 134 ++++++++++++++++- Cargo.toml | 11 +- src/common.rs | 64 ++++++++ src/main.rs | 71 +-------- src/reddit_dump.rs | 359 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 571 insertions(+), 68 deletions(-) create mode 100644 src/common.rs create mode 100644 src/reddit_dump.rs diff --git a/Cargo.lock b/Cargo.lock index 1399cdc..a6fc4e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,6 +89,12 @@ dependencies = [ "syn 2.0.65", ] +[[package]] +name = "arrayref" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" + [[package]] name = "arrayvec" version = "0.7.4" @@ -662,6 +668,17 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +[[package]] +name = "faststr" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f375fcf41ec4dac873a8028fba4210dbda5c86bba13d2d741e651b474f7c05a4" +dependencies = [ + "bytes", + "serde", + "simdutf8", +] + [[package]] name = "fdeflate" version = "0.3.4" @@ -1238,7 +1255,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d" dependencies = [ - "winapi", + "winapi 0.2.8", "winapi-build", ] @@ -1280,6 +1297,16 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +[[package]] +name = "libmimalloc-sys" +version = "0.1.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "libsqlite3-sys" version = "0.27.0" @@ -1373,6 +1400,7 @@ dependencies = [ "base64 0.22.1", "chrono", "faiss", + "fastrand", "fnv", "futures-util", "half", @@ -1380,6 +1408,7 @@ dependencies = [ "json5", "lazy_static", "log", + "mimalloc", "ndarray", "num_cpus", "pretty_env_logger", @@ -1390,12 +1419,24 @@ dependencies = [ "serde", "serde_bytes", "serde_json", + "sonic-rs", "sqlx", "tokio", "tokio-stream", "tower", "tower-http", + "url", "walkdir", + "zstd", +] + +[[package]] +name = "mimalloc" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176" +dependencies = [ + "libmimalloc-sys", ] [[package]] @@ -1665,6 +1706,16 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi 0.3.9", +] + [[package]] name = "parking_lot" version = "0.12.2" @@ -2216,7 +2267,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d931a44fdaa43b8637009e7632a02adc4f2b2e0733c08caa4cf00e8da4a117a7" dependencies = [ "kernel32-sys", - "winapi", + "winapi 0.2.8", ] [[package]] @@ -2384,6 +2435,12 @@ dependencies = [ "quote", ] +[[package]] +name = "simdutf8" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" + [[package]] name = "slab" version = "0.4.9" @@ -2409,6 +2466,27 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "sonic-rs" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "244d3cdf9dd4e2a5c63991a6ed4ecd768959204d5e4a839181ee997e8c149407" +dependencies = [ + "arrayref", + "bumpalo", + "bytes", + "cfg-if", + "faststr", + "itoa", + "page_size", + "parking_lot", + "ryu", + "serde", + "simdutf8", + "smallvec", + "thiserror", +] + [[package]] name = "spin" version = "0.5.2" @@ -3090,7 +3168,7 @@ checksum = "bb08f9e670fab86099470b97cd2b252d6527f0b3cc1401acdb595ffc9dd288ff" dependencies = [ "kernel32-sys", "same-file", - "winapi", + "winapi 0.2.8", ] [[package]] @@ -3212,12 +3290,28 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "167dc9d6949a9b857f3451275e911c3f44255842c1f7a76f33c55103a909087a" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + [[package]] name = "winapi-build" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d315eee3b34aca4797b2da6b13ed88266e6d612562a0c46390af8299fc699bc" +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.8" @@ -3227,6 +3321,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.52.0" @@ -3420,6 +3520,34 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +[[package]] +name = "zstd" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.10+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "zune-core" version = "0.4.12" diff --git a/Cargo.toml b/Cargo.toml index 10311e2..5b3068e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,4 +33,13 @@ tower-http = { version = "0.5", features = ["cors"] } tower = "0.4" json5 = "0.4" prometheus = "0.13" -lazy_static = "1" \ No newline at end of file +lazy_static = "1" +zstd = "0.13" +url = "2" +fastrand = "2" +mimalloc = "0.1" +sonic-rs = "0.3" + +[[bin]] +name = "reddit-dump" +path = "src/reddit_dump.rs" \ No newline at end of file diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..86e705e --- /dev/null +++ b/src/common.rs @@ -0,0 +1,64 @@ +use serde::{Serialize, Deserialize}; +use std::borrow::Borrow; +use image::{DynamicImage, imageops::FilterType, ImageFormat}; +use anyhow::Result; +use std::io::Cursor; +use reqwest::Client; + +#[derive(Debug, Deserialize, Clone)] +pub struct InferenceServerConfig { + pub batch: usize, + pub image_size: (u32, u32), + pub embedding_size: usize, +} + +pub async fn resize_for_embed + Send + 'static>(config: InferenceServerConfig, image: T) -> Result> { + let resized = tokio::task::spawn_blocking(move || { + let new = image.borrow().resize( + config.image_size.0, + config.image_size.1, + FilterType::Lanczos3 + ); + let mut buf = Vec::new(); + let mut csr = Cursor::new(&mut buf); + new.write_to(&mut csr, ImageFormat::Png)?; + Ok::, anyhow::Error>(buf) + }).await??; + Ok(resized) +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum EmbeddingRequest { + Images { images: Vec }, + Text { text: Vec } +} + +async fn fetch_backend_config(base_url: &str) -> Result { + let res = Client::new().get(&format!("{}/config", base_url)).send().await?; + Ok(rmp_serde::from_slice(&res.bytes().await?)?) +} + +pub async fn get_backend_config(clip_server: &str) -> InferenceServerConfig { + loop { + match fetch_backend_config(&clip_server).await { + Ok(backend) => break backend, + Err(e) => { + log::error!("Backend failed (fetch): {}", e); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } + } + } +} + +pub async fn query_clip_server(client: &Client, base_url: &str, path: &str, data: I) -> Result where I: Serialize, O: serde::de::DeserializeOwned, +{ + let response = client + .post(&format!("{}{}", base_url, path)) + .header("Content-Type", "application/msgpack") + .body(rmp_serde::to_vec_named(&data)?) + .send() + .await?; + let result: O = rmp_serde::from_slice(&response.bytes().await?)?; + Ok(result) +} diff --git a/src/main.rs b/src/main.rs index d4379e3..cbd2437 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,8 +30,10 @@ use prometheus::{register_int_counter, register_int_counter_vec, register_int_ga use ndarray::ArrayBase; mod ocr; +mod common; use crate::ocr::scan_image; +use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server}; lazy_static! { static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap(); @@ -126,13 +128,6 @@ struct FileRecord { thumbnails: Option>, } -#[derive(Debug, Deserialize, Clone)] -struct InferenceServerConfig { - batch: usize, - image_size: (u32, u32), - embedding_size: usize, -} - #[derive(Debug, Clone)] struct WConfig { backend: InferenceServerConfig, @@ -140,23 +135,6 @@ struct WConfig { predefined_embeddings: HashMap, ndarray::prelude::Dim<[usize; 1]>>> } -async fn query_clip_server( - client: &Client, - config: &Config, - path: &str, - data: I, -) -> Result where I: Serialize, O: serde::de::DeserializeOwned, -{ - let response = client - .post(&format!("{}{}", config.clip_server, path)) - .header("Content-Type", "application/msgpack") - .body(rmp_serde::to_vec_named(&data)?) - .send() - .await?; - let result: O = rmp_serde::from_slice(&response.bytes().await?)?; - Ok(result) -} - #[derive(Debug)] struct LoadedImage { image: Arc, @@ -170,13 +148,6 @@ struct EmbeddingInput { filename: String, } -#[derive(Debug, Serialize)] -#[serde(untagged)] -enum EmbeddingRequest { - Images { images: Vec }, - Text { text: Vec } -} - fn timestamp() -> i64 { chrono::Utc::now().timestamp_micros() } @@ -274,21 +245,6 @@ fn image_formats(_config: &Config) -> HashMap { formats } -async fn resize_for_embed(config: Arc, image: Arc) -> Result> { - let resized = tokio::task::spawn_blocking(move || { - let new = image.resize( - config.backend.image_size.0, - config.backend.image_size.1, - FilterType::Lanczos3 - ); - let mut buf = Vec::new(); - let mut csr = Cursor::new(&mut buf); - new.write_to(&mut csr, ImageFormat::Png)?; - Ok::, anyhow::Error>(buf) - }).await??; - Ok(resized) -} - async fn ingest_files(config: Arc) -> Result<()> { let pool = initialize_database(&config.service).await?; let client = Client::new(); @@ -324,7 +280,7 @@ async fn ingest_files(config: Arc) -> Result<()> { }; IMAGES_LOADED_COUNTER.inc(); if record.embedding.is_none() { - let resized = resize_for_embed(config.clone(), image.clone()).await?; + let resized = resize_for_embed(config.backend.clone(), image.clone()).await?; to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await? } @@ -505,7 +461,7 @@ async fn ingest_files(config: Arc) -> Result<()> { async move { let result: Vec = query_clip_server( &client, - &config.service, + &config.service.clip_server, "", EmbeddingRequest::Images { images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(), @@ -753,7 +709,7 @@ async fn handle_request(config: Arc, client: Arc, index: &IInde TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc(); let bytes = BASE64_STANDARD.decode(image)?; let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?); - image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.clone(), image).await?)); + image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.backend.clone(), image).await?)); image_weights.push(term.weight.unwrap_or(1.0)); } if let Some(text) = &term.text { @@ -792,7 +748,7 @@ async fn handle_request(config: Arc, client: Arc, index: &IInde } for batch in batches { - let embs: Vec> = query_clip_server(&client, &config.service, "/", batch).await?; + let embs: Vec> = query_clip_server(&client, &config.service.clip_server, "/", batch).await?; for emb in embs { total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb)); } @@ -813,11 +769,6 @@ async fn handle_request(config: Arc, client: Arc, index: &IInde }).into_response()) } -async fn get_backend_config(config: &Config) -> Result { - let res = Client::new().get(&format!("{}/config", config.clip_server)).send().await?; - Ok(rmp_serde::from_slice(&res.bytes().await?)?) -} - #[derive(Serialize, Deserialize)] struct FrontendInit { n_total: u64, @@ -834,15 +785,7 @@ async fn main() -> Result<()> { let pool = initialize_database(&config).await?; sqlx::query(SCHEMA).execute(&pool).await?; - let backend = loop { - match get_backend_config(&config).await { - Ok(backend) => break backend, - Err(e) => { - log::error!("Backend failed (fetch): {}", e); - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - } - } - }; + let backend = get_backend_config(&config.clip_server).await; let mut predefined_embeddings = HashMap::new(); diff --git a/src/reddit_dump.rs b/src/reddit_dump.rs new file mode 100644 index 0000000..32af576 --- /dev/null +++ b/src/reddit_dump.rs @@ -0,0 +1,359 @@ +use anyhow::{anyhow, Context, Result}; +use common::resize_for_embed; +use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, path::PathBuf, time::Duration, sync::Arc, str::FromStr}; +use serde::{Serialize, Deserialize}; +use lazy_static::lazy_static; +use regex::{RegexSet, bytes}; +use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}}; +use tokio_stream::wrappers::ReceiverStream; +use reqwest::Client; +use futures_util::stream::{StreamExt, TryStreamExt}; +use image::{DynamicImage, io::Reader as ImageReader}; +use mimalloc::MiMalloc; + +#[global_allocator] +static GLOBAL: MiMalloc = MiMalloc; + +mod common; + +use crate::common::{get_backend_config, query_clip_server, EmbeddingRequest}; + +fn function_which_returns_some_na() -> Option { Some(String::from("na")) } + +#[derive(Clone, Deserialize, Serialize, Debug)] +#[serde(untagged)] +enum BadTimestampFormat { + Int(u64), + String(String) +} + +impl BadTimestampFormat { + fn to_u64(&self) -> Result { + match self { + BadTimestampFormat::Int(x) => Ok(*x), + BadTimestampFormat::String(x) => u64::from_str(&x).context("invalid string") + } + } +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +struct Entry { + url: String, + over_18: bool, + title: String, + author: Option, + selftext: String, + subreddit: String, + created_utc: BadTimestampFormat, + #[serde(default="function_which_returns_some_na")] + post_hint: Option, + id: String +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +struct ProcessedEntry { + url: String, + id: String, + title: String, + subreddit: String, + author: String, + timestamp: u64, + blob: Vec +} + +lazy_static! { + static ref URL_IGNORE: RegexSet = RegexSet::new([ + r"//reddit\.com", + r"\.html?", + r"\.php", + r"\?articleid=", + r"\.aspx?", + r"\.xml", + r"//youtube\.com", + // TODO fill in more things, maybe try and collect thumbnails or something + ]).unwrap(); + static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"] + .into_iter().map(str::as_bytes).collect(); + static ref OBJECT_HACKY_IGNORE: bytes::RegexSet = bytes::RegexSet::new([ + r#""author":"\[deleted\]""#, + r#""promoted":true"#, // these seem to be ads which are in the data for some reason, and lack some important fields + r#""domain":"self.promos""#, // ....... + r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record + ]).unwrap(); +} + +fn process_file(path: PathBuf, tx: mpsc::Sender) -> Result<()> { + let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?; + stream.window_log_max(31)?; + let mut stream = BufReader::new(stream); + let mut buf = Vec::new(); + loop { + if stream.read_until(0x0A, &mut buf)? == 0 { + break + } + // we would discard these later, but they have a different format so they can't be deserialized straight into Entries + if OBJECT_HACKY_IGNORE.is_match(&buf) { + buf.clear(); + continue; + } + let entry = match sonic_rs::serde::from_slice::(buf.as_slice()) { + Ok(x) => x, + Err(e) => { + log::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf)); + return Ok(()) + } + }; + if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() { + if !URL_IGNORE.is_match(&entry.url) { + match &entry.post_hint { + Some(x) if x == "na" || x == "image" => { + tx.blocking_send(entry)?; + }, + _ => () + } + } + } + buf.clear(); + } + Ok(()) +} + +struct Config { + max_content_length: usize, + input: String, + output: String, + backend: String, + mode: OperatingMode +} + +async fn fetch_file(client: reqwest::Client, config: Arc, url: &str) -> Result> { + let mut response = client.get(url).send().await?; + if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) { + return Err(anyhow!("invalid Content-Type")); + } + match response.content_length() { + Some(x) if x > (config.max_content_length as u64) => return Err(anyhow!("response too large")), + _ => () + } + let mut buffer = vec![]; + while let Some(chunk) = response.chunk().await? { + buffer.extend(chunk); + if buffer.len() > config.max_content_length { + return Err(anyhow!("response too large")); + } + } + Ok(buffer) +} + +fn write_output(config: Arc, mut rx: Receiver) -> Result<()> { + let mut out = fs::File::create(&config.output)?; + let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish(); + let mut buf_stream = BufWriter::new(stream); + while let Some(x) = rx.blocking_recv() { + rmp_serde::encode::write(&mut buf_stream, &x)?; + } + Ok(()) +} + +enum OperatingMode { + Count, + Sample(f32), + FullRun +} + +#[tokio::main] +async fn main() -> Result<()> { + pretty_env_logger::init(); + let cpus = num_cpus::get(); + + let config = Arc::new(Config { + max_content_length: 1<<23, + input: String::from("./submissions"), + output: String::from("./data.zst"), + backend: String::from("http://localhost:1708"), + mode: OperatingMode::Count + }); + + let backend = get_backend_config(&config.backend).await; + + log::info!("connected to inference server"); + + let (entries_tx, mut entries_rx) = mpsc::channel::(32768); + let (buffers_tx, buffers_rx) = mpsc::channel(128); + let (resized_tx, resized_rx) = mpsc::channel(backend.batch); + let (final_write_tx, final_write_rx) = mpsc::channel::(32768); + let client = Client::builder() + .user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"))) + .timeout(Duration::from_secs(30)) + .build()?; + + let load_task: JoinHandle> = match config.mode { + OperatingMode::Count => tokio::task::spawn(async move { + let mut counter = 0; + while let Some(_) = entries_rx.recv().await { + counter += 1 + } + println!("{}", counter); + Ok(()) + }), + _ => tokio::task::spawn({ + let client = client.clone(); + let config = config.clone(); + let stream = ReceiverStream::new(entries_rx); + stream.map(Ok).try_for_each_concurrent(Some(128), move |entry| { + let client = client.clone(); + let config = config.clone(); + let buffers_tx = buffers_tx.clone(); + async move { + if let OperatingMode::Sample(rate) = config.mode { + if fastrand::f32() > rate { + return Ok(()) + } + } + match fetch_file(client, config.clone(), &entry.url).await { + Ok(buf) => { + buffers_tx.send((entry, buf)).await?; + }, + Err(e) => { + log::warn!("{} failed: {}", &entry.url, e) + } + } + Ok(()) + } + } + )}) + }; + + let resize_task = match config.mode { + OperatingMode::Count => None, + _ => Some(tokio::task::spawn({ + let stream = ReceiverStream::new(buffers_rx); + let backend = backend.clone(); + stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, buffer)| { + let backend = backend.clone(); + let resized_tx = resized_tx.clone(); + async move { + let image_result = tokio::task::spawn_blocking(|| { + let csr = Cursor::new(buffer); + let image = ImageReader::new(csr).decode()?; + Result::::Ok(image) + }).await?; + let image = match image_result { + Ok(image) => image, + Err(e) => { + log::warn!("loading {} failed: {}", entry.url, e); + return Result::<(), anyhow::Error>::Ok(()); + } + }; + let resized = resize_for_embed(backend.clone(), image).await?; + resized_tx.send((entry, resized)).await?; + Ok(()) + } + }) + })) + }; + + let embedding_generation_task: Option>> = match config.mode { + OperatingMode::Count => None, + _ => Some(tokio::spawn({ + let stream = ReceiverStream::new(resized_rx).chunks(backend.batch); + let client = client.clone(); + let config = config.clone(); + // keep multiple embedding requests in flight + stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| { + let (entries, bytes): (Vec, Vec>) = batch.into_iter().unzip(); + let client = client.clone(); + let config = config.clone(); + let final_write_tx = final_write_tx.clone(); + async move { + let result: Vec = query_clip_server( + &client, + &config.backend, + "", + EmbeddingRequest::Images { + images: bytes.into_iter().map(serde_bytes::ByteBuf::from).collect(), + }, + ).await.context("querying CLIP server")?; + + for (vector, entry) in result.into_iter().zip(entries) { + println!("{:?}", entry); + final_write_tx.send(ProcessedEntry { + url: entry.url, + id: entry.id, + title: entry.title, + subreddit: entry.subreddit, + author: entry.author.unwrap(), + blob: vector.into_vec(), + timestamp: entry.created_utc.to_u64()? + }).await?; + } + anyhow::Result::Ok(()) + } + }) + })) + }; + + let config_ = config.clone(); + let output_writer_task = match config.mode { + OperatingMode::Sample(_) | OperatingMode::FullRun => Some(tokio::task::spawn_blocking(move || write_output(config_, final_write_rx))), + _ => None + }; + + log::info!("working..."); + + let mut paths = vec![]; + for file in fs::read_dir(&config.input)? { + paths.push(file?.path()); + } + + paths.sort(); + + match config.mode { + OperatingMode::Count | OperatingMode::Sample(_) => { + let mut set = JoinSet::new(); + let semaphore = Arc::new(Semaphore::new(cpus)); + + for path in paths { + let semaphore = semaphore.clone(); + let permit = semaphore.acquire_owned().await?; + let entries_tx = entries_tx.clone(); + let path_ = path.clone(); + log::info!("reading {:?}", path); + set.spawn_blocking(move || { + match process_file(path_, entries_tx) { + Ok(_) => (), + Err(e) => log::error!("could not parse {:?} {:?}", &path, e) + } + std::mem::drop(permit); + }); + } + + std::mem::drop(entries_tx); + + while let Some(x) = set.try_join_next() { + x?; + } + }, + OperatingMode::FullRun => { + for path in paths { + let entries_tx = entries_tx.clone(); + let path_ = path.clone(); + log::info!("reading {:?}", path); + let c = tokio::task::spawn_blocking(move || process_file(path_, entries_tx)).await?; + match c { + Ok(_) => (), + _ => log::error!("could not parse {:?} {:?}", &path, c) + } + } + + std::mem::drop(entries_tx); + + load_task.await??; + } + } + if let Some(task) = resize_task { task.await??; } + if let Some(task) = embedding_generation_task { task.await?? }; + if let Some(task) = output_writer_task { task.await?? }; + + Ok(()) +} \ No newline at end of file