mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-09 04:06:39 +00:00
release WIP DiskANN index orchestration code
This commit is contained in:
parent
35df1201e2
commit
f1283137d6
@ -94,3 +94,76 @@ pub fn decode_fp16_buffer(buf: &[u8]) -> Vec<f32> {
|
||||
.map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn chunk_fp16_buffer(buf: &[u8]) -> Vec<half::f16> {
|
||||
buf.chunks_exact(2)
|
||||
.map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||
pub struct OriginalImageMetadata {
|
||||
pub mime_type: String,
|
||||
pub original_file_size: usize,
|
||||
pub dimension: (u32, u32),
|
||||
pub final_url: String
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
pub struct ProcessedEntry {
|
||||
pub url: String,
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub subreddit: String,
|
||||
pub author: String,
|
||||
pub timestamp: u64,
|
||||
#[serde(with="serde_bytes")]
|
||||
pub embedding: Vec<u8>,
|
||||
pub metadata: OriginalImageMetadata
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
pub struct ShardInputHeader {
|
||||
pub id: u32,
|
||||
pub centroid: Vec<f32>,
|
||||
pub max_query_id: usize
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
pub struct ShardedRecord {
|
||||
pub id: u32,
|
||||
#[serde(with="serde_bytes")]
|
||||
pub vector: Vec<u8>, // FP16
|
||||
pub query_knns: Vec<u32>
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
pub struct ShardHeader {
|
||||
pub id: u32,
|
||||
pub max: u32,
|
||||
pub centroid: Vec<f32>,
|
||||
pub medioid: u32,
|
||||
pub offsets: Vec<u64>,
|
||||
pub mapping: Vec<u32>
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)]
|
||||
pub struct PackedIndexEntry {
|
||||
pub vector: Vec<u16>, // FP16 values cast to u16 for storage
|
||||
pub vertices: Vec<u32>,
|
||||
pub id: u32,
|
||||
pub timestamp: u64,
|
||||
pub dimensions: (u32, u32),
|
||||
pub score: f32,
|
||||
pub url: String,
|
||||
pub shards: Vec<u32>
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
pub struct IndexHeader {
|
||||
pub shards: Vec<(Vec<f32>, u32)>,
|
||||
pub count: u32,
|
||||
pub dead_count: u32,
|
||||
pub record_pad_size: usize,
|
||||
pub quantizer: diskann::vector::ProductQuantizer
|
||||
}
|
||||
|
@ -1,42 +1,30 @@
|
||||
use anyhow::{Result, Context};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::io::{BufReader, Write};
|
||||
use std::io::{BufReader, Read, Seek, SeekFrom, Write, BufWriter};
|
||||
use std::path::PathBuf;
|
||||
use rmp_serde::decode::Error as DecodeError;
|
||||
use std::fs;
|
||||
use base64::Engine;
|
||||
use argh::FromArgs;
|
||||
use chrono::{TimeZone, Utc, DateTime};
|
||||
use std::collections::{VecDeque, HashSet};
|
||||
use std::collections::VecDeque;
|
||||
use faiss::Index;
|
||||
use std::sync::mpsc::{sync_channel, SyncSender};
|
||||
use itertools::Itertools;
|
||||
use simsimd::SpatialSimilarity;
|
||||
use std::hash::Hasher;
|
||||
use foldhash::{HashSet, HashSetExt};
|
||||
|
||||
use diskann::vector::{scale_dot_result_f64, ProductQuantizer};
|
||||
|
||||
mod common;
|
||||
|
||||
// TODO refactor
|
||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||
struct OriginalImageMetadata {
|
||||
mime_type: String,
|
||||
original_file_size: usize,
|
||||
dimension: (u32, u32),
|
||||
final_url: String
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
struct ProcessedEntry {
|
||||
url: String,
|
||||
id: String,
|
||||
title: String,
|
||||
subreddit: String,
|
||||
author: String,
|
||||
timestamp: u64,
|
||||
#[serde(with="serde_bytes")]
|
||||
embedding: Vec<u8>,
|
||||
metadata: OriginalImageMetadata
|
||||
}
|
||||
use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader};
|
||||
|
||||
#[derive(FromArgs)]
|
||||
#[argh(description="Process scraper dump files")]
|
||||
struct CLIArguments {
|
||||
#[argh(option, short='s', description="read subset of records")]
|
||||
#[argh(option, short='s', description="randomly select fraction of records")]
|
||||
sample: Option<f32>,
|
||||
#[argh(switch, short='p', description="print basic information for records")]
|
||||
print_records: bool,
|
||||
@ -44,16 +32,36 @@ struct CLIArguments {
|
||||
print_embeddings: bool,
|
||||
#[argh(switch, short='a', description="print aggregates")]
|
||||
print_aggregates: bool,
|
||||
#[argh(option, short='E', description="x:y - load embedding named x from file y")]
|
||||
#[argh(option, short='E', description="x:y[:f] - load embedding named x from file y, discard record if dot product >= filter threshold f")]
|
||||
embedding: Vec<String>,
|
||||
#[argh(option, short='H', description="path for histograms of dot with embeddings")]
|
||||
histograms: Option<String>,
|
||||
#[argh(switch, short='D', description="enable deduplicator")]
|
||||
#[argh(switch, short='D', description="enable deduplication")]
|
||||
deduplicate: bool,
|
||||
#[argh(option, short='T', description="deduplication Hamming distance threshold")]
|
||||
threshold: Option<u64>,
|
||||
#[argh(positional)]
|
||||
paths: Vec<String>
|
||||
paths: Vec<String>,
|
||||
#[argh(option, short='o', description="output embeddings to file")]
|
||||
output_embeddings: Option<String>,
|
||||
#[argh(option, short='C', description="split input into shards using these centroids")]
|
||||
centroids: Option<String>,
|
||||
#[argh(option, short='S', description="index shard directory")]
|
||||
shards_dir: Option<String>,
|
||||
#[argh(option, short='Q', description="query vectors file")]
|
||||
queries: Option<String>,
|
||||
#[argh(option, short='d', description="random seed")]
|
||||
seed: Option<u64>,
|
||||
#[argh(option, short='i', description="index output directory")]
|
||||
index_output: Option<String>,
|
||||
#[argh(switch, short='t', description="print titles")]
|
||||
titles: bool,
|
||||
#[argh(option, description="truncate centroids list")]
|
||||
clip_centroids: Option<usize>,
|
||||
#[argh(switch, description="print original linked URL")]
|
||||
original_url: bool,
|
||||
#[argh(option, short='q', description="product quantization codec path")]
|
||||
pq_codec: Option<String>,
|
||||
#[argh(switch, short='j', description="JSON output")]
|
||||
json: bool
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
@ -70,13 +78,14 @@ impl Histogram {
|
||||
}
|
||||
|
||||
fn add(&mut self, x: f32) {
|
||||
let bucket = if x < self.min {
|
||||
let mut bucket = if x < self.min {
|
||||
0
|
||||
} else if x >= self.max {
|
||||
self.buckets.len() - 1
|
||||
} else {
|
||||
((x - self.min) / (self.max - self.min) * (self.buckets.len() as f32)) as usize
|
||||
};
|
||||
bucket = bucket.max(0).min(self.buckets.len() - 1);
|
||||
self.buckets[bucket] += 1;
|
||||
}
|
||||
|
||||
@ -86,93 +95,356 @@ impl Histogram {
|
||||
}
|
||||
}
|
||||
|
||||
fn dot(x: &[f32], y: &[f32]) -> f32 {
|
||||
x.iter().zip(y).map(|(a, b)| a * b).sum::<f32>()
|
||||
}
|
||||
|
||||
fn binarize(x: &[f32]) -> Vec<u8> {
|
||||
let mut buf = vec![0; x.len() / 8];
|
||||
fn binarize(x: &[f32]) -> u64 {
|
||||
let mut hasher = seahash::SeaHasher::new();
|
||||
for i in 0..(x.len() / 8) {
|
||||
buf[i] = ((x[i * 8] > 0.0) as u8) + (((x[i * 8 + 1] > 0.0) as u8) << 1) + (((x[i * 8 + 2] > 0.0) as u8) << 2) + (((x[i * 8 + 3] > 0.0) as u8) << 3) + (((x[i * 8 + 4] > 0.0) as u8) << 4) + (((x[i * 8 + 5] > 0.0) as u8) << 5) + (((x[i * 8 + 6] > 0.0) as u8) << 6) + (((x[i * 8 + 7] > 0.0) as u8) << 7);
|
||||
hasher.write_u8(((x[i * 8] > 0.0) as u8) + (((x[i * 8 + 1] > 0.0) as u8) << 1) + (((x[i * 8 + 2] > 0.0) as u8) << 2) + (((x[i * 8 + 3] > 0.0) as u8) << 3) + (((x[i * 8 + 4] > 0.0) as u8) << 4) + (((x[i * 8 + 5] > 0.0) as u8) << 5) + (((x[i * 8 + 6] > 0.0) as u8) << 6) + (((x[i * 8 + 7] > 0.0) as u8) << 7));
|
||||
}
|
||||
buf
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: CLIArguments = argh::from_env();
|
||||
let mut rng = fastrand::Rng::new();
|
||||
let mut latest_timestamp = DateTime::<Utc>::MIN_UTC;
|
||||
let mut earliest_timestamp = DateTime::<Utc>::MAX_UTC;
|
||||
let mut count = 0;
|
||||
let mut deduped_count = 0;
|
||||
let mut embeddings = Vec::new();
|
||||
for x in args.embedding {
|
||||
let (name, path) = x.split_once(':').unwrap();
|
||||
let blob = std::fs::read(path).context("read embedding")?;
|
||||
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512)));
|
||||
}
|
||||
|
||||
// TODO ring of vecs probably has bad cache locality
|
||||
let mut dedupe_ring: VecDeque<Vec<u8>> = VecDeque::with_capacity(2<<10);
|
||||
let threshold = args.threshold.unwrap_or(3);
|
||||
|
||||
for path in args.paths {
|
||||
fn reader_thread(paths: &Vec<String>, tx: SyncSender<ProcessedEntry>) -> Result<()> {
|
||||
for path in paths {
|
||||
let stream = zstd::stream::Decoder::new(fs::File::open(path).context("read dump file")?)?;
|
||||
let mut stream = BufReader::new(stream);
|
||||
|
||||
loop {
|
||||
let res: Result<ProcessedEntry, DecodeError> = rmp_serde::from_read(&mut stream);
|
||||
if res.is_ok() {
|
||||
count += 1;
|
||||
}
|
||||
match res {
|
||||
Ok(x) => {
|
||||
if args.sample.is_some() && rng.f32() > args.sample.unwrap() {
|
||||
continue;
|
||||
}
|
||||
let timestamp = Utc.timestamp_opt(x.timestamp as i64, 0).unwrap();
|
||||
|
||||
let embedding = common::decode_fp16_buffer(&x.embedding);
|
||||
|
||||
latest_timestamp = latest_timestamp.max(timestamp);
|
||||
earliest_timestamp = earliest_timestamp.min(timestamp);
|
||||
|
||||
if args.deduplicate {
|
||||
let code = binarize(&embedding);
|
||||
if dedupe_ring.len() == dedupe_ring.capacity() {
|
||||
dedupe_ring.pop_front().unwrap();
|
||||
}
|
||||
let has_match = dedupe_ring.iter().any(|x| hamming::distance(x, &code) <= threshold);
|
||||
dedupe_ring.push_back(code);
|
||||
if has_match {
|
||||
deduped_count += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if args.print_records {
|
||||
println!("{} {} https://reddit.com/r/{}/comments/{} {}", timestamp, x.title, x.subreddit, x.id, x.metadata.final_url);
|
||||
}
|
||||
if args.print_embeddings {
|
||||
println!("https://mse.osmarks.net/?e={}", base64::engine::general_purpose::URL_SAFE.encode(&x.embedding));
|
||||
}
|
||||
for (_name, vec, histogram) in &mut embeddings {
|
||||
let dot = dot(&embedding, vec);
|
||||
histogram.add(dot);
|
||||
}
|
||||
},
|
||||
Ok(x) => tx.send(x)?,
|
||||
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||
Err(e) => return Err(e).context("decode fail")
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const SHARD_SPILL: usize = 2;
|
||||
const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size
|
||||
const D_EMB: u32 = 1152;
|
||||
const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0);
|
||||
const KNN_K: usize = 30;
|
||||
const BALANCE_WEIGHT: f64 = 0.2;
|
||||
const BATCH_SIZE: usize = 128;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: CLIArguments = argh::from_env();
|
||||
let mut rng = fastrand::Rng::with_seed(args.seed.unwrap_or(0));
|
||||
let mut latest_timestamp = DateTime::<Utc>::MIN_UTC;
|
||||
let mut earliest_timestamp = DateTime::<Utc>::MAX_UTC;
|
||||
let mut count = 0;
|
||||
let mut deduped_count = 0;
|
||||
|
||||
// load specified embeddings from files
|
||||
let mut embeddings = Vec::new();
|
||||
for x in args.embedding {
|
||||
let (name, snd) = x.split_once(':').unwrap();
|
||||
let (path, threshold) = if let Some((path, threshold)) = snd.split_once(':') {
|
||||
(path, Some(threshold.parse::<f32>().context("parse threshold")?))
|
||||
} else {
|
||||
(snd, None)
|
||||
};
|
||||
let blob = fs::read(path).context("read embedding")?;
|
||||
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold));
|
||||
}
|
||||
|
||||
let pq_codec = if let Some(pq_codec) = args.pq_codec {
|
||||
let data = fs::read(pq_codec).context("read pq codec")?;
|
||||
let pq_codec: ProductQuantizer = rmp_serde::from_read(&data[..]).context("decode pq codec")?;
|
||||
Some(pq_codec)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// construct FAISS index over query vectors for kNNs
|
||||
let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries {
|
||||
println!("constructing index");
|
||||
// not memory-efficient but this is small
|
||||
let data = fs::read(queries_file).context("read queries file")?;
|
||||
//let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?;
|
||||
let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?;
|
||||
//let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?;
|
||||
let unpacked = common::decode_fp16_buffer(&data);
|
||||
index.train(&unpacked)?;
|
||||
index.add(&unpacked)?;
|
||||
println!("done");
|
||||
(Some(index), unpacked.len() / D_EMB as usize)
|
||||
} else {
|
||||
(None, 0)
|
||||
};
|
||||
|
||||
// if sufficient config to split index exists, set up output files
|
||||
let mut shards_out = if let (Some(shards_dir), Some(centroids)) = (&args.shards_dir, &args.centroids) {
|
||||
let mut shards = Vec::new();
|
||||
let centroids_data = fs::read(centroids).context("read centroids file")?;
|
||||
let mut centroids_data = common::decode_fp16_buffer(¢roids_data);
|
||||
|
||||
if let Some(clip) = args.clip_centroids {
|
||||
centroids_data.truncate(clip * D_EMB as usize);
|
||||
}
|
||||
|
||||
for i in 0..(centroids_data.len() / (D_EMB as usize)) {
|
||||
let centroid = centroids_data[i * (D_EMB as usize)..(i + 1) * (D_EMB as usize)].to_vec();
|
||||
let mut file = fs::File::create(PathBuf::from(shards_dir).join(format!("{}.shard.msgpack", i))).context("create shard file")?;
|
||||
rmp_serde::encode::write(&mut file, &ShardInputHeader { id: i as u32, centroid: centroid.clone(), max_query_id })?;
|
||||
shards.push((centroid, file, 0, i));
|
||||
}
|
||||
|
||||
Some(shards)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// we can't fit all generated shards into RAM or they wouldn't be sharded anyway; keep file handles and locations lookup table
|
||||
let (mut read_out_vertices, shard_specs) = if let (Some(shards_dir), Some(_index_output)) = (&args.shards_dir, &args.index_output) {
|
||||
let mut original_ids_to_shards = Vec::new(); // locations in shard files of graph vertices: [(shard, offset, len)]
|
||||
let mut shard_id_mappings = Vec::new();
|
||||
let mut files = Vec::new();
|
||||
let mut shard_specs = Vec::new();
|
||||
|
||||
// open shard files and build lookup from their header files
|
||||
for file in fs::read_dir(shards_dir)? {
|
||||
let file = file?;
|
||||
let path = file.path();
|
||||
let filename = path.file_name().unwrap().to_str().unwrap();
|
||||
let (fst, snd) = filename.split_once(".").unwrap();
|
||||
if snd == "shard-header.msgpack" {
|
||||
let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?;
|
||||
if original_ids_to_shards.len() < (header.max as usize + 1) {
|
||||
// probably somewhat inefficient, oh well
|
||||
original_ids_to_shards.resize(header.max as usize + 1, [EMPTY_LOOKUP; SHARD_SPILL]);
|
||||
}
|
||||
for (i, &id) in header.mapping.iter().enumerate() {
|
||||
let len = header.offsets[i + 1] - header.offsets[i]; // always valid, as we have a dummy entry at the end
|
||||
let mut did_write = false;
|
||||
// write location to next empty slot
|
||||
//println!("{} {} {} {:?}", id, header.offsets[i], header.max, original_ids_to_shards[id as usize]);
|
||||
for rec in original_ids_to_shards[id as usize].iter_mut() {
|
||||
if *rec == EMPTY_LOOKUP {
|
||||
*rec = (header.id, header.offsets[i], len as u32);
|
||||
did_write = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// each record should be in exactly SHARD_SPILL shards
|
||||
if !did_write {
|
||||
bail!("shard processing inconsistency");
|
||||
}
|
||||
}
|
||||
|
||||
shard_specs.push((header.centroid.clone(), header.mapping[header.medioid as usize]));
|
||||
|
||||
shard_id_mappings.push((header.id, header.mapping));
|
||||
} else if snd == "shard.bin" {
|
||||
let file = fs::File::open(&path).context("open shard file")?;
|
||||
let id: u32 = str::parse(fst)?;
|
||||
files.push((id, file));
|
||||
}
|
||||
}
|
||||
|
||||
files.sort_by_key(|(id, _)| *id);
|
||||
shard_id_mappings.sort_by_key(|(id, _)| *id);
|
||||
|
||||
let read_out_vertices =move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
|
||||
let mut out_vertices: Vec<u32> = vec![];
|
||||
let mut shards: Vec<u32> = vec![];
|
||||
// look up each location in shard files
|
||||
for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() {
|
||||
shards.push(shard);
|
||||
let shard = shard as usize;
|
||||
// this random access is almost certainly rather slow
|
||||
// parallelize?
|
||||
files[shard].1.seek(SeekFrom::Start(offset))?;
|
||||
let mut buf = vec![0; len as usize];
|
||||
files[shard].1.read_exact(&mut buf)?;
|
||||
let s: &mut [u32] = bytemuck::cast_slice_mut(&mut *buf);
|
||||
for within_shard_id in s.iter_mut() {
|
||||
*within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize];
|
||||
}
|
||||
out_vertices.extend(s.iter().unique());
|
||||
}
|
||||
|
||||
Ok((out_vertices, shards))
|
||||
};
|
||||
|
||||
(Some(read_out_vertices), Some(shard_specs))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let mut index_output_file = if let Some(index_output) = &args.index_output {
|
||||
let main_output = BufWriter::new(fs::File::create(PathBuf::from(index_output).join("index.bin")).context("create index file")?);
|
||||
let pq_codes =BufWriter::new(fs::File::create(PathBuf::from(index_output).join("index.pq-codes.bin")).context("create index file")?);
|
||||
Some((main_output, pq_codes))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut output_file = args.output_embeddings.map(|x| fs::File::create(x).context("create output file")).transpose()?;
|
||||
|
||||
let mut i: u64 = 0;
|
||||
|
||||
let mut dedupe_ring: VecDeque<u64> = VecDeque::with_capacity(2<<20);
|
||||
let mut dedupe_hashset: HashSet<u64> = HashSet::with_capacity(2<<21);
|
||||
let mut dedupe_url_ring: VecDeque<u64> = VecDeque::with_capacity(2<<20);
|
||||
let mut dedupe_url_hashset: HashSet<u64> = HashSet::with_capacity(2<<21);
|
||||
|
||||
let (tx, rx) = sync_channel(1024);
|
||||
|
||||
let th = std::thread::spawn(move || reader_thread(&args.paths, tx));
|
||||
|
||||
let mut rng2 = rng.fork();
|
||||
let initial_filter = |x: ProcessedEntry| {
|
||||
i += 1;
|
||||
|
||||
if args.sample.is_some() && rng2.f32() > args.sample.unwrap() {
|
||||
return None;
|
||||
}
|
||||
let timestamp = Utc.timestamp_opt(x.timestamp as i64, 0).unwrap();
|
||||
|
||||
let embedding = common::decode_fp16_buffer(&x.embedding);
|
||||
|
||||
latest_timestamp = latest_timestamp.max(timestamp);
|
||||
earliest_timestamp = earliest_timestamp.min(timestamp);
|
||||
|
||||
for (_name, vec, histogram, threshold) in &mut embeddings {
|
||||
let dot = SpatialSimilarity::dot(&embedding, vec).unwrap() as f32;
|
||||
histogram.add(dot);
|
||||
if let Some(threshold) = threshold {
|
||||
if dot >= *threshold {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// distance thresholding is too costly to do over a long range so just do it badly
|
||||
if args.deduplicate {
|
||||
let code = binarize(&embedding);
|
||||
let mut hasher = seahash::SeaHasher::new();
|
||||
hasher.write(&x.metadata.final_url.as_bytes());
|
||||
let url_code = hasher.finish();
|
||||
if dedupe_ring.len() == dedupe_ring.capacity() {
|
||||
dedupe_ring.pop_front().unwrap();
|
||||
dedupe_url_ring.pop_front().unwrap();
|
||||
}
|
||||
dedupe_ring.push_back(code);
|
||||
dedupe_url_ring.push_back(url_code);
|
||||
if dedupe_hashset.insert(code) == false || dedupe_url_hashset.insert(url_code) == false {
|
||||
deduped_count += 1;
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
if args.print_records {
|
||||
println!("{} {} https://reddit.com/r/{}/comments/{} {}", timestamp, x.title, x.subreddit, x.id, x.metadata.final_url);
|
||||
}
|
||||
if args.original_url {
|
||||
println!("{}", x.url);
|
||||
}
|
||||
if args.titles {
|
||||
println!("{}", x.title);
|
||||
}
|
||||
if args.print_embeddings {
|
||||
println!("https://mse.osmarks.net/?e={}", base64::engine::general_purpose::URL_SAFE.encode(&x.embedding));
|
||||
}
|
||||
|
||||
Some((x, embedding))
|
||||
};
|
||||
|
||||
let mut dead_count = 0;
|
||||
|
||||
let mut bal_count = 1;
|
||||
|
||||
for batch in &rx.iter().filter_map(initial_filter).chunks(BATCH_SIZE) {
|
||||
let batch: Vec<_> = batch.collect();
|
||||
let batch_len = batch.len();
|
||||
|
||||
for (x, _embedding) in batch.iter() {
|
||||
if let Some(ref mut file) = output_file {
|
||||
file.write_all(&x.embedding)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(shards) = &mut shards_out {
|
||||
let mut knn_query = vec![];
|
||||
for (_, embedding) in batch.iter() {
|
||||
knn_query.extend(embedding);
|
||||
}
|
||||
|
||||
let index = queries_index.as_mut().context("need queries")?;
|
||||
let knn_result = index.search(&knn_query, KNN_K)?;
|
||||
|
||||
for (i, (x, embedding)) in batch.iter().enumerate() {
|
||||
// closest matches first
|
||||
shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| {
|
||||
let mut dot = SpatialSimilarity::dot(¢roid, &embedding).unwrap();
|
||||
dot -= BALANCE_WEIGHT * (shard_count as f64 / bal_count as f64);
|
||||
-scale_dot_result_f64(dot)
|
||||
});
|
||||
|
||||
let entry = ShardedRecord {
|
||||
id: count + i as u32,
|
||||
vector: x.embedding.clone(),
|
||||
query_knns: knn_result.labels[i * KNN_K..(i + 1)*KNN_K].into_iter().map(|x| x.get().unwrap() as u32).collect()
|
||||
};
|
||||
let data = rmp_serde::to_vec(&entry)?;
|
||||
for (_, file, shard_count, _shard_index) in shards[0..SHARD_SPILL].iter_mut() {
|
||||
file.write_all(&data)?;
|
||||
*shard_count += 1;
|
||||
}
|
||||
|
||||
bal_count += 1;
|
||||
// it is possible that using the count which is updated at the end of the batch leads to confusing numerics issues
|
||||
// also, this one starts at 1, so we avoid a division by zero on the first one
|
||||
}
|
||||
}
|
||||
|
||||
if let (Some(read_out_vertices), Some(index_output_file)) = (&mut read_out_vertices, &mut index_output_file) {
|
||||
let quantizer = pq_codec.as_ref().unwrap();
|
||||
|
||||
let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize);
|
||||
for (_x, embedding) in batch.iter() {
|
||||
batch_embeddings.extend_from_slice(&embedding);
|
||||
}
|
||||
let codes = quantizer.quantize_batch(&batch_embeddings);
|
||||
|
||||
for (i, (x, _embedding)) in batch.into_iter().enumerate() {
|
||||
let (vertices, shards) = read_out_vertices(count)?; // TODO: could parallelize this given the batching
|
||||
let mut entry = PackedIndexEntry {
|
||||
id: count + i as u32,
|
||||
vertices,
|
||||
vector: x.embedding.chunks_exact(2).map(|x| u16::from_le_bytes([x[0], x[1]])).collect(),
|
||||
timestamp: x.timestamp,
|
||||
dimensions: x.metadata.dimension,
|
||||
score: 0.5, // TODO
|
||||
url: x.metadata.final_url,
|
||||
shards
|
||||
};
|
||||
let mut bytes = bitcode::encode(&entry);
|
||||
if bytes.len() > (RECORD_PAD_SIZE - 2) {
|
||||
// we do need the records to fit in a fixed size and can't really drop things, so discard URL so it can exist as a graph node only
|
||||
entry.url = String::new();
|
||||
bytes = bitcode::encode(&entry);
|
||||
dead_count += 1;
|
||||
}
|
||||
let len = bytes.len() as u16;
|
||||
bytes.resize(RECORD_PAD_SIZE - 2, 0);
|
||||
index_output_file.0.write_all(&u16::to_le_bytes(len))?;
|
||||
index_output_file.0.write_all(&bytes)?;
|
||||
}
|
||||
index_output_file.1.write_all(&codes)?;
|
||||
}
|
||||
|
||||
count += batch_len as u32;
|
||||
}
|
||||
|
||||
if args.print_aggregates {
|
||||
println!("earliest={} latest={} count={} deduped={}", earliest_timestamp, latest_timestamp, count, deduped_count);
|
||||
println!("earliest={} latest={} count={} read={} deduped={}", earliest_timestamp, latest_timestamp, count, i, deduped_count);
|
||||
}
|
||||
if let Some(histogram_path) = args.histograms {
|
||||
let mut file = std::fs::File::create(histogram_path)?;
|
||||
for (name, _, histogram) in &embeddings {
|
||||
let mut file = fs::File::create(histogram_path)?;
|
||||
for (name, _, histogram, _) in &embeddings {
|
||||
let width = 800.0;
|
||||
let padding = 40.0;
|
||||
let bars_height = 300 as f64;
|
||||
@ -195,5 +467,26 @@ fn main() -> Result<()> {
|
||||
file.write_all(plot.into_string().as_bytes())?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(index_output) = &args.index_output {
|
||||
let mut file = fs::File::create(PathBuf::from(index_output).join("index.msgpack"))?;
|
||||
let header = IndexHeader {
|
||||
shards: shard_specs.unwrap(),
|
||||
count: count as u32,
|
||||
record_pad_size: RECORD_PAD_SIZE,
|
||||
dead_count,
|
||||
quantizer: pq_codec.unwrap()
|
||||
};
|
||||
file.write_all(rmp_serde::to_vec_named(&header)?.as_slice())?;
|
||||
}
|
||||
|
||||
if let Some(shards) = &mut shards_out {
|
||||
for (_centroid, _file, count, index) in shards.iter_mut() {
|
||||
println!("shard {}: {} records", index, count);
|
||||
}
|
||||
}
|
||||
|
||||
th.join().unwrap()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
133
src/generate_index_shard.rs
Normal file
133
src/generate_index_shard.rs
Normal file
@ -0,0 +1,133 @@
|
||||
use anyhow::{Result, Context};
|
||||
use itertools::Itertools;
|
||||
use std::io::{BufReader, Write, BufWriter};
|
||||
use rmp_serde::decode::Error as DecodeError;
|
||||
use std::fs;
|
||||
use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer};
|
||||
use half::f16;
|
||||
|
||||
mod common;
|
||||
|
||||
use common::{ShardInputHeader, ShardedRecord, ShardHeader};
|
||||
|
||||
const D_EMB: usize = 1152;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let mut rng = fastrand::Rng::new();
|
||||
|
||||
let mut stream = BufReader::new(fs::File::open(std::env::args().nth(1).unwrap()).context("read dump file")?);
|
||||
|
||||
let mut original_ids = vec![];
|
||||
let mut vector_data = vec![];
|
||||
let mut query_knns = vec![];
|
||||
|
||||
let header: ShardInputHeader = rmp_serde::from_read(&mut stream)?;
|
||||
let centroid_fp16 = header.centroid.iter().map(|x| f16::from_f32(*x)).collect::<Vec<_>>();
|
||||
|
||||
{
|
||||
let _timer = Timer::new("read shard");
|
||||
loop {
|
||||
let res: Result<ShardedRecord, DecodeError> = rmp_serde::from_read(&mut stream);
|
||||
match res {
|
||||
Ok(x) => {
|
||||
original_ids.push(x.id);
|
||||
vector_data.extend(bytemuck::cast_slice(&x.vector));
|
||||
query_knns.push(x.query_knns);
|
||||
},
|
||||
Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||||
Err(e) => return Err(e).context("decode fail")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut config = IndexBuildConfig {
|
||||
r: 64,
|
||||
r_cap: 80,
|
||||
l: 256,
|
||||
maxc: 750,
|
||||
alpha: 65536
|
||||
};
|
||||
|
||||
let vecs = VectorList {
|
||||
data: vector_data,
|
||||
d_emb: D_EMB,
|
||||
length: original_ids.len()
|
||||
};
|
||||
|
||||
let mut graph = IndexGraph::empty(original_ids.len(), config.r_cap);
|
||||
|
||||
{
|
||||
//let _timer = Timer::new("project bipartite");
|
||||
//project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd, config, &vecs);
|
||||
}
|
||||
|
||||
{
|
||||
let _timer = Timer::new("random fill");
|
||||
random_fill_graph(&mut rng, &mut graph, config.r);
|
||||
}
|
||||
|
||||
let medioid = vecs.iter().position_max_by_key(|&v| {
|
||||
dot(v, ¢roid_fp16)
|
||||
}).unwrap() as u32;
|
||||
|
||||
{
|
||||
let _timer = Timer::new("first pass");
|
||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||
}
|
||||
|
||||
{
|
||||
let _timer = Timer::new("second pass");
|
||||
config.alpha = 80000;
|
||||
build_graph(&mut rng, &mut graph, medioid, &vecs, config);
|
||||
}
|
||||
|
||||
std::mem::drop(vecs);
|
||||
|
||||
let mut query_knns_bwd = vec![vec![]; header.max_query_id];
|
||||
|
||||
{
|
||||
let _timer = Timer::new("compute backward edges");
|
||||
for (record_id, knns) in query_knns.iter().enumerate() {
|
||||
for &k in knns {
|
||||
query_knns_bwd[k as usize].push(record_id as u32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let _timer = Timer::new("augment bipartite");
|
||||
augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config);
|
||||
}
|
||||
|
||||
let len = original_ids.len();
|
||||
|
||||
{
|
||||
let _timer = Timer::new("write shard");
|
||||
let mut graph_data = BufWriter::new(fs::File::create(&format!("{}.shard.bin", header.id))?);
|
||||
|
||||
let mut offsets = Vec::with_capacity(original_ids.len());
|
||||
let mut offset = 0;
|
||||
for out_neighbours in graph.graph.iter() {
|
||||
let out_neighbours = out_neighbours.read().unwrap();
|
||||
offsets.push(offset);
|
||||
let s: &[u8] = bytemuck::cast_slice(&*out_neighbours);
|
||||
offset += s.len() as u64;
|
||||
graph_data.write_all(s)?;
|
||||
}
|
||||
offsets.push(offset); // dummy entry for convenience
|
||||
|
||||
let mut header_f = fs::File::create(&format!("{}.shard-header.msgpack", header.id))?;
|
||||
header_f.write_all(&rmp_serde::to_vec(&ShardHeader {
|
||||
id: header.id,
|
||||
max: *original_ids.iter().max().unwrap(),
|
||||
centroid: header.centroid,
|
||||
medioid,
|
||||
offsets,
|
||||
mapping: original_ids
|
||||
})?)?;
|
||||
}
|
||||
|
||||
println!("{} vectors", len);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -15,4 +15,9 @@ output, input, *xs = sys.argv[1:]
|
||||
with open(output, "wb") as f:
|
||||
with open(input, "rb") as g:
|
||||
input_data = g.read()
|
||||
f.write(get_embedding({"images": [input_data]})[0])
|
||||
if not xs:
|
||||
result = get_embedding({"images": [input_data]})[0]
|
||||
else:
|
||||
result = get_embedding({"text": xs})[0]
|
||||
f.write(result)
|
||||
print(base64.urlsafe_b64encode(result).decode("ascii"))
|
||||
|
173
src/query_disk_index.rs
Normal file
173
src/query_disk_index.rs
Normal file
@ -0,0 +1,173 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use diskann::vector::scale_dot_result_f64;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
use base64::Engine;
|
||||
use argh::FromArgs;
|
||||
use chrono::{TimeZone, Utc, DateTime};
|
||||
use std::collections::VecDeque;
|
||||
use itertools::Itertools;
|
||||
use foldhash::{HashSet, HashSetExt};
|
||||
use half::f16;
|
||||
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, ProductQuantizer, DistanceLUT, scale_dot_result}};
|
||||
use simsimd::SpatialSimilarity;
|
||||
use memmap2::{Mmap, MmapOptions};
|
||||
|
||||
mod common;
|
||||
|
||||
use common::{PackedIndexEntry, IndexHeader};
|
||||
|
||||
#[derive(FromArgs)]
|
||||
#[argh(description="Query disk index")]
|
||||
struct CLIArguments {
|
||||
#[argh(positional)]
|
||||
query_vector: String,
|
||||
#[argh(positional)]
|
||||
index_path: String
|
||||
}
|
||||
|
||||
fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result<PackedIndexEntry> {
|
||||
let offset = id as usize * header.record_pad_size;
|
||||
data_file.seek(SeekFrom::Start(offset as u64))?;
|
||||
let mut buf = vec![0; header.record_pad_size as usize];
|
||||
data_file.read_exact(&mut buf)?;
|
||||
let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize;
|
||||
Ok(bitcode::decode(&buf[2..len+2])?)
|
||||
}
|
||||
|
||||
fn read_pq_codes(id: u32, codes: &Mmap, buf: &mut Vec<u8>, pq_code_size: usize) {
|
||||
let loc = (id as usize) * pq_code_size;
|
||||
buf.extend(&codes[loc..loc+pq_code_size])
|
||||
}
|
||||
|
||||
struct Scratch {
|
||||
visited: HashSet<u32>,
|
||||
neighbour_buffer: NeighbourBuffer,
|
||||
neighbour_pre_buffer: Vec<u32>,
|
||||
visited_list: Vec<(u32, i64, String, Vec<u32>)>
|
||||
}
|
||||
|
||||
struct IndexRef<'a> {
|
||||
data_file: &'a mut fs::File,
|
||||
pq_codes: &'a Mmap,
|
||||
header: &'a IndexHeader,
|
||||
pq_code_size: usize
|
||||
}
|
||||
|
||||
fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef) -> Result<(usize, usize)> {
|
||||
scratch.visited.clear();
|
||||
scratch.neighbour_buffer.clear();
|
||||
scratch.visited_list.clear();
|
||||
|
||||
let mut cmps = 0;
|
||||
let mut pq_cmps = 0;
|
||||
|
||||
let node = read_node(start, index.data_file, index.header)?;
|
||||
let vector = bytemuck::cast_slice(&node.vector);
|
||||
scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vector));
|
||||
scratch.visited.insert(start);
|
||||
|
||||
while let Some(pt) = scratch.neighbour_buffer.next_unvisited() {
|
||||
//println!("pt {} {:?}", pt, graph.out_neighbours(pt));
|
||||
scratch.neighbour_pre_buffer.clear();
|
||||
let node = read_node(pt, index.data_file, index.header)?;
|
||||
let vector = bytemuck::cast_slice(&node.vector);
|
||||
let distance = fast_dot_noprefetch(query, &vector);
|
||||
cmps += 1;
|
||||
scratch.visited_list.push((pt, distance, node.url, node.shards));
|
||||
for &neighbour in node.vertices.iter() {
|
||||
if scratch.visited.insert(neighbour) {
|
||||
scratch.neighbour_pre_buffer.push(neighbour);
|
||||
}
|
||||
}
|
||||
let mut pq_codes = Vec::with_capacity(index.pq_code_size * scratch.neighbour_pre_buffer.len());
|
||||
for &neighbour in scratch.neighbour_pre_buffer.iter() {
|
||||
read_pq_codes(neighbour, index.pq_codes, &mut pq_codes, index.pq_code_size);
|
||||
}
|
||||
let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes);
|
||||
for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() {
|
||||
//let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO
|
||||
//let node = read_node(neighbour, index.data_file, index.header)?;
|
||||
//let vector = bytemuck::cast_slice(&node.vector);
|
||||
//let distance = fast_dot_noprefetch(query, &vector);
|
||||
pq_cmps += 1;
|
||||
scratch.neighbour_buffer.insert(neighbour, approx_scores[i]);
|
||||
//scratch.neighbour_buffer.insert(neighbour, distance);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((cmps, pq_cmps))
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: CLIArguments = argh::from_env();
|
||||
|
||||
let query_vector: Vec<f16> = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(args.query_vector.as_bytes()).context("invalid base64")?);
|
||||
let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::<Vec<f32>>();
|
||||
|
||||
let index_path = PathBuf::from(&args.index_path);
|
||||
let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?;
|
||||
let mut data_file = fs::File::open(index_path.join("index.bin"))?;
|
||||
let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin"))?;
|
||||
let pq_codes = unsafe {
|
||||
// This is unsafe because other processes could in principle edit the mmap'd file.
|
||||
// It would be annoying to do anything about this possibility, so ignore it.
|
||||
MmapOptions::new().populate().map(&pq_codes_file)?
|
||||
};
|
||||
|
||||
let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32);
|
||||
|
||||
println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len());
|
||||
|
||||
// TODO slightly dubious
|
||||
let selected_shard = header.shards.iter().position_max_by_key(|x| {
|
||||
scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap())
|
||||
}).unwrap();
|
||||
|
||||
println!("best shard is {}", selected_shard);
|
||||
|
||||
for shard in 0..header.shards.len() {
|
||||
let selected_start = header.shards[shard].1;
|
||||
|
||||
let mut scratch = Scratch {
|
||||
visited: HashSet::new(),
|
||||
neighbour_buffer: NeighbourBuffer::new(5000),
|
||||
neighbour_pre_buffer: Vec::new(),
|
||||
visited_list: Vec::new()
|
||||
};
|
||||
|
||||
//let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng);
|
||||
let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef {
|
||||
data_file: &mut data_file,
|
||||
header: &header,
|
||||
pq_codes: &pq_codes,
|
||||
pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code,
|
||||
})?;
|
||||
|
||||
println!("index scan {}: {:?} cmps", shard, cmps);
|
||||
|
||||
scratch.visited_list.sort_by_key(|x| -x.1);
|
||||
for (id, distance, url, shards) in scratch.visited_list.iter().take(20) {
|
||||
println!("index scan: {} {} {} {:?}", id, distance, url, shards);
|
||||
}
|
||||
println!("");
|
||||
}
|
||||
|
||||
let mut matches = vec![];
|
||||
// brute force scan
|
||||
for i in 0..header.count {
|
||||
let node = read_node(i, &mut data_file, &header)?;
|
||||
//println!("{} {}", i, node.url);
|
||||
let vector = bytemuck::cast_slice(&node.vector);
|
||||
matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards));
|
||||
}
|
||||
|
||||
matches.sort_by_key(|x| -x.1);
|
||||
for (id, distance, url, shards) in matches.iter().take(20) {
|
||||
println!("brute force: {} {} {} {:?}", id, distance, url, shards);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -19,7 +19,7 @@ static GLOBAL: MiMalloc = MiMalloc;
|
||||
|
||||
mod common;
|
||||
|
||||
use crate::common::{get_backend_config, query_clip_server, EmbeddingRequest};
|
||||
use crate::common::{get_backend_config, query_clip_server, EmbeddingRequest, OriginalImageMetadata, ProcessedEntry};
|
||||
|
||||
fn function_which_returns_some_na() -> Option<String> { Some(String::from("na")) }
|
||||
|
||||
@ -27,14 +27,16 @@ fn function_which_returns_some_na() -> Option<String> { Some(String::from("na"))
|
||||
#[serde(untagged)]
|
||||
enum BadTimestampFormat {
|
||||
Int(u64),
|
||||
String(String)
|
||||
String(String),
|
||||
Float(f64) // *what* are they doing?
|
||||
}
|
||||
|
||||
impl BadTimestampFormat {
|
||||
fn to_u64(&self) -> Result<u64> {
|
||||
match self {
|
||||
BadTimestampFormat::Int(x) => Ok(*x),
|
||||
BadTimestampFormat::String(x) => u64::from_str(&x).context("invalid string")
|
||||
BadTimestampFormat::String(x) => u64::from_str(&x).context("invalid string"),
|
||||
BadTimestampFormat::Float(x) => Ok(*x as u64)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -53,31 +55,9 @@ struct Entry {
|
||||
id: String
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)]
|
||||
struct OriginalImageMetadata {
|
||||
mime_type: String,
|
||||
original_file_size: usize,
|
||||
dimension: (u32, u32),
|
||||
final_url: String
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
struct ProcessedEntry {
|
||||
url: String,
|
||||
id: String,
|
||||
title: String,
|
||||
subreddit: String,
|
||||
author: String,
|
||||
timestamp: u64,
|
||||
#[serde(with = "serde_bytes")]
|
||||
embedding: Vec<u8>,
|
||||
metadata: OriginalImageMetadata
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
// we do exclude galleries doing this but there don't seem to be any in the dataset
|
||||
static ref URL_IGNORE: RegexSet = RegexSet::new([
|
||||
r"//reddit\.com",
|
||||
r"//reddit\.com/[^g]",
|
||||
r"\.html?",
|
||||
r"\.php",
|
||||
r"\?articleid=",
|
||||
@ -85,7 +65,7 @@ lazy_static! {
|
||||
r"\.xml",
|
||||
r"/rss/",
|
||||
r"//vimeo\.com",
|
||||
r"//www\.reddit\.com",
|
||||
r"//www\.reddit\.com/[^g]",
|
||||
r"//v\.redd\.it",
|
||||
r"\.gifv$",
|
||||
r"youtube\.com/user/"
|
||||
@ -113,6 +93,7 @@ lazy_static! {
|
||||
"/media",
|
||||
r"youtu\.be",
|
||||
r"youtube\.com",
|
||||
"reddit.com/gallery/"
|
||||
]).case_insensitive(true).build().unwrap();
|
||||
static ref ACCEPTABLE_FILETYPES: HashSet<&'static str> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"]
|
||||
.into_iter().collect();
|
||||
@ -139,6 +120,7 @@ lazy_static! {
|
||||
static ref HTML_EXTRACTION_RULES: Vec<(Regex, Regex)> = [
|
||||
(r"//imgur\.com/a/[A-Za-z0-9]+", r#"<meta name="twitter:image" data-react-helmet="true" content="([^"]+)">"#),
|
||||
(r"//imgur\.com/gallery/[A-Za-z0-9]+", r#"<meta name="twitter:image" data-react-helmet="true" content="([^"]+)">"#),
|
||||
(r"reddit.com/gallery/[A-Za-z0-9_-]+", r#"<li style="left:0px" class="_28TEYBuEdOuE3kN6UyoKMa"><figure class="_3BxRNDoASi9FbGX01ewiLg _3o5Vzct5tn9PE7e-emdDmf"><a href="([^"]+)" rel="noopener noreferrer" target="_blank""#) // lazy Reddit gallery extraction; hopefully they don't change the HTML
|
||||
].into_iter().map(|(r, e)| (Regex::new(r).unwrap(), Regex::new(e).unwrap())).collect();
|
||||
|
||||
static ref IMAGES_FETCHED_COUNTER: IntCounter = register_int_counter!("mse_scrape_images_fetched", "images fetched").unwrap();
|
||||
@ -181,10 +163,7 @@ fn process_file(path: PathBuf, tx: mpsc::Sender<Entry>, timestamp_threshold: Opt
|
||||
// Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine".
|
||||
let after_threshold = match timestamp_threshold {
|
||||
Some(threshold) => {
|
||||
let timestamp = match &entry.created_utc {
|
||||
BadTimestampFormat::Int(x) => *x,
|
||||
BadTimestampFormat::String(s) => u64::from_str(s).unwrap()
|
||||
};
|
||||
let timestamp = entry.created_utc.to_u64().unwrap();
|
||||
timestamp > threshold
|
||||
},
|
||||
None => true
|
||||
@ -219,7 +198,7 @@ struct Config {
|
||||
async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) -> Result<(Vec<u8>, String, String)> {
|
||||
let mut url = url.to_string();
|
||||
for (regex, replacement) in URL_REPLACEMENT_RULES.iter() {
|
||||
url = regex.replace(&url, *replacement).to_string();
|
||||
url = regex.replace_all(&url, *replacement).to_string();
|
||||
}
|
||||
|
||||
let mut html_extract_rule = None;
|
||||
@ -233,7 +212,7 @@ async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) ->
|
||||
|
||||
let mut response = client.get(&*url).send().await?;
|
||||
let content_type = std::str::from_utf8(&response.headers().get(reqwest::header::CONTENT_TYPE).context("no content type")?.as_bytes())?.to_owned();
|
||||
if !(ACCEPTABLE_FILETYPES.contains(&content_type[..]) || (html_extract_rule.is_some() && content_type == "text/html")) {
|
||||
if !(ACCEPTABLE_FILETYPES.contains(&content_type[..]) || (html_extract_rule.is_some() && content_type.starts_with("text/html"))) {
|
||||
return Err(anyhow!("invalid Content-Type"));
|
||||
}
|
||||
match response.content_length() {
|
||||
@ -255,7 +234,7 @@ async fn fetch_file(client: reqwest::Client, config: Arc<Config>, url: &str) ->
|
||||
return Err(anyhow!("discarded"));
|
||||
}
|
||||
if let Some(extract_rule) = html_extract_rule {
|
||||
if content_type == "text/html" {
|
||||
if content_type.starts_with("text/html") {
|
||||
let buffer = String::from_utf8_lossy(&buffer).to_string();
|
||||
if let Some(mat) = extract_rule.captures(&buffer) {
|
||||
let new_url = mat.get(1).unwrap().as_str();
|
||||
@ -344,11 +323,11 @@ async fn main() -> Result<()> {
|
||||
|
||||
let config = Arc::new(Config {
|
||||
max_content_length: 1<<24,
|
||||
input: String::from("./reddit_subs_202212/"),
|
||||
input: String::from("/srv/scratch/reddit_subs_202312/"),
|
||||
output: String::from("."),
|
||||
backend: String::from("http://localhost:1708"),
|
||||
mode: OperatingMode::FullRun,
|
||||
filename_threshold: Some(String::from("RS_2019-07.zst")),
|
||||
filename_threshold: None,
|
||||
metrics_addr: String::from("0.0.0.0:9914"),
|
||||
contact_info: String::from("scraping-ops@osmarks.net"),
|
||||
discard_hashes: [4168519401919155623, 4577010157274124110].into_iter().collect()
|
||||
|
Loading…
x
Reference in New Issue
Block a user