mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
more progress on Reddit
This commit is contained in:
parent
f8d68d9d54
commit
a8329e43fc
12
Cargo.lock
generated
12
Cargo.lock
generated
@ -1122,8 +1122,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "image"
|
||||
version = "0.25.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11"
|
||||
source = "git+https://github.com/fintelia/image/?branch=upgrade-zune-jpeg#54ee15fae5f865f9806bda70c2118e9e572e7deb"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
@ -1141,7 +1140,6 @@ dependencies = [
|
||||
"rayon",
|
||||
"rgb",
|
||||
"tiff",
|
||||
"zune-core",
|
||||
"zune-jpeg",
|
||||
]
|
||||
|
||||
@ -3550,9 +3548,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zune-core"
|
||||
version = "0.4.12"
|
||||
version = "0.5.0-rc1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a"
|
||||
checksum = "e0d1b427373b52a2497c49b0860a5290daab6a0437902ffd8f607367bd5eb7d0"
|
||||
|
||||
[[package]]
|
||||
name = "zune-inflate"
|
||||
@ -3565,9 +3563,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zune-jpeg"
|
||||
version = "0.4.11"
|
||||
version = "0.5.0-rc1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448"
|
||||
checksum = "e4019a3b4e46db2d81faab716e5034dd212f1e105bc55b3f5ca4381dd78736e0"
|
||||
dependencies = [
|
||||
"zune-core",
|
||||
]
|
||||
|
@ -40,6 +40,9 @@ fastrand = "2"
|
||||
mimalloc = "0.1"
|
||||
sonic-rs = "0.3"
|
||||
|
||||
[patch.crates-io]
|
||||
image = { git = "https://github.com/fintelia/image/", branch = "upgrade-zune-jpeg" }
|
||||
|
||||
[[bin]]
|
||||
name = "reddit-dump"
|
||||
path = "src/reddit_dump.rs"
|
@ -1,9 +1,9 @@
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use common::resize_for_embed;
|
||||
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, path::PathBuf, time::Duration, sync::Arc, str::FromStr};
|
||||
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, time::Duration, sync::Arc, str::FromStr, path::PathBuf};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use lazy_static::lazy_static;
|
||||
use regex::{RegexSet, bytes};
|
||||
use regex::{RegexSet, bytes, Regex};
|
||||
use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use reqwest::Client;
|
||||
@ -43,7 +43,7 @@ struct Entry {
|
||||
title: String,
|
||||
author: Option<String>,
|
||||
selftext: String,
|
||||
subreddit: String,
|
||||
subreddit: Option<String>,
|
||||
created_utc: BadTimestampFormat,
|
||||
#[serde(default="function_which_returns_some_na")]
|
||||
post_hint: Option<String>,
|
||||
@ -70,6 +70,11 @@ lazy_static! {
|
||||
r"\.aspx?",
|
||||
r"\.xml",
|
||||
r"//youtube\.com",
|
||||
r"/rss/",
|
||||
r"//vimeo\.com",
|
||||
r"//www\.youtube\.com",
|
||||
r"//youtu\.be",
|
||||
r"//www\.reddit\.com",
|
||||
// TODO fill in more things, maybe try and collect thumbnails or something
|
||||
]).unwrap();
|
||||
static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
|
||||
@ -80,9 +85,13 @@ lazy_static! {
|
||||
r#""domain":"self.promos""#, // .......
|
||||
r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record
|
||||
]).unwrap();
|
||||
static ref URL_REPLACEMENT_RULES: Vec<(Regex, &'static str)> = [
|
||||
(r"//imgur.com/([A-Za-z0-9]+)", r"//i.imgur.com/$1.jpg"),
|
||||
(r"^http://", r"https://")
|
||||
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), e)).collect();
|
||||
}
|
||||
|
||||
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>) -> Result<()> {
|
||||
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Option<u64>) -> Result<()> {
|
||||
let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||
stream.window_log_max(31)?;
|
||||
let mut stream = BufReader::new(stream);
|
||||
@ -103,11 +112,23 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>) -> Result<()> {
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() {
|
||||
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() && entry.subreddit.is_some() {
|
||||
if !URL_IGNORE.is_match(&entry.url) {
|
||||
match &entry.post_hint {
|
||||
Some(x) if x == "na" || x == "image" => {
|
||||
tx.blocking_send(entry)?;
|
||||
// Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine".
|
||||
let after_threshold = match timestamp_threshold {
|
||||
Some(threshold) => {
|
||||
let timestamp = match &entry.created_utc {
|
||||
BadTimestampFormat::Int(x) => *x,
|
||||
BadTimestampFormat::String(s) => u64::from_str(s).unwrap()
|
||||
};
|
||||
timestamp > threshold
|
||||
},
|
||||
None => true
|
||||
};
|
||||
|
||||
if after_threshold { tx.blocking_send(entry)?; }
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
@ -123,11 +144,17 @@ struct Config {
|
||||
input: String,
|
||||
output: String,
|
||||
backend: String,
|
||||
mode: OperatingMode
|
||||
mode: OperatingMode,
|
||||
filename_threshold: Option<String>
|
||||
}
|
||||
|
||||
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<Vec<u8>> {
|
||||
let mut response = client.get(url).send().await?;
|
||||
// inelegant but I can't get it to work using Cows
|
||||
let mut url = url.to_string();
|
||||
for (regex, replacement) in URL_REPLACEMENT_RULES.iter() {
|
||||
url = regex.replace(&url, *replacement).to_string();
|
||||
}
|
||||
let mut response = client.get(&*url).send().await?;
|
||||
if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) {
|
||||
return Err(anyhow!("invalid Content-Type"));
|
||||
}
|
||||
@ -146,7 +173,7 @@ async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) ->
|
||||
}
|
||||
|
||||
fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result<()> {
|
||||
let mut out = fs::File::create(&config.output)?;
|
||||
let mut out = fs::File::options().append(true).open(&config.output)?;
|
||||
let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish();
|
||||
let mut buf_stream = BufWriter::new(stream);
|
||||
while let Some(x) = rx.blocking_recv() {
|
||||
@ -161,6 +188,26 @@ enum OperatingMode {
|
||||
FullRun
|
||||
}
|
||||
|
||||
fn readback_output(path: &str) -> Result<(u64, usize)> {
|
||||
use rmp_serde::decode::Error;
|
||||
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||
let mut stream = BufReader::new(stream);
|
||||
let mut latest_timestamp = 0;
|
||||
let mut count = 0;
|
||||
loop {
|
||||
let res: Result<ProcessedEntry, Error> = rmp_serde::from_read(&mut stream);
|
||||
if res.is_ok() {
|
||||
count += 1;
|
||||
}
|
||||
match res {
|
||||
Ok(x) => latest_timestamp = latest_timestamp.max(x.timestamp),
|
||||
Err(Error::InvalidDataRead(x)) | Err(Error::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||
Err(e) => return Err(e).context("decode fail")
|
||||
}
|
||||
}
|
||||
Ok((latest_timestamp, count))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
pretty_env_logger::init();
|
||||
@ -169,11 +216,30 @@ async fn main() -> Result<()> {
|
||||
let config = Arc::new(Config {
|
||||
max_content_length: 1<<23,
|
||||
input: String::from("./submissions"),
|
||||
output: String::from("./data.zst"),
|
||||
output: String::from("./sample.zst"),
|
||||
backend: String::from("http://localhost:1708"),
|
||||
mode: OperatingMode::Count
|
||||
mode: OperatingMode::Sample(0.004),
|
||||
filename_threshold: None
|
||||
});
|
||||
|
||||
let timestamp_threshold = match config.mode {
|
||||
OperatingMode::Count => None,
|
||||
_ => {
|
||||
match readback_output(&config.output) {
|
||||
Ok(x) => Some(x),
|
||||
Err(e) => {
|
||||
log::warn!("could not read output: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
if let Some((threshold, count)) = timestamp_threshold {
|
||||
log::info!("threshold is {}, {} items", threshold, count);
|
||||
}
|
||||
|
||||
let backend = get_backend_config(&config.backend).await;
|
||||
|
||||
log::info!("connected to inference server");
|
||||
@ -200,7 +266,7 @@ async fn main() -> Result<()> {
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let stream = ReceiverStream::new(entries_rx);
|
||||
stream.map(Ok).try_for_each_concurrent(Some(128), move |entry| {
|
||||
stream.map(Ok).try_for_each_concurrent(Some(512), move |entry| {
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let buffers_tx = buffers_tx.clone();
|
||||
@ -212,6 +278,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
match fetch_file(client, config.clone(), &entry.url).await {
|
||||
Ok(buf) => {
|
||||
log::debug!("got {}", &entry.url);
|
||||
buffers_tx.send((entry, buf)).await?;
|
||||
},
|
||||
Err(e) => {
|
||||
@ -235,7 +302,7 @@ async fn main() -> Result<()> {
|
||||
async move {
|
||||
let image_result = tokio::task::spawn_blocking(|| {
|
||||
let csr = Cursor::new(buffer);
|
||||
let image = ImageReader::new(csr).decode()?;
|
||||
let image = ImageReader::new(csr).with_guessed_format()?.decode()?;
|
||||
Result::<DynamicImage, anyhow::Error>::Ok(image)
|
||||
}).await?;
|
||||
let image = match image_result {
|
||||
@ -276,12 +343,11 @@ async fn main() -> Result<()> {
|
||||
).await.context("querying CLIP server")?;
|
||||
|
||||
for (vector, entry) in result.into_iter().zip(entries) {
|
||||
println!("{:?}", entry);
|
||||
final_write_tx.send(ProcessedEntry {
|
||||
url: entry.url,
|
||||
id: entry.id,
|
||||
title: entry.title,
|
||||
subreddit: entry.subreddit,
|
||||
subreddit: entry.subreddit.unwrap(),
|
||||
author: entry.author.unwrap(),
|
||||
blob: vector.into_vec(),
|
||||
timestamp: entry.created_utc.to_u64()?
|
||||
@ -303,14 +369,20 @@ async fn main() -> Result<()> {
|
||||
|
||||
let mut paths = vec![];
|
||||
for file in fs::read_dir(&config.input)? {
|
||||
paths.push(file?.path());
|
||||
let path = file?.path();
|
||||
let last_segment = path.file_name().context("invalid file structure")?.to_str().context("non-UTF8 path")?.to_string();
|
||||
match &config.filename_threshold {
|
||||
Some(threshold) if threshold >= &last_segment => (),
|
||||
_ => paths.push(path)
|
||||
}
|
||||
}
|
||||
|
||||
paths.sort();
|
||||
|
||||
let mut file_readers = JoinSet::new();
|
||||
|
||||
match config.mode {
|
||||
OperatingMode::Count | OperatingMode::Sample(_) => {
|
||||
let mut set = JoinSet::new();
|
||||
let semaphore = Arc::new(Semaphore::new(cpus));
|
||||
|
||||
for path in paths {
|
||||
@ -319,41 +391,37 @@ async fn main() -> Result<()> {
|
||||
let entries_tx = entries_tx.clone();
|
||||
let path_ = path.clone();
|
||||
log::info!("reading {:?}", path);
|
||||
set.spawn_blocking(move || {
|
||||
match process_file(path_, entries_tx) {
|
||||
file_readers.spawn_blocking(move || {
|
||||
match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
|
||||
Ok(_) => (),
|
||||
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
|
||||
}
|
||||
std::mem::drop(permit);
|
||||
});
|
||||
}
|
||||
|
||||
std::mem::drop(entries_tx);
|
||||
|
||||
while let Some(x) = set.try_join_next() {
|
||||
x?;
|
||||
}
|
||||
},
|
||||
OperatingMode::FullRun => {
|
||||
for path in paths {
|
||||
let entries_tx = entries_tx.clone();
|
||||
let path_ = path.clone();
|
||||
log::info!("reading {:?}", path);
|
||||
let c = tokio::task::spawn_blocking(move || process_file(path_, entries_tx)).await?;
|
||||
match c {
|
||||
file_readers.spawn_blocking(move || match process_file(path_, entries_tx, timestamp_threshold.map(|(x, _)| x)) {
|
||||
Ok(_) => (),
|
||||
_ => log::error!("could not parse {:?} {:?}", &path, c)
|
||||
}
|
||||
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
|
||||
});
|
||||
}
|
||||
|
||||
std::mem::drop(entries_tx);
|
||||
|
||||
load_task.await??;
|
||||
}
|
||||
}
|
||||
if let Some(task) = resize_task { task.await??; }
|
||||
if let Some(task) = embedding_generation_task { task.await?? };
|
||||
if let Some(task) = output_writer_task { task.await?? };
|
||||
|
||||
while let Some(x) = file_readers.try_join_next() {
|
||||
x?;
|
||||
}
|
||||
|
||||
std::mem::drop(entries_tx);
|
||||
println!("{:?}", load_task.await?);
|
||||
if let Some(task) = resize_task { println!("{:?}", task.await?); }
|
||||
if let Some(task) = embedding_generation_task { println!("{:?}", task.await?) };
|
||||
if let Some(task) = output_writer_task { println!("{:?}", task.await?) };
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user