mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-01-19 21:52:57 +00:00
improve dump processing and misc performance fixes
This commit is contained in:
parent
c277b49dc1
commit
8097ce8d91
@ -81,3 +81,9 @@ pub async fn query_clip_server<I, O>(client: &Client, base_url: &str, path: &str
|
|||||||
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
|
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn decode_fp16_buffer(buf: &[u8]) -> Vec<f32> {
|
||||||
|
buf.chunks_exact(2)
|
||||||
|
.map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
use anyhow::{Result, Context};
|
use anyhow::{Result, Context};
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use std::io::BufReader;
|
use std::io::{BufReader, Write};
|
||||||
use rmp_serde::decode::Error as DecodeError;
|
use rmp_serde::decode::Error as DecodeError;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
use base64::Engine;
|
||||||
|
use argh::FromArgs;
|
||||||
|
use chrono::{TimeZone, Utc, DateTime};
|
||||||
|
use std::collections::{VecDeque, HashSet};
|
||||||
|
use std::hash::Hasher;
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
|
||||||
// TODO refactor
|
// TODO refactor
|
||||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||||
@ -22,35 +28,172 @@ struct ProcessedEntry {
|
|||||||
subreddit: String,
|
subreddit: String,
|
||||||
author: String,
|
author: String,
|
||||||
timestamp: u64,
|
timestamp: u64,
|
||||||
#[serde(with = "serde_bytes")]
|
#[serde(with="serde_bytes")]
|
||||||
embedding: Vec<u8>,
|
embedding: Vec<u8>,
|
||||||
metadata: OriginalImageMetadata
|
metadata: OriginalImageMetadata
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromArgs)]
|
||||||
|
#[argh(description="Process scraper dump files")]
|
||||||
|
struct CLIArguments {
|
||||||
|
#[argh(option, short='s', description="read subset of records")]
|
||||||
|
sample: Option<f32>,
|
||||||
|
#[argh(switch, short='p', description="print basic information for records")]
|
||||||
|
print_records: bool,
|
||||||
|
#[argh(switch, short='e',description="print embeddings")]
|
||||||
|
print_embeddings: bool,
|
||||||
|
#[argh(switch, short='a', description="print aggregates")]
|
||||||
|
print_aggregates: bool,
|
||||||
|
#[argh(option, short='E', description="x:y - load embedding named x from file y")]
|
||||||
|
embedding: Vec<String>,
|
||||||
|
#[argh(option, short='H', description="path for histograms of dot with embeddings")]
|
||||||
|
histograms: Option<String>,
|
||||||
|
#[argh(switch, short='D', description="enable deduplicator")]
|
||||||
|
deduplicate: bool,
|
||||||
|
#[argh(option, short='T', description="deduplication Hamming distance threshold")]
|
||||||
|
threshold: Option<u64>,
|
||||||
|
#[argh(positional)]
|
||||||
|
paths: Vec<String>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||||
|
struct Histogram {
|
||||||
|
min: f32,
|
||||||
|
max: f32,
|
||||||
|
buckets: Vec<u64>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Histogram {
|
||||||
|
fn new(min: f32, max: f32, count: usize) -> Self {
|
||||||
|
let buckets = (0..count).map(|_| 0).collect();
|
||||||
|
Self { min, max, buckets }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(&mut self, x: f32) {
|
||||||
|
let bucket = if x < self.min {
|
||||||
|
0
|
||||||
|
} else if x >= self.max {
|
||||||
|
self.buckets.len() - 1
|
||||||
|
} else {
|
||||||
|
((x - self.min) / (self.max - self.min) * (self.buckets.len() as f32)) as usize
|
||||||
|
};
|
||||||
|
self.buckets[bucket] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn buckets(&self) -> Vec<(f32, u64)> {
|
||||||
|
let step = (self.max - self.min) / (self.buckets.len() as f32);
|
||||||
|
self.buckets.iter().enumerate().map(|(i, x)| (self.min + (i as f32) * step, *x)).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dot(x: &[f32], y: &[f32]) -> f32 {
|
||||||
|
x.iter().zip(y).map(|(a, b)| a * b).sum::<f32>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binarize(x: &[f32]) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0; x.len() / 8];
|
||||||
|
for i in 0..(x.len() / 8) {
|
||||||
|
buf[i] = ((x[i * 8] > 0.0) as u8) + (((x[i * 8 + 1] > 0.0) as u8) << 1) + (((x[i * 8 + 2] > 0.0) as u8) << 2) + (((x[i * 8 + 3] > 0.0) as u8) << 3) + (((x[i * 8 + 4] > 0.0) as u8) << 4) + (((x[i * 8 + 5] > 0.0) as u8) << 5) + (((x[i * 8 + 6] > 0.0) as u8) << 6) + (((x[i * 8 + 7] > 0.0) as u8) << 7);
|
||||||
|
}
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let path = std::env::args().nth(1).context("missing path")?;
|
let args: CLIArguments = argh::from_env();
|
||||||
let stream = zstd::stream::Decoder::new(fs::File::open(path)?)?;
|
let mut rng = fastrand::Rng::new();
|
||||||
let mut stream = BufReader::new(stream);
|
let mut latest_timestamp = DateTime::<Utc>::MIN_UTC;
|
||||||
let mut latest_timestamp = 0;
|
let mut earliest_timestamp = DateTime::<Utc>::MAX_UTC;
|
||||||
let mut earliest_timestamp = u64::MAX;
|
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
loop {
|
let mut deduped_count = 0;
|
||||||
let res: Result<ProcessedEntry, DecodeError> = rmp_serde::from_read(&mut stream);
|
let mut embeddings = Vec::new();
|
||||||
if res.is_ok() {
|
for x in args.embedding {
|
||||||
count += 1;
|
let (name, path) = x.split_once(':').unwrap();
|
||||||
}
|
let blob = std::fs::read(path).context("read embedding")?;
|
||||||
match res {
|
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512)));
|
||||||
Ok(x) => {
|
}
|
||||||
if x.timestamp > latest_timestamp {
|
|
||||||
//println!("{} {} https://reddit.com/r/{}/comments/{} {} https://mse.osmarks.net/?e={}", x.timestamp, count, x.subreddit, x.id, x.metadata.final_url, URL_SAFE.encode(x.embedding));
|
// TODO ring of vecs probably has bad cache locality
|
||||||
latest_timestamp = x.timestamp;
|
let mut dedupe_ring: VecDeque<Vec<u8>> = VecDeque::with_capacity(2<<10);
|
||||||
}
|
let threshold = args.threshold.unwrap_or(3);
|
||||||
earliest_timestamp = earliest_timestamp.min(x.timestamp);
|
|
||||||
},
|
for path in args.paths {
|
||||||
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
let stream = zstd::stream::Decoder::new(fs::File::open(path).context("read dump file")?)?;
|
||||||
Err(e) => return Err(e).context("decode fail")
|
let mut stream = BufReader::new(stream);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let res: Result<ProcessedEntry, DecodeError> = rmp_serde::from_read(&mut stream);
|
||||||
|
if res.is_ok() {
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
match res {
|
||||||
|
Ok(x) => {
|
||||||
|
if args.sample.is_some() && rng.f32() > args.sample.unwrap() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let timestamp = Utc.timestamp_opt(x.timestamp as i64, 0).unwrap();
|
||||||
|
|
||||||
|
let embedding = common::decode_fp16_buffer(&x.embedding);
|
||||||
|
|
||||||
|
latest_timestamp = latest_timestamp.max(timestamp);
|
||||||
|
earliest_timestamp = earliest_timestamp.min(timestamp);
|
||||||
|
|
||||||
|
if args.deduplicate {
|
||||||
|
let code = binarize(&embedding);
|
||||||
|
if dedupe_ring.len() == dedupe_ring.capacity() {
|
||||||
|
dedupe_ring.pop_front().unwrap();
|
||||||
|
}
|
||||||
|
let has_match = dedupe_ring.iter().any(|x| hamming::distance(x, &code) <= threshold);
|
||||||
|
dedupe_ring.push_back(code);
|
||||||
|
if has_match {
|
||||||
|
deduped_count += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.print_records {
|
||||||
|
println!("{} {} https://reddit.com/r/{}/comments/{} {}", timestamp, x.title, x.subreddit, x.id, x.metadata.final_url);
|
||||||
|
}
|
||||||
|
if args.print_embeddings {
|
||||||
|
println!("https://mse.osmarks.net/?e={}", base64::engine::general_purpose::URL_SAFE.encode(&x.embedding));
|
||||||
|
}
|
||||||
|
for (_name, vec, histogram) in &mut embeddings {
|
||||||
|
let dot = dot(&embedding, vec);
|
||||||
|
histogram.add(dot);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||||
|
Err(e) => return Err(e).context("decode fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.print_aggregates {
|
||||||
|
println!("earliest={} latest={} count={} deduped={}", earliest_timestamp, latest_timestamp, count, deduped_count);
|
||||||
|
}
|
||||||
|
if let Some(histogram_path) = args.histograms {
|
||||||
|
let mut file = std::fs::File::create(histogram_path)?;
|
||||||
|
for (name, _, histogram) in &embeddings {
|
||||||
|
let width = 800.0;
|
||||||
|
let padding = 40.0;
|
||||||
|
let bars_height = 300 as f64;
|
||||||
|
let buckets = histogram.buckets();
|
||||||
|
let max_count = *buckets.iter().map(|(_max, count)| count).max().unwrap();
|
||||||
|
let bar_width = width / buckets.len() as f64;
|
||||||
|
let plot = maud::html! {
|
||||||
|
h1 { (name) }
|
||||||
|
svg style="border: 1px solid gray;" viewBox=(format!("{} 0 {} {}", -padding * 0.25, width + (padding * 0.75), bars_height + 50.0)) xmlns="http://www.w3.org/2000/svg" width=(format!("{}", width + padding)) height=(format!("{}", bars_height + 50.0)) {
|
||||||
|
@for (i, (min, count)) in buckets.into_iter().enumerate() {
|
||||||
|
@let height = bars_height * (count as f64 / max_count as f64);
|
||||||
|
rect width=(format!("{}", bar_width)) x=(format!("{}", bar_width * i as f64)) height=(format!("{}", height)) y=(format!("{}", bars_height - height)) {
|
||||||
|
title {
|
||||||
|
(format!("{} {}", min, count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
file.write_all(plot.into_string().as_bytes())?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("{} {} {}", earliest_timestamp, latest_timestamp, count);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
18
src/get_embedding.py
Normal file
18
src/get_embedding.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import msgpack
|
||||||
|
import sys
|
||||||
|
|
||||||
|
with open("mse_config.json") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
def get_embedding(req):
|
||||||
|
return msgpack.unpackb(requests.post(config["clip_server"], data=msgpack.packb(req)).content)
|
||||||
|
|
||||||
|
output, input, *xs = sys.argv[1:]
|
||||||
|
|
||||||
|
with open(output, "wb") as f:
|
||||||
|
with open(input, "rb") as g:
|
||||||
|
input_data = g.read()
|
||||||
|
f.write(get_embedding({"images": [input_data]})[0])
|
@ -40,7 +40,7 @@ mod common;
|
|||||||
mod video_reader;
|
mod video_reader;
|
||||||
|
|
||||||
use crate::ocr::scan_image;
|
use crate::ocr::scan_image;
|
||||||
use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server};
|
use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server, decode_fp16_buffer};
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap();
|
static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap();
|
||||||
@ -893,12 +893,6 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
|
|||||||
Ok(index)
|
Ok(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode_fp16_buffer(buf: &[u8]) -> Vec<f32> {
|
|
||||||
buf.chunks_exact(2)
|
|
||||||
.map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32())
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingVector = Vec<f32>;
|
type EmbeddingVector = Vec<f32>;
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use common::resize_for_embed;
|
use common::resize_for_embed;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use std::{collections::HashSet, ffi::OsStr, fs::{self, read_dir}, io::{BufRead, BufReader, BufWriter, Cursor}, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
|
use std::{collections::HashSet, ffi::OsStr, fs::{self, read_dir}, io::{BufRead, BufReader, BufWriter, Cursor}, path::PathBuf, str::FromStr, sync::Arc, time::Duration, hash::Hasher};
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use regex::{bytes, Regex, RegexSet, RegexSetBuilder};
|
use regex::{bytes, Regex, RegexSet, RegexSetBuilder};
|
||||||
@ -148,6 +148,7 @@ lazy_static! {
|
|||||||
static ref IMAGE_FILESIZES_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_filesizes", "filesizes of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.5, 29).unwrap()).unwrap();
|
static ref IMAGE_FILESIZES_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_filesizes", "filesizes of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.5, 29).unwrap()).unwrap();
|
||||||
static ref IMAGE_PIXELS_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_pixels", "pixel count of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.3, 53).unwrap()).unwrap();
|
static ref IMAGE_PIXELS_HISTOGRAM: HistogramVec = register_histogram_vec!("mse_scrape_image_pixels", "pixel count of successfully fetched images", &["format"], prometheus::exponential_buckets(100.0, 1.3, 53).unwrap()).unwrap();
|
||||||
static ref HTML_EXTRACTS_COUNTER: IntCounter = register_int_counter!("mse_scrape_html_extracts", "html extraction operations").unwrap();
|
static ref HTML_EXTRACTS_COUNTER: IntCounter = register_int_counter!("mse_scrape_html_extracts", "html extraction operations").unwrap();
|
||||||
|
static ref DISCARDED_COUNTER: IntCounter = register_int_counter!("mse_scrape_discarded", "images discarded by hash").unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(tx))]
|
#[instrument(skip(tx))]
|
||||||
@ -209,7 +210,8 @@ struct Config {
|
|||||||
mode: OperatingMode,
|
mode: OperatingMode,
|
||||||
filename_threshold: Option<String>,
|
filename_threshold: Option<String>,
|
||||||
metrics_addr: String,
|
metrics_addr: String,
|
||||||
contact_info: String
|
contact_info: String,
|
||||||
|
discard_hashes: HashSet<u64>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(client, config))]
|
#[instrument(skip(client, config))]
|
||||||
@ -239,12 +241,19 @@ async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) ->
|
|||||||
_ => ()
|
_ => ()
|
||||||
}
|
}
|
||||||
let mut buffer = vec![];
|
let mut buffer = vec![];
|
||||||
|
let mut hash = seahash::SeaHasher::new();
|
||||||
while let Some(chunk) = response.chunk().await? {
|
while let Some(chunk) = response.chunk().await? {
|
||||||
|
hash.write(&chunk);
|
||||||
buffer.extend(chunk);
|
buffer.extend(chunk);
|
||||||
if buffer.len() > config.max_content_length {
|
if buffer.len() > config.max_content_length {
|
||||||
return Err(anyhow!("response too large"));
|
return Err(anyhow!("response too large"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let hash = hash.finish();
|
||||||
|
if config.discard_hashes.contains(&hash) {
|
||||||
|
DISCARDED_COUNTER.inc();
|
||||||
|
return Err(anyhow!("discarded"));
|
||||||
|
}
|
||||||
if let Some(extract_rule) = html_extract_rule {
|
if let Some(extract_rule) = html_extract_rule {
|
||||||
if content_type == "text/html" {
|
if content_type == "text/html" {
|
||||||
let buffer = String::from_utf8_lossy(&buffer).to_string();
|
let buffer = String::from_utf8_lossy(&buffer).to_string();
|
||||||
@ -329,7 +338,7 @@ async fn serve_metrics(config: Arc<Config>) -> Result<()> {
|
|||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
console_subscriber::init();
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
let cpus = num_cpus::get();
|
let cpus = num_cpus::get();
|
||||||
|
|
||||||
@ -341,7 +350,8 @@ async fn main() -> Result<()> {
|
|||||||
mode: OperatingMode::FullRun,
|
mode: OperatingMode::FullRun,
|
||||||
filename_threshold: Some(String::from("RS_2017-08.zst")),
|
filename_threshold: Some(String::from("RS_2017-08.zst")),
|
||||||
metrics_addr: String::from("0.0.0.0:9914"),
|
metrics_addr: String::from("0.0.0.0:9914"),
|
||||||
contact_info: String::from("scraping-ops@osmarks.net")
|
contact_info: String::from("scraping-ops@osmarks.net"),
|
||||||
|
discard_hashes: [4168519401919155623, 4577010157274124110].into_iter().collect()
|
||||||
});
|
});
|
||||||
|
|
||||||
serve_metrics(config.clone()).await?;
|
serve_metrics(config.clone()).await?;
|
||||||
|
Loading…
Reference in New Issue
Block a user