1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00

WIP Reddit dump loader

This commit is contained in:
osmarks 2024-05-24 17:47:18 +01:00
parent 978aadda6a
commit f8d68d9d54
5 changed files with 571 additions and 68 deletions

134
Cargo.lock generated
View File

@ -89,6 +89,12 @@ dependencies = [
"syn 2.0.65",
]
[[package]]
name = "arrayref"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545"
[[package]]
name = "arrayvec"
version = "0.7.4"
@ -662,6 +668,17 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
[[package]]
name = "faststr"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f375fcf41ec4dac873a8028fba4210dbda5c86bba13d2d741e651b474f7c05a4"
dependencies = [
"bytes",
"serde",
"simdutf8",
]
[[package]]
name = "fdeflate"
version = "0.3.4"
@ -1238,7 +1255,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d"
dependencies = [
"winapi",
"winapi 0.2.8",
"winapi-build",
]
@ -1280,6 +1297,16 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libmimalloc-sys"
version = "0.1.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "libsqlite3-sys"
version = "0.27.0"
@ -1373,6 +1400,7 @@ dependencies = [
"base64 0.22.1",
"chrono",
"faiss",
"fastrand",
"fnv",
"futures-util",
"half",
@ -1380,6 +1408,7 @@ dependencies = [
"json5",
"lazy_static",
"log",
"mimalloc",
"ndarray",
"num_cpus",
"pretty_env_logger",
@ -1390,12 +1419,24 @@ dependencies = [
"serde",
"serde_bytes",
"serde_json",
"sonic-rs",
"sqlx",
"tokio",
"tokio-stream",
"tower",
"tower-http",
"url",
"walkdir",
"zstd",
]
[[package]]
name = "mimalloc"
version = "0.1.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176"
dependencies = [
"libmimalloc-sys",
]
[[package]]
@ -1665,6 +1706,16 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "page_size"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da"
dependencies = [
"libc",
"winapi 0.3.9",
]
[[package]]
name = "parking_lot"
version = "0.12.2"
@ -2216,7 +2267,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d931a44fdaa43b8637009e7632a02adc4f2b2e0733c08caa4cf00e8da4a117a7"
dependencies = [
"kernel32-sys",
"winapi",
"winapi 0.2.8",
]
[[package]]
@ -2384,6 +2435,12 @@ dependencies = [
"quote",
]
[[package]]
name = "simdutf8"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
[[package]]
name = "slab"
version = "0.4.9"
@ -2409,6 +2466,27 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "sonic-rs"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "244d3cdf9dd4e2a5c63991a6ed4ecd768959204d5e4a839181ee997e8c149407"
dependencies = [
"arrayref",
"bumpalo",
"bytes",
"cfg-if",
"faststr",
"itoa",
"page_size",
"parking_lot",
"ryu",
"serde",
"simdutf8",
"smallvec",
"thiserror",
]
[[package]]
name = "spin"
version = "0.5.2"
@ -3090,7 +3168,7 @@ checksum = "bb08f9e670fab86099470b97cd2b252d6527f0b3cc1401acdb595ffc9dd288ff"
dependencies = [
"kernel32-sys",
"same-file",
"winapi",
"winapi 0.2.8",
]
[[package]]
@ -3212,12 +3290,28 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "167dc9d6949a9b857f3451275e911c3f44255842c1f7a76f33c55103a909087a"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-build"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d315eee3b34aca4797b2da6b13ed88266e6d612562a0c46390af8299fc699bc"
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.8"
@ -3227,6 +3321,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.52.0"
@ -3420,6 +3520,34 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
[[package]]
name = "zstd"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a"
dependencies = [
"zstd-safe",
]
[[package]]
name = "zstd-safe"
version = "7.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a"
dependencies = [
"zstd-sys",
]
[[package]]
name = "zstd-sys"
version = "2.0.10+zstd.1.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa"
dependencies = [
"cc",
"pkg-config",
]
[[package]]
name = "zune-core"
version = "0.4.12"

View File

@ -33,4 +33,13 @@ tower-http = { version = "0.5", features = ["cors"] }
tower = "0.4"
json5 = "0.4"
prometheus = "0.13"
lazy_static = "1"
lazy_static = "1"
zstd = "0.13"
url = "2"
fastrand = "2"
mimalloc = "0.1"
sonic-rs = "0.3"
[[bin]]
name = "reddit-dump"
path = "src/reddit_dump.rs"

64
src/common.rs Normal file
View File

@ -0,0 +1,64 @@
use serde::{Serialize, Deserialize};
use std::borrow::Borrow;
use image::{DynamicImage, imageops::FilterType, ImageFormat};
use anyhow::Result;
use std::io::Cursor;
use reqwest::Client;
#[derive(Debug, Deserialize, Clone)]
pub struct InferenceServerConfig {
pub batch: usize,
pub image_size: (u32, u32),
pub embedding_size: usize,
}
pub async fn resize_for_embed<T: Borrow<DynamicImage> + Send + 'static>(config: InferenceServerConfig, image: T) -> Result<Vec<u8>> {
let resized = tokio::task::spawn_blocking(move || {
let new = image.borrow().resize(
config.image_size.0,
config.image_size.1,
FilterType::Lanczos3
);
let mut buf = Vec::new();
let mut csr = Cursor::new(&mut buf);
new.write_to(&mut csr, ImageFormat::Png)?;
Ok::<Vec<u8>, anyhow::Error>(buf)
}).await??;
Ok(resized)
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum EmbeddingRequest {
Images { images: Vec<serde_bytes::ByteBuf> },
Text { text: Vec<String> }
}
async fn fetch_backend_config(base_url: &str) -> Result<InferenceServerConfig> {
let res = Client::new().get(&format!("{}/config", base_url)).send().await?;
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
}
pub async fn get_backend_config(clip_server: &str) -> InferenceServerConfig {
loop {
match fetch_backend_config(&clip_server).await {
Ok(backend) => break backend,
Err(e) => {
log::error!("Backend failed (fetch): {}", e);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
}
}
pub async fn query_clip_server<I, O>(client: &Client, base_url: &str, path: &str, data: I) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
{
let response = client
.post(&format!("{}{}", base_url, path))
.header("Content-Type", "application/msgpack")
.body(rmp_serde::to_vec_named(&data)?)
.send()
.await?;
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
Ok(result)
}

View File

@ -30,8 +30,10 @@ use prometheus::{register_int_counter, register_int_counter_vec, register_int_ga
use ndarray::ArrayBase;
mod ocr;
mod common;
use crate::ocr::scan_image;
use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server};
lazy_static! {
static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap();
@ -126,13 +128,6 @@ struct FileRecord {
thumbnails: Option<Vec<u8>>,
}
#[derive(Debug, Deserialize, Clone)]
struct InferenceServerConfig {
batch: usize,
image_size: (u32, u32),
embedding_size: usize,
}
#[derive(Debug, Clone)]
struct WConfig {
backend: InferenceServerConfig,
@ -140,23 +135,6 @@ struct WConfig {
predefined_embeddings: HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::prelude::Dim<[usize; 1]>>>
}
async fn query_clip_server<I, O>(
client: &Client,
config: &Config,
path: &str,
data: I,
) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
{
let response = client
.post(&format!("{}{}", config.clip_server, path))
.header("Content-Type", "application/msgpack")
.body(rmp_serde::to_vec_named(&data)?)
.send()
.await?;
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
Ok(result)
}
#[derive(Debug)]
struct LoadedImage {
image: Arc<DynamicImage>,
@ -170,13 +148,6 @@ struct EmbeddingInput {
filename: String,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum EmbeddingRequest {
Images { images: Vec<serde_bytes::ByteBuf> },
Text { text: Vec<String> }
}
fn timestamp() -> i64 {
chrono::Utc::now().timestamp_micros()
}
@ -274,21 +245,6 @@ fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
formats
}
async fn resize_for_embed(config: Arc<WConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> {
let resized = tokio::task::spawn_blocking(move || {
let new = image.resize(
config.backend.image_size.0,
config.backend.image_size.1,
FilterType::Lanczos3
);
let mut buf = Vec::new();
let mut csr = Cursor::new(&mut buf);
new.write_to(&mut csr, ImageFormat::Png)?;
Ok::<Vec<u8>, anyhow::Error>(buf)
}).await??;
Ok(resized)
}
async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
let pool = initialize_database(&config.service).await?;
let client = Client::new();
@ -324,7 +280,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
};
IMAGES_LOADED_COUNTER.inc();
if record.embedding.is_none() {
let resized = resize_for_embed(config.clone(), image.clone()).await?;
let resized = resize_for_embed(config.backend.clone(), image.clone()).await?;
to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await?
}
@ -505,7 +461,7 @@ async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
async move {
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
&client,
&config.service,
&config.service.clip_server,
"",
EmbeddingRequest::Images {
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
@ -753,7 +709,7 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc();
let bytes = BASE64_STANDARD.decode(image)?;
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?);
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.clone(), image).await?));
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.backend.clone(), image).await?));
image_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(text) = &term.text {
@ -792,7 +748,7 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
}
for batch in batches {
let embs: Vec<Vec<u8>> = query_clip_server(&client, &config.service, "/", batch).await?;
let embs: Vec<Vec<u8>> = query_clip_server(&client, &config.service.clip_server, "/", batch).await?;
for emb in embs {
total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb));
}
@ -813,11 +769,6 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
}).into_response())
}
async fn get_backend_config(config: &Config) -> Result<InferenceServerConfig> {
let res = Client::new().get(&format!("{}/config", config.clip_server)).send().await?;
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
}
#[derive(Serialize, Deserialize)]
struct FrontendInit {
n_total: u64,
@ -834,15 +785,7 @@ async fn main() -> Result<()> {
let pool = initialize_database(&config).await?;
sqlx::query(SCHEMA).execute(&pool).await?;
let backend = loop {
match get_backend_config(&config).await {
Ok(backend) => break backend,
Err(e) => {
log::error!("Backend failed (fetch): {}", e);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
};
let backend = get_backend_config(&config.clip_server).await;
let mut predefined_embeddings = HashMap::new();

359
src/reddit_dump.rs Normal file
View File

@ -0,0 +1,359 @@
use anyhow::{anyhow, Context, Result};
use common::resize_for_embed;
use std::{collections::HashSet, fs, io::{BufReader, Cursor, BufRead, BufWriter}, path::PathBuf, time::Duration, sync::Arc, str::FromStr};
use serde::{Serialize, Deserialize};
use lazy_static::lazy_static;
use regex::{RegexSet, bytes};
use tokio::{sync::{mpsc::{self, Receiver}, Semaphore}, task::{JoinHandle, JoinSet}};
use tokio_stream::wrappers::ReceiverStream;
use reqwest::Client;
use futures_util::stream::{StreamExt, TryStreamExt};
use image::{DynamicImage, io::Reader as ImageReader};
use mimalloc::MiMalloc;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
mod common;
use crate::common::{get_backend_config, query_clip_server, EmbeddingRequest};
fn function_which_returns_some_na() -> Option<String> { Some(String::from("na")) }
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(untagged)]
enum BadTimestampFormat {
Int(u64),
String(String)
}
impl BadTimestampFormat {
fn to_u64(&self) -> Result<u64> {
match self {
BadTimestampFormat::Int(x) => Ok(*x),
BadTimestampFormat::String(x) => u64::from_str(&x).context("invalid string")
}
}
}
#[derive(Clone, Deserialize, Serialize, Debug)]
struct Entry {
url: String,
over_18: bool,
title: String,
author: Option<String>,
selftext: String,
subreddit: String,
created_utc: BadTimestampFormat,
#[serde(default="function_which_returns_some_na")]
post_hint: Option<String>,
id: String
}
#[derive(Clone, Deserialize, Serialize, Debug)]
struct ProcessedEntry {
url: String,
id: String,
title: String,
subreddit: String,
author: String,
timestamp: u64,
blob: Vec<u8>
}
lazy_static! {
static ref URL_IGNORE: RegexSet = RegexSet::new([
r"//reddit\.com",
r"\.html?",
r"\.php",
r"\?articleid=",
r"\.aspx?",
r"\.xml",
r"//youtube\.com",
// TODO fill in more things, maybe try and collect thumbnails or something
]).unwrap();
static ref ACCEPTABLE_FILETYPES: HashSet<&'static [u8]> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
.into_iter().map(str::as_bytes).collect();
static ref OBJECT_HACKY_IGNORE: bytes::RegexSet = bytes::RegexSet::new([
r#""author":"\[deleted\]""#,
r#""promoted":true"#, // these seem to be ads which are in the data for some reason, and lack some important fields
r#""domain":"self.promos""#, // .......
r"\x00" // for SOME REASON one of the JSON files contains a lot of null bytes before one particular record, so just ignore that record
]).unwrap();
}
fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>) -> Result<()> {
let mut stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
stream.window_log_max(31)?;
let mut stream = BufReader::new(stream);
let mut buf = Vec::new();
loop {
if stream.read_until(0x0A, &mut buf)? == 0 {
break
}
// we would discard these later, but they have a different format so they can't be deserialized straight into Entries
if OBJECT_HACKY_IGNORE.is_match(&buf) {
buf.clear();
continue;
}
let entry = match sonic_rs::serde::from_slice::<Entry>(buf.as_slice()) {
Ok(x) => x,
Err(e) => {
log::warn!("parse failed, please validate {:?} {:?}", e, String::from_utf8_lossy(&buf));
return Ok(())
}
};
if entry.selftext.is_empty() && !entry.over_18 && entry.author.is_some() {
if !URL_IGNORE.is_match(&entry.url) {
match &entry.post_hint {
Some(x) if x == "na" || x == "image" => {
tx.blocking_send(entry)?;
},
_ => ()
}
}
}
buf.clear();
}
Ok(())
}
struct Config {
max_content_length: usize,
input: String,
output: String,
backend: String,
mode: OperatingMode
}
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<Vec<u8>> {
let mut response = client.get(url).send().await?;
if !ACCEPTABLE_FILETYPES.contains(response.headers().get(reqwest::header::CONTENT_TYPE).context("no contept type")?.as_bytes()) {
return Err(anyhow!("invalid Content-Type"));
}
match response.content_length() {
Some(x) if x > (config.max_content_length as u64) => return Err(anyhow!("response too large")),
_ => ()
}
let mut buffer = vec![];
while let Some(chunk) = response.chunk().await? {
buffer.extend(chunk);
if buffer.len() > config.max_content_length {
return Err(anyhow!("response too large"));
}
}
Ok(buffer)
}
fn write_output(config: Arc<Config>, mut rx: Receiver<ProcessedEntry>) -> Result<()> {
let mut out = fs::File::create(&config.output)?;
let stream = zstd::Encoder::new(&mut out, 15)?.auto_finish();
let mut buf_stream = BufWriter::new(stream);
while let Some(x) = rx.blocking_recv() {
rmp_serde::encode::write(&mut buf_stream, &x)?;
}
Ok(())
}
enum OperatingMode {
Count,
Sample(f32),
FullRun
}
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();
let cpus = num_cpus::get();
let config = Arc::new(Config {
max_content_length: 1<<23,
input: String::from("./submissions"),
output: String::from("./data.zst"),
backend: String::from("http://localhost:1708"),
mode: OperatingMode::Count
});
let backend = get_backend_config(&config.backend).await;
log::info!("connected to inference server");
let (entries_tx, mut entries_rx) = mpsc::channel::<Entry>(32768);
let (buffers_tx, buffers_rx) = mpsc::channel(128);
let (resized_tx, resized_rx) = mpsc::channel(backend.batch);
let (final_write_tx, final_write_rx) = mpsc::channel::<ProcessedEntry>(32768);
let client = Client::builder()
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")))
.timeout(Duration::from_secs(30))
.build()?;
let load_task: JoinHandle<Result<()>> = match config.mode {
OperatingMode::Count => tokio::task::spawn(async move {
let mut counter = 0;
while let Some(_) = entries_rx.recv().await {
counter += 1
}
println!("{}", counter);
Ok(())
}),
_ => tokio::task::spawn({
let client = client.clone();
let config = config.clone();
let stream = ReceiverStream::new(entries_rx);
stream.map(Ok).try_for_each_concurrent(Some(128), move |entry| {
let client = client.clone();
let config = config.clone();
let buffers_tx = buffers_tx.clone();
async move {
if let OperatingMode::Sample(rate) = config.mode {
if fastrand::f32() > rate {
return Ok(())
}
}
match fetch_file(client, config.clone(), &entry.url).await {
Ok(buf) => {
buffers_tx.send((entry, buf)).await?;
},
Err(e) => {
log::warn!("{} failed: {}", &entry.url, e)
}
}
Ok(())
}
}
)})
};
let resize_task = match config.mode {
OperatingMode::Count => None,
_ => Some(tokio::task::spawn({
let stream = ReceiverStream::new(buffers_rx);
let backend = backend.clone();
stream.map(Ok).try_for_each_concurrent(Some(cpus), move |(entry, buffer)| {
let backend = backend.clone();
let resized_tx = resized_tx.clone();
async move {
let image_result = tokio::task::spawn_blocking(|| {
let csr = Cursor::new(buffer);
let image = ImageReader::new(csr).decode()?;
Result::<DynamicImage, anyhow::Error>::Ok(image)
}).await?;
let image = match image_result {
Ok(image) => image,
Err(e) => {
log::warn!("loading {} failed: {}", entry.url, e);
return Result::<(), anyhow::Error>::Ok(());
}
};
let resized = resize_for_embed(backend.clone(), image).await?;
resized_tx.send((entry, resized)).await?;
Ok(())
}
})
}))
};
let embedding_generation_task: Option<JoinHandle<Result<()>>> = match config.mode {
OperatingMode::Count => None,
_ => Some(tokio::spawn({
let stream = ReceiverStream::new(resized_rx).chunks(backend.batch);
let client = client.clone();
let config = config.clone();
// keep multiple embedding requests in flight
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
let (entries, bytes): (Vec<Entry>, Vec<Vec<u8>>) = batch.into_iter().unzip();
let client = client.clone();
let config = config.clone();
let final_write_tx = final_write_tx.clone();
async move {
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
&client,
&config.backend,
"",
EmbeddingRequest::Images {
images: bytes.into_iter().map(serde_bytes::ByteBuf::from).collect(),
},
).await.context("querying CLIP server")?;
for (vector, entry) in result.into_iter().zip(entries) {
println!("{:?}", entry);
final_write_tx.send(ProcessedEntry {
url: entry.url,
id: entry.id,
title: entry.title,
subreddit: entry.subreddit,
author: entry.author.unwrap(),
blob: vector.into_vec(),
timestamp: entry.created_utc.to_u64()?
}).await?;
}
anyhow::Result::Ok(())
}
})
}))
};
let config_ = config.clone();
let output_writer_task = match config.mode {
OperatingMode::Sample(_) | OperatingMode::FullRun => Some(tokio::task::spawn_blocking(move || write_output(config_, final_write_rx))),
_ => None
};
log::info!("working...");
let mut paths = vec![];
for file in fs::read_dir(&config.input)? {
paths.push(file?.path());
}
paths.sort();
match config.mode {
OperatingMode::Count | OperatingMode::Sample(_) => {
let mut set = JoinSet::new();
let semaphore = Arc::new(Semaphore::new(cpus));
for path in paths {
let semaphore = semaphore.clone();
let permit = semaphore.acquire_owned().await?;
let entries_tx = entries_tx.clone();
let path_ = path.clone();
log::info!("reading {:?}", path);
set.spawn_blocking(move || {
match process_file(path_, entries_tx) {
Ok(_) => (),
Err(e) => log::error!("could not parse {:?} {:?}", &path, e)
}
std::mem::drop(permit);
});
}
std::mem::drop(entries_tx);
while let Some(x) = set.try_join_next() {
x?;
}
},
OperatingMode::FullRun => {
for path in paths {
let entries_tx = entries_tx.clone();
let path_ = path.clone();
log::info!("reading {:?}", path);
let c = tokio::task::spawn_blocking(move || process_file(path_, entries_tx)).await?;
match c {
Ok(_) => (),
_ => log::error!("could not parse {:?} {:?}", &path, c)
}
}
std::mem::drop(entries_tx);
load_task.await??;
}
}
if let Some(task) = resize_task { task.await??; }
if let Some(task) = embedding_generation_task { task.await?? };
if let Some(task) = output_writer_task { task.await?? };
Ok(())
}