1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00

more progress on Reddit

This commit is contained in:
osmarks 2024-05-27 15:22:28 +01:00
parent f8d68d9d54
commit a8329e43fc
3 changed files with 112 additions and 43 deletions

12
Cargo.lock generated
View File

@ -1122,8 +1122,7 @@ dependencies = [
[[package]]
name = "image"
version = "0.25.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11"
source = "git+https://github.com/fintelia/image/?branch=upgrade-zune-jpeg#54ee15fae5f865f9806bda70c2118e9e572e7deb"
dependencies = [
"bytemuck",
"byteorder",
@ -1141,7 +1140,6 @@ dependencies = [
"rayon",
"rgb",
"tiff",
"zune-core",
"zune-jpeg",
]
@ -3550,9 +3548,9 @@ dependencies = [
[[package]]
name = "zune-core"
version = "0.4.12"
version = "0.5.0-rc1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a"
checksum = "e0d1b427373b52a2497c49b0860a5290daab6a0437902ffd8f607367bd5eb7d0"
[[package]]
name = "zune-inflate"
@ -3565,9 +3563,9 @@ dependencies = [
[[package]]
name = "zune-jpeg"
version = "0.4.11"
version = "0.5.0-rc1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448"
checksum = "e4019a3b4e46db2d81faab716e5034dd212f1e105bc55b3f5ca4381dd78736e0"
dependencies = [
"zune-core",
]

View File

@ -40,6 +40,9 @@ fastrand = "2"
mimalloc = "0.1"
sonic-rs = "0.3"
[patch.crates-io]
image = { git = "https://github.com/fintelia/image/", branch = "upgrade-zune-jpeg" }
[[bin]]
name = "reddit-dump"
path = "src/reddit_dump.rs"

View File

@ -1,9 +1,9 @@
use anyhow::{anyhow, Context, Result};
use common::resize_for_embed;
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, path::PathBuf, time::Duration, sync::Arc, str::FromStr};
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, time::Duration, sync::Arc, str::FromStr, path::PathBuf};
use serde::{Serialize, Deserialize};
use lazy_static::lazy_static;
use regex::{RegexSet, bytes};
use regex::{RegexSet, bytes, Regex};
use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}};
use tokio_stream::wrappers::ReceiverStream;
use reqwest::Client;
@ -43,7 +43,7 @@ struct Entry {
title: String,
author: Option<String>,
selftext: String,
subreddit: String,
subreddit: Option<String>,
created_utc: BadTimestampFormat,
#[serde(default="function_which_returns_some_na")]
post_hint: Option<String>,
@ -70,6 +70,11 @@ lazy_static! {
r"\.aspx?",
r"\.xml",
r"//youtube\.com",
r"/rss/",
r"//vimeo\.com",
r"//www\.youtube\.com",
r"//youtu\.be",
r"//www\.reddit\.com",
// TODO fill in more things, maybe try and collect thumbnails or something
]).unwrap();
static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
@ -80,9 +85,13 @@ lazy_static! {
r#""domain":"self.promos""#, // .......
r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record
]).unwrap();
static ref URL_REPLACEMENT_RULES: Vec<(Regex, &'static str)> = [
(r"//imgur.com/([A-Za-z0-9]+)", r"//i.imgur.com/$1.jpg"),
(r"^http://", r"https://")
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), e)).collect();
}
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>) -> Result<()> {
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Option<u64>) -> Result<()> {
let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
stream.window_log_max(31)?;
let mut stream = BufReader::new(stream);
@ -103,11 +112,23 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>) -> Result<()> {
return Ok(())
}
};
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() {
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() && entry.subreddit.is_some() {
if !URL_IGNORE.is_match(&entry.url) {
match &entry.post_hint {
Some(x) if x == "na" || x == "image" => {
tx.blocking_send(entry)?;
// Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine".
let after_threshold = match timestamp_threshold {
Some(threshold) => {
let timestamp = match &entry.created_utc {
BadTimestampFormat::Int(x) => *x,
BadTimestampFormat::String(s) => u64::from_str(s).unwrap()
};
timestamp > threshold
},
None => true
};
if after_threshold { tx.blocking_send(entry)?; }
},
_ => ()
}
@ -123,11 +144,17 @@ struct Config {
input: String,
output: String,
backend: String,
mode: OperatingMode
mode: OperatingMode,
filename_threshold: Option<String>
}
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<Vec<u8>> {
let mut response = client.get(url).send().await?;
// inelegant but I can't get it to work using Cows
let mut url = url.to_string();
for (regex, replacement) in URL_REPLACEMENT_RULES.iter() {
url = regex.replace(&url, *replacement).to_string();
}
let mut response = client.get(&*url).send().await?;
if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) {
return Err(anyhow!("invalid Content-Type"));
}
@ -146,7 +173,7 @@ async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) ->
}
fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result<()> {
let mut out = fs::File::create(&config.output)?;
let mut out = fs::File::options().append(true).open(&config.output)?;
let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish();
let mut buf_stream = BufWriter::new(stream);
while let Some(x) = rx.blocking_recv() {
@ -161,6 +188,26 @@ enum OperatingMode {
FullRun
}
fn readback_output(path: &str) -> Result<(u64, usize)> {
use rmp_serde::decode::Error;
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
let mut stream = BufReader::new(stream);
let mut latest_timestamp = 0;
let mut count = 0;
loop {
let res: Result<ProcessedEntry, Error> = rmp_serde::from_read(&mut stream);
if res.is_ok() {
count += 1;
}
match res {
Ok(x) => latest_timestamp = latest_timestamp.max(x.timestamp),
Err(Error::InvalidDataRead(x)) | Err(Error::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e).context("decode fail")
}
}
Ok((latest_timestamp, count))
}
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();
@ -169,11 +216,30 @@ async fn main() -> Result<()> {
let config = Arc::new(Config {
max_content_length: 1<<23,
input: String::from("./submissions"),
output: String::from("./data.zst"),
output: String::from("./sample.zst"),
backend: String::from("http://localhost:1708"),
mode: OperatingMode::Count
mode: OperatingMode::Sample(0.004),
filename_threshold: None
});
let timestamp_threshold = match config.mode {
OperatingMode::Count => None,
_ => {
match readback_output(&config.output) {
Ok(x) => Some(x),
Err(e) => {
log::warn!("could not read output: {}", e);
None
}
}
}
};
if let Some((threshold, count)) = timestamp_threshold {
log::info!("threshold is {}, {} items", threshold, count);
}
let backend = get_backend_config(&config.backend).await;
log::info!("connected to inference server");
@ -200,7 +266,7 @@ async fn main() -> Result<()> {
let client = client.clone();
let config = config.clone();
let stream = ReceiverStream::new(entries_rx);
stream.map(Ok).try_for_each_concurrent(Some(128), move |entry| {
stream.map(Ok).try_for_each_concurrent(Some(512), move |entry| {
let client = client.clone();
let config = config.clone();
let buffers_tx = buffers_tx.clone();
@ -212,6 +278,7 @@ async fn main() -> Result<()> {
}
match fetch_file(client, config.clone(), &entry.url).await {
Ok(buf) => {
log::debug!("got {}", &entry.url);
buffers_tx.send((entry, buf)).await?;
},
Err(e) => {
@ -235,7 +302,7 @@ async fn main() -> Result<()> {
async move {
let image_result = tokio::task::spawn_blocking(|| {
let csr = Cursor::new(buffer);
let image = ImageReader::new(csr).decode()?;
let image = ImageReader::new(csr).with_guessed_format()?.decode()?;
Result::<DynamicImage, anyhow::Error>::Ok(image)
}).await?;
let image = match image_result {
@ -276,12 +343,11 @@ async fn main() -> Result<()> {
).await.context("querying CLIP server")?;
for (vector, entry) in result.into_iter().zip(entries) {
println!("{:?}", entry);
final_write_tx.send(ProcessedEntry {
url: entry.url,
id: entry.id,
title: entry.title,
subreddit: entry.subreddit,
subreddit: entry.subreddit.unwrap(),
author: entry.author.unwrap(),
blob: vector.into_vec(),
timestamp: entry.created_utc.to_u64()?
@ -303,14 +369,20 @@ async fn main() -> Result<()> {
let mut paths = vec![];
for file in fs::read_dir(&config.input)? {
paths.push(file?.path());
let path = file?.path();
let last_segment = path.file_name().context("invalid file structure")?.to_str().context("non-UTF8 path")?.to_string();
match &config.filename_threshold {
Some(threshold) if threshold >= &last_segment => (),
_ => paths.push(path)
}
}
paths.sort();
let mut file_readers = JoinSet::new();
match config.mode {
OperatingMode::Count | OperatingMode::Sample(_) => {
let mut set = JoinSet::new();
let semaphore = Arc::new(Semaphore::new(cpus));
for path in paths {
@ -319,41 +391,37 @@ async fn main() -> Result<()> {
let entries_tx = entries_tx.clone();
let path_ = path.clone();
log::info!("reading {:?}", path);
set.spawn_blocking(move || {
match process_file(path_, entries_tx) {
file_readers.spawn_blocking(move || {
match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
Ok(_) => (),
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
}
std::mem::drop(permit);
});
}
std::mem::drop(entries_tx);
while let Some(x) = set.try_join_next() {
x?;
}
},
OperatingMode::FullRun => {
for path in paths {
let entries_tx = entries_tx.clone();
let path_ = path.clone();
log::info!("reading {:?}", path);
let c = tokio::task::spawn_blocking(move || process_file(path_, entries_tx)).await?;
match c {
file_readers.spawn_blocking(move || match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
Ok(_) => (),
_ => log::error!("could not parse {:?} {:?}", &path, c)
}
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
});
}
std::mem::drop(entries_tx);
load_task.await??;
}
}
if let Some(task) = resize_task { task.await??; }
if let Some(task) = embedding_generation_task { task.await?? };
if let Some(task) = output_writer_task { task.await?? };
while let Some(x) = file_readers.try_join_next() {
x?;
}
std::mem::drop(entries_tx);
println!("{:?}", load_task.await?);
if let Some(task) = resize_task { println!("{:?}", task.await?); }
if let Some(task) = embedding_generation_task { println!("{:?}", task.await?) };
if let Some(task) = output_writer_task { println!("{:?}", task.await?) };
Ok(())
}