mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
WIP Reddit dump loader
This commit is contained in:
parent
978aadda6a
commit
f8d68d9d54
134
Cargo.lock
generated
134
Cargo.lock
generated
@ -89,6 +89,12 @@ dependencies = [
|
||||
"syn 2.0.65",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrayref"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.4"
|
||||
@ -662,6 +668,17 @@ version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
|
||||
|
||||
[[package]]
|
||||
name = "faststr"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f375fcf41ec4dac873a8028fba4210dbda5c86bba13d2d741e651b474f7c05a4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"serde",
|
||||
"simdutf8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.4"
|
||||
@ -1238,7 +1255,7 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
"winapi 0.2.8",
|
||||
"winapi-build",
|
||||
]
|
||||
|
||||
@ -1280,6 +1297,16 @@ version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libmimalloc-sys"
|
||||
version = "0.1.38"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.27.0"
|
||||
@ -1373,6 +1400,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"faiss",
|
||||
"fastrand",
|
||||
"fnv",
|
||||
"futures-util",
|
||||
"half",
|
||||
@ -1380,6 +1408,7 @@ dependencies = [
|
||||
"json5",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"mimalloc",
|
||||
"ndarray",
|
||||
"num_cpus",
|
||||
"pretty_env_logger",
|
||||
@ -1390,12 +1419,24 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
"serde_json",
|
||||
"sonic-rs",
|
||||
"sqlx",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"url",
|
||||
"walkdir",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mimalloc"
|
||||
version = "0.1.42"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176"
|
||||
dependencies = [
|
||||
"libmimalloc-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1665,6 +1706,16 @@ dependencies = [
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "page_size"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi 0.3.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.2"
|
||||
@ -2216,7 +2267,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d931a44fdaa43b8637009e7632a02adc4f2b2e0733c08caa4cf00e8da4a117a7"
|
||||
dependencies = [
|
||||
"kernel32-sys",
|
||||
"winapi",
|
||||
"winapi 0.2.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2384,6 +2435,12 @@ dependencies = [
|
||||
"quote",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simdutf8"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
@ -2409,6 +2466,27 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sonic-rs"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "244d3cdf9dd4e2a5c63991a6ed4ecd768959204d5e4a839181ee997e8c149407"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"bumpalo",
|
||||
"bytes",
|
||||
"cfg-if",
|
||||
"faststr",
|
||||
"itoa",
|
||||
"page_size",
|
||||
"parking_lot",
|
||||
"ryu",
|
||||
"serde",
|
||||
"simdutf8",
|
||||
"smallvec",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.5.2"
|
||||
@ -3090,7 +3168,7 @@ checksum = "bb08f9e670fab86099470b97cd2b252d6527f0b3cc1401acdb595ffc9dd288ff"
|
||||
dependencies = [
|
||||
"kernel32-sys",
|
||||
"same-file",
|
||||
"winapi",
|
||||
"winapi 0.2.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3212,12 +3290,28 @@ version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "167dc9d6949a9b857f3451275e911c3f44255842c1f7a76f33c55103a909087a"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-build"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d315eee3b34aca4797b2da6b13ed88266e6d612562a0c46390af8299fc699bc"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.8"
|
||||
@ -3227,6 +3321,12 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.52.0"
|
||||
@ -3420,6 +3520,34 @@ version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a"
|
||||
dependencies = [
|
||||
"zstd-safe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-safe"
|
||||
version = "7.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a"
|
||||
dependencies = [
|
||||
"zstd-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-sys"
|
||||
version = "2.0.10+zstd.1.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zune-core"
|
||||
version = "0.4.12"
|
||||
|
11
Cargo.toml
11
Cargo.toml
@ -33,4 +33,13 @@ tower-http = { version = "0.5", features = ["cors"] }
|
||||
tower = "0.4"
|
||||
json5 = "0.4"
|
||||
prometheus = "0.13"
|
||||
lazy_static = "1"
|
||||
lazy_static = "1"
|
||||
zstd = "0.13"
|
||||
url = "2"
|
||||
fastrand = "2"
|
||||
mimalloc = "0.1"
|
||||
sonic-rs = "0.3"
|
||||
|
||||
[[bin]]
|
||||
name = "reddit-dump"
|
||||
path = "src/reddit_dump.rs"
|
64
src/common.rs
Normal file
64
src/common.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::borrow::Borrow;
|
||||
use image::{DynamicImage, imageops::FilterType, ImageFormat};
|
||||
use anyhow::Result;
|
||||
use std::io::Cursor;
|
||||
use reqwest::Client;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct InferenceServerConfig {
|
||||
pub batch: usize,
|
||||
pub image_size: (u32, u32),
|
||||
pub embedding_size: usize,
|
||||
}
|
||||
|
||||
pub async fn resize_for_embed<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
|
||||
let resized = tokio::task::spawn_blocking(move || {
|
||||
let new = image.borrow().resize(
|
||||
config.image_size.0,
|
||||
config.image_size.1,
|
||||
FilterType::Lanczos3
|
||||
);
|
||||
let mut buf = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
new.write_to(&mut csr, ImageFormat::Png)?;
|
||||
Ok::<Vec<u8>, anyhow::Error>(buf)
|
||||
}).await??;
|
||||
Ok(resized)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum EmbeddingRequest {
|
||||
Images { images: Vec<serde_bytes::ByteBuf> },
|
||||
Text { text: Vec<String> }
|
||||
}
|
||||
|
||||
async fn fetch_backend_config(base_url: &str) -> Result<InferenceServerConfig> {
|
||||
let res = Client::new().get(&format!("{}/config", base_url)).send().await?;
|
||||
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
|
||||
}
|
||||
|
||||
pub async fn get_backend_config(clip_server: &str) -> InferenceServerConfig {
|
||||
loop {
|
||||
match fetch_backend_config(&clip_server).await {
|
||||
Ok(backend) => break backend,
|
||||
Err(e) => {
|
||||
log::error!("Backend failed (fetch): {}", e);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn query_clip_server<I, O>(client: &Client, base_url: &str, path: &str, data: I) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
|
||||
{
|
||||
let response = client
|
||||
.post(&format!("{}{}", base_url, path))
|
||||
.header("Content-Type", "application/msgpack")
|
||||
.body(rmp_serde::to_vec_named(&data)?)
|
||||
.send()
|
||||
.await?;
|
||||
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
|
||||
Ok(result)
|
||||
}
|
71
src/main.rs
71
src/main.rs
@ -30,8 +30,10 @@ use prometheus::{register_int_counter, register_int_counter_vec, register_int_ga
|
||||
use ndarray::ArrayBase;
|
||||
|
||||
mod ocr;
|
||||
mod common;
|
||||
|
||||
use crate::ocr::scan_image;
|
||||
use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server};
|
||||
|
||||
lazy_static! {
|
||||
static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap();
|
||||
@ -126,13 +128,6 @@ struct FileRecord {
|
||||
thumbnails: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct InferenceServerConfig {
|
||||
batch: usize,
|
||||
image_size: (u32, u32),
|
||||
embedding_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct WConfig {
|
||||
backend: InferenceServerConfig,
|
||||
@ -140,23 +135,6 @@ struct WConfig {
|
||||
predefined_embeddings: HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::prelude::Dim<[usize; 1]>>>
|
||||
}
|
||||
|
||||
async fn query_clip_server<I, O>(
|
||||
client: &Client,
|
||||
config: &Config,
|
||||
path: &str,
|
||||
data: I,
|
||||
) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
|
||||
{
|
||||
let response = client
|
||||
.post(&format!("{}{}", config.clip_server, path))
|
||||
.header("Content-Type", "application/msgpack")
|
||||
.body(rmp_serde::to_vec_named(&data)?)
|
||||
.send()
|
||||
.await?;
|
||||
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LoadedImage {
|
||||
image: Arc<DynamicImage>,
|
||||
@ -170,13 +148,6 @@ struct EmbeddingInput {
|
||||
filename: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum EmbeddingRequest {
|
||||
Images { images: Vec<serde_bytes::ByteBuf> },
|
||||
Text { text: Vec<String> }
|
||||
}
|
||||
|
||||
fn timestamp() -> i64 {
|
||||
chrono::Utc::now().timestamp_micros()
|
||||
}
|
||||
@ -274,21 +245,6 @@ fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
|
||||
formats
|
||||
}
|
||||
|
||||
async fn resize_for_embed(config: Arc<WConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> {
|
||||
let resized = tokio::task::spawn_blocking(move || {
|
||||
let new = image.resize(
|
||||
config.backend.image_size.0,
|
||||
config.backend.image_size.1,
|
||||
FilterType::Lanczos3
|
||||
);
|
||||
let mut buf = Vec::new();
|
||||
let mut csr = Cursor::new(&mut buf);
|
||||
new.write_to(&mut csr, ImageFormat::Png)?;
|
||||
Ok::<Vec<u8>, anyhow::Error>(buf)
|
||||
}).await??;
|
||||
Ok(resized)
|
||||
}
|
||||
|
||||
async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
let pool = initialize_database(&config.service).await?;
|
||||
let client = Client::new();
|
||||
@ -324,7 +280,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
};
|
||||
IMAGES_LOADED_COUNTER.inc();
|
||||
if record.embedding.is_none() {
|
||||
let resized = resize_for_embed(config.clone(), image.clone()).await?;
|
||||
let resized = resize_for_embed(config.backend.clone(), image.clone()).await?;
|
||||
|
||||
to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await?
|
||||
}
|
||||
@ -505,7 +461,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
|
||||
async move {
|
||||
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
|
||||
&client,
|
||||
&config.service,
|
||||
&config.service.clip_server,
|
||||
"",
|
||||
EmbeddingRequest::Images {
|
||||
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
|
||||
@ -753,7 +709,7 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
||||
TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc();
|
||||
let bytes = BASE64_STANDARD.decode(image)?;
|
||||
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?);
|
||||
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.clone(), image).await?));
|
||||
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.backend.clone(), image).await?));
|
||||
image_weights.push(term.weight.unwrap_or(1.0));
|
||||
}
|
||||
if let Some(text) = &term.text {
|
||||
@ -792,7 +748,7 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
||||
}
|
||||
|
||||
for batch in batches {
|
||||
let embs: Vec<Vec<u8>> = query_clip_server(&client, &config.service, "/", batch).await?;
|
||||
let embs: Vec<Vec<u8>> = query_clip_server(&client, &config.service.clip_server, "/", batch).await?;
|
||||
for emb in embs {
|
||||
total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb));
|
||||
}
|
||||
@ -813,11 +769,6 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
||||
}).into_response())
|
||||
}
|
||||
|
||||
async fn get_backend_config(config: &Config) -> Result<InferenceServerConfig> {
|
||||
let res = Client::new().get(&format!("{}/config", config.clip_server)).send().await?;
|
||||
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct FrontendInit {
|
||||
n_total: u64,
|
||||
@ -834,15 +785,7 @@ async fn main() -> Result<()> {
|
||||
let pool = initialize_database(&config).await?;
|
||||
sqlx::query(SCHEMA).execute(&pool).await?;
|
||||
|
||||
let backend = loop {
|
||||
match get_backend_config(&config).await {
|
||||
Ok(backend) => break backend,
|
||||
Err(e) => {
|
||||
log::error!("Backend failed (fetch): {}", e);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
};
|
||||
let backend = get_backend_config(&config.clip_server).await;
|
||||
|
||||
let mut predefined_embeddings = HashMap::new();
|
||||
|
||||
|
359
src/reddit_dump.rs
Normal file
359
src/reddit_dump.rs
Normal file
@ -0,0 +1,359 @@
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use common::resize_for_embed;
|
||||
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, path::PathBuf, time::Duration, sync::Arc, str::FromStr};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use lazy_static::lazy_static;
|
||||
use regex::{RegexSet, bytes};
|
||||
use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use reqwest::Client;
|
||||
use futures_util::stream::{StreamExt, TryStreamExt};
|
||||
use image::{DynamicImage, io::Reader as ImageReader};
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
|
||||
mod common;
|
||||
|
||||
use crate::common::{get_backend_config, query_clip_server, EmbeddingRequest};
|
||||
|
||||
fn function_which_returns_some_na() -> Option<String> { Some(String::from("na")) }
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
enum BadTimestampFormat {
|
||||
Int(u64),
|
||||
String(String)
|
||||
}
|
||||
|
||||
impl BadTimestampFormat {
|
||||
fn to_u64(&self) -> Result<u64> {
|
||||
match self {
|
||||
BadTimestampFormat::Int(x) => Ok(*x),
|
||||
BadTimestampFormat::String(x) => u64::from_str(&x).context("invalid string")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
struct Entry {
|
||||
url: String,
|
||||
over_18: bool,
|
||||
title: String,
|
||||
author: Option<String>,
|
||||
selftext: String,
|
||||
subreddit: String,
|
||||
created_utc: BadTimestampFormat,
|
||||
#[serde(default="function_which_returns_some_na")]
|
||||
post_hint: Option<String>,
|
||||
id: String
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
struct ProcessedEntry {
|
||||
url: String,
|
||||
id: String,
|
||||
title: String,
|
||||
subreddit: String,
|
||||
author: String,
|
||||
timestamp: u64,
|
||||
blob: Vec<u8>
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref URL_IGNORE: RegexSet = RegexSet::new([
|
||||
r"//reddit\.com",
|
||||
r"\.html?",
|
||||
r"\.php",
|
||||
r"\?articleid=",
|
||||
r"\.aspx?",
|
||||
r"\.xml",
|
||||
r"//youtube\.com",
|
||||
// TODO fill in more things, maybe try and collect thumbnails or something
|
||||
]).unwrap();
|
||||
static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
|
||||
.into_iter().map(str::as_bytes).collect();
|
||||
static ref OBJECT_HACKY_IGNORE: bytes::RegexSet = bytes::RegexSet::new([
|
||||
r#""author":"\[deleted\]""#,
|
||||
r#""promoted":true"#, // these seem to be ads which are in the data for some reason, and lack some important fields
|
||||
r#""domain":"self.promos""#, // .......
|
||||
r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record
|
||||
]).unwrap();
|
||||
}
|
||||
|
||||
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>) -> Result<()> {
|
||||
let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
||||
stream.window_log_max(31)?;
|
||||
let mut stream = BufReader::new(stream);
|
||||
let mut buf = Vec::new();
|
||||
loop {
|
||||
if stream.read_until(0x0A, &mut buf)? == 0 {
|
||||
break
|
||||
}
|
||||
// we would discard these later, but they have a different format so they can't be deserialized straight into Entries
|
||||
if OBJECT_HACKY_IGNORE.is_match(&buf) {
|
||||
buf.clear();
|
||||
continue;
|
||||
}
|
||||
let entry = match sonic_rs::serde::from_slice::<Entry>(buf.as_slice()) {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
log::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf));
|
||||
return Ok(())
|
||||
}
|
||||
};
|
||||
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() {
|
||||
if !URL_IGNORE.is_match(&entry.url) {
|
||||
match &entry.post_hint {
|
||||
Some(x) if x == "na" || x == "image" => {
|
||||
tx.blocking_send(entry)?;
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
buf.clear();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct Config {
|
||||
max_content_length: usize,
|
||||
input: String,
|
||||
output: String,
|
||||
backend: String,
|
||||
mode: OperatingMode
|
||||
}
|
||||
|
||||
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<Vec<u8>> {
|
||||
let mut response = client.get(url).send().await?;
|
||||
if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) {
|
||||
return Err(anyhow!("invalid Content-Type"));
|
||||
}
|
||||
match response.content_length() {
|
||||
Some(x) if x > (config.max_content_length as u64) => return Err(anyhow!("response too large")),
|
||||
_ => ()
|
||||
}
|
||||
let mut buffer = vec![];
|
||||
while let Some(chunk) = response.chunk().await? {
|
||||
buffer.extend(chunk);
|
||||
if buffer.len() > config.max_content_length {
|
||||
return Err(anyhow!("response too large"));
|
||||
}
|
||||
}
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result<()> {
|
||||
let mut out = fs::File::create(&config.output)?;
|
||||
let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish();
|
||||
let mut buf_stream = BufWriter::new(stream);
|
||||
while let Some(x) = rx.blocking_recv() {
|
||||
rmp_serde::encode::write(&mut buf_stream, &x)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum OperatingMode {
|
||||
Count,
|
||||
Sample(f32),
|
||||
FullRun
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
pretty_env_logger::init();
|
||||
let cpus = num_cpus::get();
|
||||
|
||||
let config = Arc::new(Config {
|
||||
max_content_length: 1<<23,
|
||||
input: String::from("./submissions"),
|
||||
output: String::from("./data.zst"),
|
||||
backend: String::from("http://localhost:1708"),
|
||||
mode: OperatingMode::Count
|
||||
});
|
||||
|
||||
let backend = get_backend_config(&config.backend).await;
|
||||
|
||||
log::info!("connected to inference server");
|
||||
|
||||
let (entries_tx, mut entries_rx) = mpsc::channel::<Entry>(32768);
|
||||
let (buffers_tx, buffers_rx) = mpsc::channel(128);
|
||||
let (resized_tx, resized_rx) = mpsc::channel(backend.batch);
|
||||
let (final_write_tx, final_write_rx) = mpsc::channel::<ProcessedEntry>(32768);
|
||||
let client = Client::builder()
|
||||
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")))
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
let load_task: JoinHandle<Result<()>> = match config.mode {
|
||||
OperatingMode::Count => tokio::task::spawn(async move {
|
||||
let mut counter = 0;
|
||||
while let Some(_) = entries_rx.recv().await {
|
||||
counter += 1
|
||||
}
|
||||
println!("{}", counter);
|
||||
Ok(())
|
||||
}),
|
||||
_ => tokio::task::spawn({
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let stream = ReceiverStream::new(entries_rx);
|
||||
stream.map(Ok).try_for_each_concurrent(Some(128), move |entry| {
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let buffers_tx = buffers_tx.clone();
|
||||
async move {
|
||||
if let OperatingMode::Sample(rate) = config.mode {
|
||||
if fastrand::f32() > rate {
|
||||
return Ok(())
|
||||
}
|
||||
}
|
||||
match fetch_file(client, config.clone(), &entry.url).await {
|
||||
Ok(buf) => {
|
||||
buffers_tx.send((entry, buf)).await?;
|
||||
},
|
||||
Err(e) => {
|
||||
log::warn!("{} failed: {}", &entry.url, e)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
)})
|
||||
};
|
||||
|
||||
let resize_task = match config.mode {
|
||||
OperatingMode::Count => None,
|
||||
_ => Some(tokio::task::spawn({
|
||||
let stream = ReceiverStream::new(buffers_rx);
|
||||
let backend = backend.clone();
|
||||
stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, buffer)| {
|
||||
let backend = backend.clone();
|
||||
let resized_tx = resized_tx.clone();
|
||||
async move {
|
||||
let image_result = tokio::task::spawn_blocking(|| {
|
||||
let csr = Cursor::new(buffer);
|
||||
let image = ImageReader::new(csr).decode()?;
|
||||
Result::<DynamicImage, anyhow::Error>::Ok(image)
|
||||
}).await?;
|
||||
let image = match image_result {
|
||||
Ok(image) => image,
|
||||
Err(e) => {
|
||||
log::warn!("loading {} failed: {}", entry.url, e);
|
||||
return Result::<(), anyhow::Error>::Ok(());
|
||||
}
|
||||
};
|
||||
let resized = resize_for_embed(backend.clone(), image).await?;
|
||||
resized_tx.send((entry, resized)).await?;
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}))
|
||||
};
|
||||
|
||||
let embedding_generation_task: Option<JoinHandle<Result<()>>> = match config.mode {
|
||||
OperatingMode::Count => None,
|
||||
_ => Some(tokio::spawn({
|
||||
let stream = ReceiverStream::new(resized_rx).chunks(backend.batch);
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
// keep multiple embedding requests in flight
|
||||
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
|
||||
let (entries, bytes): (Vec<Entry>, Vec<Vec<u8>>) = batch.into_iter().unzip();
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let final_write_tx = final_write_tx.clone();
|
||||
async move {
|
||||
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
|
||||
&client,
|
||||
&config.backend,
|
||||
"",
|
||||
EmbeddingRequest::Images {
|
||||
images: bytes.into_iter().map(serde_bytes::ByteBuf::from).collect(),
|
||||
},
|
||||
).await.context("querying CLIP server")?;
|
||||
|
||||
for (vector, entry) in result.into_iter().zip(entries) {
|
||||
println!("{:?}", entry);
|
||||
final_write_tx.send(ProcessedEntry {
|
||||
url: entry.url,
|
||||
id: entry.id,
|
||||
title: entry.title,
|
||||
subreddit: entry.subreddit,
|
||||
author: entry.author.unwrap(),
|
||||
blob: vector.into_vec(),
|
||||
timestamp: entry.created_utc.to_u64()?
|
||||
}).await?;
|
||||
}
|
||||
anyhow::Result::Ok(())
|
||||
}
|
||||
})
|
||||
}))
|
||||
};
|
||||
|
||||
let config_ = config.clone();
|
||||
let output_writer_task = match config.mode {
|
||||
OperatingMode::Sample(_) | OperatingMode::FullRun => Some(tokio::task::spawn_blocking(move || write_output(config_, final_write_rx))),
|
||||
_ => None
|
||||
};
|
||||
|
||||
log::info!("working...");
|
||||
|
||||
let mut paths = vec![];
|
||||
for file in fs::read_dir(&config.input)? {
|
||||
paths.push(file?.path());
|
||||
}
|
||||
|
||||
paths.sort();
|
||||
|
||||
match config.mode {
|
||||
OperatingMode::Count | OperatingMode::Sample(_) => {
|
||||
let mut set = JoinSet::new();
|
||||
let semaphore = Arc::new(Semaphore::new(cpus));
|
||||
|
||||
for path in paths {
|
||||
let semaphore = semaphore.clone();
|
||||
let permit = semaphore.acquire_owned().await?;
|
||||
let entries_tx = entries_tx.clone();
|
||||
let path_ = path.clone();
|
||||
log::info!("reading {:?}", path);
|
||||
set.spawn_blocking(move || {
|
||||
match process_file(path_, entries_tx) {
|
||||
Ok(_) => (),
|
||||
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
|
||||
}
|
||||
std::mem::drop(permit);
|
||||
});
|
||||
}
|
||||
|
||||
std::mem::drop(entries_tx);
|
||||
|
||||
while let Some(x) = set.try_join_next() {
|
||||
x?;
|
||||
}
|
||||
},
|
||||
OperatingMode::FullRun => {
|
||||
for path in paths {
|
||||
let entries_tx = entries_tx.clone();
|
||||
let path_ = path.clone();
|
||||
log::info!("reading {:?}", path);
|
||||
let c = tokio::task::spawn_blocking(move || process_file(path_, entries_tx)).await?;
|
||||
match c {
|
||||
Ok(_) => (),
|
||||
_ => log::error!("could not parse {:?} {:?}", &path, c)
|
||||
}
|
||||
}
|
||||
|
||||
std::mem::drop(entries_tx);
|
||||
|
||||
load_task.await??;
|
||||
}
|
||||
}
|
||||
if let Some(task) = resize_task { task.await??; }
|
||||
if let Some(task) = embedding_generation_task { task.await?? };
|
||||
if let Some(task) = output_writer_task { task.await?? };
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user