From a8329e43fca1b6a4c623f222d6d87eba195a6789 Mon Sep 17 00:00:00 2001 From: osmarks Date: Mon, 27 May 2024 15:22:28 +0100 Subject: [PATCH] more progress on Reddit --- Cargo.lock | 12 ++-- Cargo.toml | 3 + src/reddit_dump.rs | 140 +++++++++++++++++++++++++++++++++------------ 3 files changed, 112 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a6fc4e5..4521135 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1122,8 +1122,7 @@ dependencies = [ [[package]] name = "image" version = "0.25.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +source = "git+https://github.com/fintelia/image/?branch=upgrade-zune-jpeg#54ee15fae5f865f9806bda70c2118e9e572e7deb" dependencies = [ "bytemuck", "byteorder", @@ -1141,7 +1140,6 @@ dependencies = [ "rayon", "rgb", "tiff", - "zune-core", "zune-jpeg", ] @@ -3550,9 +3548,9 @@ dependencies = [ [[package]] name = "zune-core" -version = "0.4.12" +version = "0.5.0-rc1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a" +checksum = "e0d1b427373b52a2497c49b0860a5290daab6a0437902ffd8f607367bd5eb7d0" [[package]] name = "zune-inflate" @@ -3565,9 +3563,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.4.11" +version = "0.5.0-rc1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448" +checksum = "e4019a3b4e46db2d81faab716e5034dd212f1e105bc55b3f5ca4381dd78736e0" dependencies = [ "zune-core", ] diff --git a/Cargo.toml b/Cargo.toml index 5b3068e..9c2357f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,9 @@ fastrand = "2" mimalloc = "0.1" sonic-rs = "0.3" +[patch.crates-io] +image = { git = "https://github.com/fintelia/image/", branch = "upgrade-zune-jpeg" } + [[bin]] name = "reddit-dump" path = "src/reddit_dump.rs" \ No newline at end of file diff --git a/src/reddit_dump.rs b/src/reddit_dump.rs index 32af576..6968099 100644 --- a/src/reddit_dump.rs +++ b/src/reddit_dump.rs @@ -1,9 +1,9 @@ use anyhow::{anyhow, Context, Result}; use common::resize_for_embed; -use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, path::PathBuf, time::Duration, sync::Arc, str::FromStr}; +use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, time::Duration, sync::Arc, str::FromStr, path::PathBuf}; use serde::{Serialize, Deserialize}; use lazy_static::lazy_static; -use regex::{RegexSet, bytes}; +use regex::{RegexSet, bytes, Regex}; use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}}; use tokio_stream::wrappers::ReceiverStream; use reqwest::Client; @@ -43,7 +43,7 @@ struct Entry { title: String, author: Option, selftext: String, - subreddit: String, + subreddit: Option, created_utc: BadTimestampFormat, #[serde(default="function_which_returns_some_na")] post_hint: Option, @@ -70,6 +70,11 @@ lazy_static! { r"\.aspx?", r"\.xml", r"//youtube\.com", + r"/rss/", + r"//vimeo\.com", + r"//www\.youtube\.com", + r"//youtu\.be", + r"//www\.reddit\.com", // TODO fill in more things, maybe try and collect thumbnails or something ]).unwrap(); static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"] @@ -80,9 +85,13 @@ lazy_static! { r#""domain":"self.promos""#, // ....... r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record ]).unwrap(); + static ref URL_REPLACEMENT_RULES: Vec<(Regex, &'static str)> = [ + (r"//imgur.com/([A-Za-z0-9]+)", r"//i.imgur.com/$1.jpg"), + (r"^http://", r"https://") + ].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), e)).collect(); } -fn process_file(path: PathBuf, tx: mpsc::Sender) -> Result<()> { +fn process_file(path: PathBuf, tx: mpsc::Sender, timestamp_threshold: Option) -> Result<()> { let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?; stream.window_log_max(31)?; let mut stream = BufReader::new(stream); @@ -103,11 +112,23 @@ fn process_file(path: PathBuf, tx: mpsc::Sender) -> Result<()> { return Ok(()) } }; - if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() { + if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() && entry.subreddit.is_some() { if !URL_IGNORE.is_match(&entry.url) { match &entry.post_hint { Some(x) if x == "na" || x == "image" => { - tx.blocking_send(entry)?; + // Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine". + let after_threshold = match timestamp_threshold { + Some(threshold) => { + let timestamp = match &entry.created_utc { + BadTimestampFormat::Int(x) => *x, + BadTimestampFormat::String(s) => u64::from_str(s).unwrap() + }; + timestamp > threshold + }, + None => true + }; + + if after_threshold { tx.blocking_send(entry)?; } }, _ => () } @@ -123,11 +144,17 @@ struct Config { input: String, output: String, backend: String, - mode: OperatingMode + mode: OperatingMode, + filename_threshold: Option } async fn fetch_file(client: reqwest::Client, config: Arc, url: &str) -> Result> { - let mut response = client.get(url).send().await?; + // inelegant but I can't get it to work using Cows + let mut url = url.to_string(); + for (regex, replacement) in URL_REPLACEMENT_RULES.iter() { + url = regex.replace(&url, *replacement).to_string(); + } + let mut response = client.get(&*url).send().await?; if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) { return Err(anyhow!("invalid Content-Type")); } @@ -146,7 +173,7 @@ async fn fetch_file(client: reqwest::Client, config: Arc, url: &str) -> } fn write_output(config: Arc, mut rx: Receiver) -> Result<()> { - let mut out = fs::File::create(&config.output)?; + let mut out = fs::File::options().append(true).open(&config.output)?; let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish(); let mut buf_stream = BufWriter::new(stream); while let Some(x) = rx.blocking_recv() { @@ -161,6 +188,26 @@ enum OperatingMode { FullRun } +fn readback_output(path: &str) -> Result<(u64, usize)> { + use rmp_serde::decode::Error; + let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?; + let mut stream = BufReader::new(stream); + let mut latest_timestamp = 0; + let mut count = 0; + loop { + let res: Result = rmp_serde::from_read(&mut stream); + if res.is_ok() { + count += 1; + } + match res { + Ok(x) => latest_timestamp = latest_timestamp.max(x.timestamp), + Err(Error::InvalidDataRead(x)) | Err(Error::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e).context("decode fail") + } + } + Ok((latest_timestamp, count)) +} + #[tokio::main] async fn main() -> Result<()> { pretty_env_logger::init(); @@ -169,11 +216,30 @@ async fn main() -> Result<()> { let config = Arc::new(Config { max_content_length: 1<<23, input: String::from("./submissions"), - output: String::from("./data.zst"), + output: String::from("./sample.zst"), backend: String::from("http://localhost:1708"), - mode: OperatingMode::Count + mode: OperatingMode::Sample(0.004), + filename_threshold: None }); + let timestamp_threshold = match config.mode { + OperatingMode::Count => None, + _ => { + match readback_output(&config.output) { + Ok(x) => Some(x), + Err(e) => { + log::warn!("could not read output: {}", e); + None + } + } + } + }; + + + if let Some((threshold, count)) = timestamp_threshold { + log::info!("threshold is {}, {} items", threshold, count); + } + let backend = get_backend_config(&config.backend).await; log::info!("connected to inference server"); @@ -200,7 +266,7 @@ async fn main() -> Result<()> { let client = client.clone(); let config = config.clone(); let stream = ReceiverStream::new(entries_rx); - stream.map(Ok).try_for_each_concurrent(Some(128), move |entry| { + stream.map(Ok).try_for_each_concurrent(Some(512), move |entry| { let client = client.clone(); let config = config.clone(); let buffers_tx = buffers_tx.clone(); @@ -212,6 +278,7 @@ async fn main() -> Result<()> { } match fetch_file(client, config.clone(), &entry.url).await { Ok(buf) => { + log::debug!("got {}", &entry.url); buffers_tx.send((entry, buf)).await?; }, Err(e) => { @@ -235,7 +302,7 @@ async fn main() -> Result<()> { async move { let image_result = tokio::task::spawn_blocking(|| { let csr = Cursor::new(buffer); - let image = ImageReader::new(csr).decode()?; + let image = ImageReader::new(csr).with_guessed_format()?.decode()?; Result::::Ok(image) }).await?; let image = match image_result { @@ -276,12 +343,11 @@ async fn main() -> Result<()> { ).await.context("querying CLIP server")?; for (vector, entry) in result.into_iter().zip(entries) { - println!("{:?}", entry); final_write_tx.send(ProcessedEntry { url: entry.url, id: entry.id, title: entry.title, - subreddit: entry.subreddit, + subreddit: entry.subreddit.unwrap(), author: entry.author.unwrap(), blob: vector.into_vec(), timestamp: entry.created_utc.to_u64()? @@ -303,14 +369,20 @@ async fn main() -> Result<()> { let mut paths = vec![]; for file in fs::read_dir(&config.input)? { - paths.push(file?.path()); + let path = file?.path(); + let last_segment = path.file_name().context("invalid file structure")?.to_str().context("non-UTF8 path")?.to_string(); + match &config.filename_threshold { + Some(threshold) if threshold >= &last_segment => (), + _ => paths.push(path) + } } paths.sort(); + let mut file_readers = JoinSet::new(); + match config.mode { OperatingMode::Count | OperatingMode::Sample(_) => { - let mut set = JoinSet::new(); let semaphore = Arc::new(Semaphore::new(cpus)); for path in paths { @@ -319,41 +391,37 @@ async fn main() -> Result<()> { let entries_tx = entries_tx.clone(); let path_ = path.clone(); log::info!("reading {:?}", path); - set.spawn_blocking(move || { - match process_file(path_, entries_tx) { + file_readers.spawn_blocking(move || { + match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) { Ok(_) => (), Err(e) => log::error!("could not parse {:?} {:?}", &path, e) } std::mem::drop(permit); }); } - - std::mem::drop(entries_tx); - - while let Some(x) = set.try_join_next() { - x?; - } }, OperatingMode::FullRun => { for path in paths { let entries_tx = entries_tx.clone(); let path_ = path.clone(); log::info!("reading {:?}", path); - let c = tokio::task::spawn_blocking(move || process_file(path_, entries_tx)).await?; - match c { + file_readers.spawn_blocking(move || match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) { Ok(_) => (), - _ => log::error!("could not parse {:?} {:?}", &path, c) - } + Err(e) => log::error!("could not parse {:?} {:?}", &path, e) + }); } - - std::mem::drop(entries_tx); - - load_task.await??; } } - if let Some(task) = resize_task { task.await??; } - if let Some(task) = embedding_generation_task { task.await?? }; - if let Some(task) = output_writer_task { task.await?? }; + + while let Some(x) = file_readers.try_join_next() { + x?; + } + + std::mem::drop(entries_tx); + println!("{:?}", load_task.await?); + if let Some(task) = resize_task { println!("{:?}", task.await?); } + if let Some(task) = embedding_generation_task { println!("{:?}", task.await?) }; + if let Some(task) = output_writer_task { println!("{:?}", task.await?) }; Ok(()) } \ No newline at end of file