From 8097ce8d9188d7b43895e992a12298edafc925b4 Mon Sep 17 00:00:00 2001 From: osmarks Date: Mon, 11 Nov 2024 19:43:07 +0000 Subject: [PATCH] improve dump processing and misc performance fixes --- src/common.rs | 6 ++ src/dump_processor.rs | 191 ++++++++++++++++++++++++++++++++++++------ src/get_embedding.py | 18 ++++ src/main.rs | 8 +- src/reddit_dump.rs | 18 +++- 5 files changed, 206 insertions(+), 35 deletions(-) create mode 100644 src/get_embedding.py diff --git a/src/common.rs b/src/common.rs index 3fba89e..a089f6e 100644 --- a/src/common.rs +++ b/src/common.rs @@ -81,3 +81,9 @@ pub async fn query_clip_server(client: &Client, base_url: &str, path: &str let result: O = rmp_serde::from_slice(&response.bytes().await?)?; Ok(result) } + +pub fn decode_fp16_buffer(buf: &[u8]) -> Vec { + buf.chunks_exact(2) + .map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32()) + .collect() +} diff --git a/src/dump_processor.rs b/src/dump_processor.rs index ea345dc..c4747a9 100644 --- a/src/dump_processor.rs +++ b/src/dump_processor.rs @@ -1,9 +1,15 @@ use anyhow::{Result, Context}; use serde::{Serialize, Deserialize}; -use std::io::BufReader; +use std::io::{BufReader, Write}; use rmp_serde::decode::Error as DecodeError; use std::fs; -use base64::{engine::general_purpose::URL_SAFE, Engine as _}; +use base64::Engine; +use argh::FromArgs; +use chrono::{TimeZone, Utc, DateTime}; +use std::collections::{VecDeque, HashSet}; +use std::hash::Hasher; + +mod common; // TODO refactor #[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] @@ -22,35 +28,172 @@ struct ProcessedEntry { subreddit: String, author: String, timestamp: u64, - #[serde(with = "serde_bytes")] + #[serde(with="serde_bytes")] embedding: Vec, metadata: OriginalImageMetadata } +#[derive(FromArgs)] +#[argh(description="Process scraper dump files")] +struct CLIArguments { + #[argh(option, short='s', description="read subset of records")] + sample: Option, + #[argh(switch, short='p', description="print basic information for records")] + print_records: bool, + #[argh(switch, short='e',description="print embeddings")] + print_embeddings: bool, + #[argh(switch, short='a', description="print aggregates")] + print_aggregates: bool, + #[argh(option, short='E', description="x:y - load embedding named x from file y")] + embedding: Vec, + #[argh(option, short='H', description="path for histograms of dot with embeddings")] + histograms: Option, + #[argh(switch, short='D', description="enable deduplicator")] + deduplicate: bool, + #[argh(option, short='T', description="deduplication Hamming distance threshold")] + threshold: Option, + #[argh(positional)] + paths: Vec +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +struct Histogram { + min: f32, + max: f32, + buckets: Vec +} + +impl Histogram { + fn new(min: f32, max: f32, count: usize) -> Self { + let buckets = (0..count).map(|_| 0).collect(); + Self { min, max, buckets } + } + + fn add(&mut self, x: f32) { + let bucket = if x < self.min { + 0 + } else if x >= self.max { + self.buckets.len() - 1 + } else { + ((x - self.min) / (self.max - self.min) * (self.buckets.len() as f32)) as usize + }; + self.buckets[bucket] += 1; + } + + fn buckets(&self) -> Vec<(f32, u64)> { + let step = (self.max - self.min) / (self.buckets.len() as f32); + self.buckets.iter().enumerate().map(|(i, x)| (self.min + (i as f32) * step, *x)).collect() + } +} + +fn dot(x: &[f32], y: &[f32]) -> f32 { + x.iter().zip(y).map(|(a, b)| a * b).sum::() +} + +fn binarize(x: &[f32]) -> Vec { + let mut buf = vec![0; x.len() / 8]; + for i in 0..(x.len() / 8) { + buf[i] = ((x[i * 8] > 0.0) as u8) + (((x[i * 8 + 1] > 0.0) as u8) << 1) + (((x[i * 8 + 2] > 0.0) as u8) << 2) + (((x[i * 8 + 3] > 0.0) as u8) << 3) + (((x[i * 8 + 4] > 0.0) as u8) << 4) + (((x[i * 8 + 5] > 0.0) as u8) << 5) + (((x[i * 8 + 6] > 0.0) as u8) << 6) + (((x[i * 8 + 7] > 0.0) as u8) << 7); + } + buf +} + fn main() -> Result<()> { - let path = std::env::args().nth(1).context("missing path")?; - let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?; - let mut stream = BufReader::new(stream); - let mut latest_timestamp = 0; - let mut earliest_timestamp = u64::MAX; + let args: CLIArguments = argh::from_env(); + let mut rng = fastrand::Rng::new(); + let mut latest_timestamp = DateTime::::MIN_UTC; + let mut earliest_timestamp = DateTime::::MAX_UTC; let mut count = 0; - loop { - let res: Result = rmp_serde::from_read(&mut stream); - if res.is_ok() { - count += 1; - } - match res { - Ok(x) => { - if x.timestamp > latest_timestamp { - //println!("{} {} https://reddit.com/r/{}/comments/{} {} https://mse.osmarks.net/?e={}", x.timestamp, count, x.subreddit, x.id, x.metadata.final_url, URL_SAFE.encode(x.embedding)); - latest_timestamp = x.timestamp; - } - earliest_timestamp = earliest_timestamp.min(x.timestamp); - }, - Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e).context("decode fail") + let mut deduped_count = 0; + let mut embeddings = Vec::new(); + for x in args.embedding { + let (name, path) = x.split_once(':').unwrap(); + let blob = std::fs::read(path).context("read embedding")?; + embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512))); + } + + // TODO ring of vecs probably has bad cache locality + let mut dedupe_ring: VecDeque> = VecDeque::with_capacity(2<<10); + let threshold = args.threshold.unwrap_or(3); + + for path in args.paths { + let stream = zstd::stream::Decoder::new(fs::File::open(path).context("read dump file")?)?; + let mut stream = BufReader::new(stream); + + loop { + let res: Result = rmp_serde::from_read(&mut stream); + if res.is_ok() { + count += 1; + } + match res { + Ok(x) => { + if args.sample.is_some() && rng.f32() > args.sample.unwrap() { + continue; + } + let timestamp = Utc.timestamp_opt(x.timestamp as i64, 0).unwrap(); + + let embedding = common::decode_fp16_buffer(&x.embedding); + + latest_timestamp = latest_timestamp.max(timestamp); + earliest_timestamp = earliest_timestamp.min(timestamp); + + if args.deduplicate { + let code = binarize(&embedding); + if dedupe_ring.len() == dedupe_ring.capacity() { + dedupe_ring.pop_front().unwrap(); + } + let has_match = dedupe_ring.iter().any(|x| hamming::distance(x, &code) <= threshold); + dedupe_ring.push_back(code); + if has_match { + deduped_count += 1; + continue; + } + } + + if args.print_records { + println!("{} {} https://reddit.com/r/{}/comments/{} {}", timestamp, x.title, x.subreddit, x.id, x.metadata.final_url); + } + if args.print_embeddings { + println!("https://mse.osmarks.net/?e={}", base64::engine::general_purpose::URL_SAFE.encode(&x.embedding)); + } + for (_name, vec, histogram) in &mut embeddings { + let dot = dot(&embedding, vec); + histogram.add(dot); + } + }, + Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e).context("decode fail") + } + } + } + + if args.print_aggregates { + println!("earliest={} latest={} count={} deduped={}", earliest_timestamp, latest_timestamp, count, deduped_count); + } + if let Some(histogram_path) = args.histograms { + let mut file = std::fs::File::create(histogram_path)?; + for (name, _, histogram) in &embeddings { + let width = 800.0; + let padding = 40.0; + let bars_height = 300 as f64; + let buckets = histogram.buckets(); + let max_count = *buckets.iter().map(|(_max, count)| count).max().unwrap(); + let bar_width = width / buckets.len() as f64; + let plot = maud::html! { + h1 { (name) } + svg style="border: 1px solid gray;" viewBox=(format!("{} 0 {} {}", -padding * 0.25, width + (padding * 0.75), bars_height + 50.0)) xmlns="http://www.w3.org/2000/svg" width=(format!("{}", width + padding)) height=(format!("{}", bars_height + 50.0)) { + @for (i, (min, count)) in buckets.into_iter().enumerate() { + @let height = bars_height * (count as f64 / max_count as f64); + rect width=(format!("{}", bar_width)) x=(format!("{}", bar_width * i as f64)) height=(format!("{}", height)) y=(format!("{}", bars_height - height)) { + title { + (format!("{} {}", min, count)) + } + } + } + } + }; + file.write_all(plot.into_string().as_bytes())?; } } - println!("{} {} {}", earliest_timestamp, latest_timestamp, count); Ok(()) } diff --git a/src/get_embedding.py b/src/get_embedding.py new file mode 100644 index 0000000..2c6695b --- /dev/null +++ b/src/get_embedding.py @@ -0,0 +1,18 @@ +import json +import requests +import base64 +import msgpack +import sys + +with open("mse_config.json") as f: + config = json.load(f) + +def get_embedding(req): + return msgpack.unpackb(requests.post(config["clip_server"], data=msgpack.packb(req)).content) + +output, input, *xs = sys.argv[1:] + +with open(output, "wb") as f: + with open(input, "rb") as g: + input_data = g.read() + f.write(get_embedding({"images": [input_data]})[0]) diff --git a/src/main.rs b/src/main.rs index 83a81d1..348f108 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,7 +40,7 @@ mod common; mod video_reader; use crate::ocr::scan_image; -use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server}; +use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server, decode_fp16_buffer}; lazy_static! { static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap(); @@ -893,12 +893,6 @@ async fn build_index(config: Arc) -> Result { Ok(index) } -fn decode_fp16_buffer(buf: &[u8]) -> Vec { - buf.chunks_exact(2) - .map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32()) - .collect() -} - type EmbeddingVector = Vec; #[derive(Debug, Serialize)] diff --git a/src/reddit_dump.rs b/src/reddit_dump.rs index 1c7cde1..62b415e 100644 --- a/src/reddit_dump.rs +++ b/src/reddit_dump.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Context, Result}; use common::resize_for_embed; use itertools::Itertools; -use std::{collections::HashSet, ffi::OsStr, fs::{self, read_dir}, io::{BufRead, BufReader, BufWriter, Cursor}, path::PathBuf, str::FromStr, sync::Arc, time::Duration}; +use std::{collections::HashSet, ffi::OsStr, fs::{self, read_dir}, io::{BufRead, BufReader, BufWriter, Cursor}, path::PathBuf, str::FromStr, sync::Arc, time::Duration, hash::Hasher}; use serde::{Serialize, Deserialize}; use lazy_static::lazy_static; use regex::{bytes, Regex, RegexSet, RegexSetBuilder}; @@ -148,6 +148,7 @@ lazy_static! { static ref IMAGE_FILESIZES_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_filesizes", "filesizes of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.5, 29).unwrap()).unwrap(); static ref IMAGE_PIXELS_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_pixels", "pixel count of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.3, 53).unwrap()).unwrap(); static ref HTML_EXTRACTS_COUNTER: IntCounter = register_int_counter!("mse_scrape_html_extracts", "html extraction operations").unwrap(); + static ref DISCARDED_COUNTER: IntCounter = register_int_counter!("mse_scrape_discarded", "images discarded by hash").unwrap(); } #[instrument(skip(tx))] @@ -209,7 +210,8 @@ struct Config { mode: OperatingMode, filename_threshold: Option, metrics_addr: String, - contact_info: String + contact_info: String, + discard_hashes: HashSet } #[instrument(skip(client, config))] @@ -239,12 +241,19 @@ async fn fetch_file(client: reqwest::Client, config: Arc, url: &str) -> _ => () } let mut buffer = vec![]; + let mut hash = seahash::SeaHasher::new(); while let Some(chunk) = response.chunk().await? { + hash.write(&chunk); buffer.extend(chunk); if buffer.len() > config.max_content_length { return Err(anyhow!("response too large")); } } + let hash = hash.finish(); + if config.discard_hashes.contains(&hash) { + DISCARDED_COUNTER.inc(); + return Err(anyhow!("discarded")); + } if let Some(extract_rule) = html_extract_rule { if content_type == "text/html" { let buffer = String::from_utf8_lossy(&buffer).to_string(); @@ -329,7 +338,7 @@ async fn serve_metrics(config: Arc) -> Result<()> { #[tokio::main] async fn main() -> Result<()> { - console_subscriber::init(); + tracing_subscriber::fmt::init(); let cpus = num_cpus::get(); @@ -341,7 +350,8 @@ async fn main() -> Result<()> { mode: OperatingMode::FullRun, filename_threshold: Some(String::from("RS_2017-08.zst")), metrics_addr: String::from("0.0.0.0:9914"), - contact_info: String::from("scraping-ops@osmarks.net") + contact_info: String::from("scraping-ops@osmarks.net"), + discard_hashes: [4168519401919155623, 4577010157274124110].into_iter().collect() }); serve_metrics(config.clone()).await?;