diff --git a/src/common.rs b/src/common.rs index 051cdd0..1060b6f 100644 --- a/src/common.rs +++ b/src/common.rs @@ -94,3 +94,76 @@ pub fn decode_fp16_buffer(buf: &[u8]) -> Vec { .map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32()) .collect() } + +pub fn chunk_fp16_buffer(buf: &[u8]) -> Vec { + buf.chunks_exact(2) + .map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]])) + .collect() +} + +#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] +pub struct OriginalImageMetadata { + pub mime_type: String, + pub original_file_size: usize, + pub dimension: (u32, u32), + pub final_url: String +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct ProcessedEntry { + pub url: String, + pub id: String, + pub title: String, + pub subreddit: String, + pub author: String, + pub timestamp: u64, + #[serde(with="serde_bytes")] + pub embedding: Vec, + pub metadata: OriginalImageMetadata +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct ShardInputHeader { + pub id: u32, + pub centroid: Vec, + pub max_query_id: usize +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct ShardedRecord { + pub id: u32, + #[serde(with="serde_bytes")] + pub vector: Vec, // FP16 + pub query_knns: Vec +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct ShardHeader { + pub id: u32, + pub max: u32, + pub centroid: Vec, + pub medioid: u32, + pub offsets: Vec, + pub mapping: Vec +} + +#[derive(Clone, Debug, bitcode::Encode, bitcode::Decode)] +pub struct PackedIndexEntry { + pub vector: Vec, // FP16 values cast to u16 for storage + pub vertices: Vec, + pub id: u32, + pub timestamp: u64, + pub dimensions: (u32, u32), + pub score: f32, + pub url: String, + pub shards: Vec +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +pub struct IndexHeader { + pub shards: Vec<(Vec, u32)>, + pub count: u32, + pub dead_count: u32, + pub record_pad_size: usize, + pub quantizer: diskann::vector::ProductQuantizer +} diff --git a/src/dump_processor.rs b/src/dump_processor.rs index c4747a9..464f0b6 100644 --- a/src/dump_processor.rs +++ b/src/dump_processor.rs @@ -1,42 +1,30 @@ -use anyhow::{Result, Context}; +use anyhow::{bail, Context, Result}; use serde::{Serialize, Deserialize}; -use std::io::{BufReader, Write}; +use std::io::{BufReader, Read, Seek, SeekFrom, Write, BufWriter}; +use std::path::PathBuf; use rmp_serde::decode::Error as DecodeError; use std::fs; use base64::Engine; use argh::FromArgs; use chrono::{TimeZone, Utc, DateTime}; -use std::collections::{VecDeque, HashSet}; +use std::collections::VecDeque; +use faiss::Index; +use std::sync::mpsc::{sync_channel, SyncSender}; +use itertools::Itertools; +use simsimd::SpatialSimilarity; use std::hash::Hasher; +use foldhash::{HashSet, HashSetExt}; + +use diskann::vector::{scale_dot_result_f64, ProductQuantizer}; mod common; -// TODO refactor -#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] -struct OriginalImageMetadata { - mime_type: String, - original_file_size: usize, - dimension: (u32, u32), - final_url: String -} - -#[derive(Clone, Deserialize, Serialize, Debug)] -struct ProcessedEntry { - url: String, - id: String, - title: String, - subreddit: String, - author: String, - timestamp: u64, - #[serde(with="serde_bytes")] - embedding: Vec, - metadata: OriginalImageMetadata -} +use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader}; #[derive(FromArgs)] #[argh(description="Process scraper dump files")] struct CLIArguments { - #[argh(option, short='s', description="read subset of records")] + #[argh(option, short='s', description="randomly select fraction of records")] sample: Option, #[argh(switch, short='p', description="print basic information for records")] print_records: bool, @@ -44,16 +32,36 @@ struct CLIArguments { print_embeddings: bool, #[argh(switch, short='a', description="print aggregates")] print_aggregates: bool, - #[argh(option, short='E', description="x:y - load embedding named x from file y")] + #[argh(option, short='E', description="x:y[:f] - load embedding named x from file y, discard record if dot product >= filter threshold f")] embedding: Vec, #[argh(option, short='H', description="path for histograms of dot with embeddings")] histograms: Option, - #[argh(switch, short='D', description="enable deduplicator")] + #[argh(switch, short='D', description="enable deduplication")] deduplicate: bool, - #[argh(option, short='T', description="deduplication Hamming distance threshold")] - threshold: Option, #[argh(positional)] - paths: Vec + paths: Vec, + #[argh(option, short='o', description="output embeddings to file")] + output_embeddings: Option, + #[argh(option, short='C', description="split input into shards using these centroids")] + centroids: Option, + #[argh(option, short='S', description="index shard directory")] + shards_dir: Option, + #[argh(option, short='Q', description="query vectors file")] + queries: Option, + #[argh(option, short='d', description="random seed")] + seed: Option, + #[argh(option, short='i', description="index output directory")] + index_output: Option, + #[argh(switch, short='t', description="print titles")] + titles: bool, + #[argh(option, description="truncate centroids list")] + clip_centroids: Option, + #[argh(switch, description="print original linked URL")] + original_url: bool, + #[argh(option, short='q', description="product quantization codec path")] + pq_codec: Option, + #[argh(switch, short='j', description="JSON output")] + json: bool } #[derive(Clone, Deserialize, Serialize, Debug)] @@ -70,13 +78,14 @@ impl Histogram { } fn add(&mut self, x: f32) { - let bucket = if x < self.min { + let mut bucket = if x < self.min { 0 } else if x >= self.max { self.buckets.len() - 1 } else { ((x - self.min) / (self.max - self.min) * (self.buckets.len() as f32)) as usize }; + bucket = bucket.max(0).min(self.buckets.len() - 1); self.buckets[bucket] += 1; } @@ -86,93 +95,356 @@ impl Histogram { } } -fn dot(x: &[f32], y: &[f32]) -> f32 { - x.iter().zip(y).map(|(a, b)| a * b).sum::() -} - -fn binarize(x: &[f32]) -> Vec { - let mut buf = vec![0; x.len() / 8]; +fn binarize(x: &[f32]) -> u64 { + let mut hasher = seahash::SeaHasher::new(); for i in 0..(x.len() / 8) { - buf[i] = ((x[i * 8] > 0.0) as u8) + (((x[i * 8 + 1] > 0.0) as u8) << 1) + (((x[i * 8 + 2] > 0.0) as u8) << 2) + (((x[i * 8 + 3] > 0.0) as u8) << 3) + (((x[i * 8 + 4] > 0.0) as u8) << 4) + (((x[i * 8 + 5] > 0.0) as u8) << 5) + (((x[i * 8 + 6] > 0.0) as u8) << 6) + (((x[i * 8 + 7] > 0.0) as u8) << 7); + hasher.write_u8(((x[i * 8] > 0.0) as u8) + (((x[i * 8 + 1] > 0.0) as u8) << 1) + (((x[i * 8 + 2] > 0.0) as u8) << 2) + (((x[i * 8 + 3] > 0.0) as u8) << 3) + (((x[i * 8 + 4] > 0.0) as u8) << 4) + (((x[i * 8 + 5] > 0.0) as u8) << 5) + (((x[i * 8 + 6] > 0.0) as u8) << 6) + (((x[i * 8 + 7] > 0.0) as u8) << 7)); } - buf + hasher.finish() } -fn main() -> Result<()> { - let args: CLIArguments = argh::from_env(); - let mut rng = fastrand::Rng::new(); - let mut latest_timestamp = DateTime::::MIN_UTC; - let mut earliest_timestamp = DateTime::::MAX_UTC; - let mut count = 0; - let mut deduped_count = 0; - let mut embeddings = Vec::new(); - for x in args.embedding { - let (name, path) = x.split_once(':').unwrap(); - let blob = std::fs::read(path).context("read embedding")?; - embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512))); - } - - // TODO ring of vecs probably has bad cache locality - let mut dedupe_ring: VecDeque> = VecDeque::with_capacity(2<<10); - let threshold = args.threshold.unwrap_or(3); - - for path in args.paths { +fn reader_thread(paths: &Vec, tx: SyncSender) -> Result<()> { + for path in paths { let stream = zstd::stream::Decoder::new(fs::File::open(path).context("read dump file")?)?; let mut stream = BufReader::new(stream); loop { let res: Result = rmp_serde::from_read(&mut stream); - if res.is_ok() { - count += 1; - } match res { - Ok(x) => { - if args.sample.is_some() && rng.f32() > args.sample.unwrap() { - continue; - } - let timestamp = Utc.timestamp_opt(x.timestamp as i64, 0).unwrap(); - - let embedding = common::decode_fp16_buffer(&x.embedding); - - latest_timestamp = latest_timestamp.max(timestamp); - earliest_timestamp = earliest_timestamp.min(timestamp); - - if args.deduplicate { - let code = binarize(&embedding); - if dedupe_ring.len() == dedupe_ring.capacity() { - dedupe_ring.pop_front().unwrap(); - } - let has_match = dedupe_ring.iter().any(|x| hamming::distance(x, &code) <= threshold); - dedupe_ring.push_back(code); - if has_match { - deduped_count += 1; - continue; - } - } - - if args.print_records { - println!("{} {} https://reddit.com/r/{}/comments/{} {}", timestamp, x.title, x.subreddit, x.id, x.metadata.final_url); - } - if args.print_embeddings { - println!("https://mse.osmarks.net/?e={}", base64::engine::general_purpose::URL_SAFE.encode(&x.embedding)); - } - for (_name, vec, histogram) in &mut embeddings { - let dot = dot(&embedding, vec); - histogram.add(dot); - } - }, + Ok(x) => tx.send(x)?, Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break, Err(e) => return Err(e).context("decode fail") } } } + Ok(()) +} + +const SHARD_SPILL: usize = 2; +const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size +const D_EMB: u32 = 1152; +const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0); +const KNN_K: usize = 30; +const BALANCE_WEIGHT: f64 = 0.2; +const BATCH_SIZE: usize = 128; + +fn main() -> Result<()> { + let args: CLIArguments = argh::from_env(); + let mut rng = fastrand::Rng::with_seed(args.seed.unwrap_or(0)); + let mut latest_timestamp = DateTime::::MIN_UTC; + let mut earliest_timestamp = DateTime::::MAX_UTC; + let mut count = 0; + let mut deduped_count = 0; + + // load specified embeddings from files + let mut embeddings = Vec::new(); + for x in args.embedding { + let (name, snd) = x.split_once(':').unwrap(); + let (path, threshold) = if let Some((path, threshold)) = snd.split_once(':') { + (path, Some(threshold.parse::().context("parse threshold")?)) + } else { + (snd, None) + }; + let blob = fs::read(path).context("read embedding")?; + embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold)); + } + + let pq_codec = if let Some(pq_codec) = args.pq_codec { + let data = fs::read(pq_codec).context("read pq codec")?; + let pq_codec: ProductQuantizer = rmp_serde::from_read(&data[..]).context("decode pq codec")?; + Some(pq_codec) + } else { + None + }; + + // construct FAISS index over query vectors for kNNs + let (mut queries_index, max_query_id) = if let Some(queries_file) = args.queries { + println!("constructing index"); + // not memory-efficient but this is small + let data = fs::read(queries_file).context("read queries file")?; + //let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; + let mut index = faiss::index_factory(D_EMB, "HNSW32,SQfp16", faiss::MetricType::InnerProduct)?; + //let mut index = faiss::index_factory(D_EMB, "IVF4096,SQfp16", faiss::MetricType::InnerProduct)?; + let unpacked = common::decode_fp16_buffer(&data); + index.train(&unpacked)?; + index.add(&unpacked)?; + println!("done"); + (Some(index), unpacked.len() / D_EMB as usize) + } else { + (None, 0) + }; + + // if sufficient config to split index exists, set up output files + let mut shards_out = if let (Some(shards_dir), Some(centroids)) = (&args.shards_dir, &args.centroids) { + let mut shards = Vec::new(); + let centroids_data = fs::read(centroids).context("read centroids file")?; + let mut centroids_data = common::decode_fp16_buffer(¢roids_data); + + if let Some(clip) = args.clip_centroids { + centroids_data.truncate(clip * D_EMB as usize); + } + + for i in 0..(centroids_data.len() / (D_EMB as usize)) { + let centroid = centroids_data[i * (D_EMB as usize)..(i + 1) * (D_EMB as usize)].to_vec(); + let mut file = fs::File::create(PathBuf::from(shards_dir).join(format!("{}.shard.msgpack", i))).context("create shard file")?; + rmp_serde::encode::write(&mut file, &ShardInputHeader { id: i as u32, centroid: centroid.clone(), max_query_id })?; + shards.push((centroid, file, 0, i)); + } + + Some(shards) + } else { + None + }; + + // we can't fit all generated shards into RAM or they wouldn't be sharded anyway; keep file handles and locations lookup table + let (mut read_out_vertices, shard_specs) = if let (Some(shards_dir), Some(_index_output)) = (&args.shards_dir, &args.index_output) { + let mut original_ids_to_shards = Vec::new(); // locations in shard files of graph vertices: [(shard, offset, len)] + let mut shard_id_mappings = Vec::new(); + let mut files = Vec::new(); + let mut shard_specs = Vec::new(); + + // open shard files and build lookup from their header files + for file in fs::read_dir(shards_dir)? { + let file = file?; + let path = file.path(); + let filename = path.file_name().unwrap().to_str().unwrap(); + let (fst, snd) = filename.split_once(".").unwrap(); + if snd == "shard-header.msgpack" { + let header: ShardHeader = rmp_serde::from_read(BufReader::new(fs::File::open(path)?))?; + if original_ids_to_shards.len() < (header.max as usize + 1) { + // probably somewhat inefficient, oh well + original_ids_to_shards.resize(header.max as usize + 1, [EMPTY_LOOKUP; SHARD_SPILL]); + } + for (i, &id) in header.mapping.iter().enumerate() { + let len = header.offsets[i + 1] - header.offsets[i]; // always valid, as we have a dummy entry at the end + let mut did_write = false; + // write location to next empty slot + //println!("{} {} {} {:?}", id, header.offsets[i], header.max, original_ids_to_shards[id as usize]); + for rec in original_ids_to_shards[id as usize].iter_mut() { + if *rec == EMPTY_LOOKUP { + *rec = (header.id, header.offsets[i], len as u32); + did_write = true; + break; + } + } + // each record should be in exactly SHARD_SPILL shards + if !did_write { + bail!("shard processing inconsistency"); + } + } + + shard_specs.push((header.centroid.clone(), header.mapping[header.medioid as usize])); + + shard_id_mappings.push((header.id, header.mapping)); + } else if snd == "shard.bin" { + let file = fs::File::open(&path).context("open shard file")?; + let id: u32 = str::parse(fst)?; + files.push((id, file)); + } + } + + files.sort_by_key(|(id, _)| *id); + shard_id_mappings.sort_by_key(|(id, _)| *id); + + let read_out_vertices =move |id: u32| -> Result<(Vec, Vec)> { + let mut out_vertices: Vec = vec![]; + let mut shards: Vec = vec![]; + // look up each location in shard files + for &(shard, offset, len) in original_ids_to_shards[id as usize].iter() { + shards.push(shard); + let shard = shard as usize; + // this random access is almost certainly rather slow + // parallelize? + files[shard].1.seek(SeekFrom::Start(offset))?; + let mut buf = vec![0; len as usize]; + files[shard].1.read_exact(&mut buf)?; + let s: &mut [u32] = bytemuck::cast_slice_mut(&mut *buf); + for within_shard_id in s.iter_mut() { + *within_shard_id = shard_id_mappings[shard].1[*within_shard_id as usize]; + } + out_vertices.extend(s.iter().unique()); + } + + Ok((out_vertices, shards)) + }; + + (Some(read_out_vertices), Some(shard_specs)) + } else { + (None, None) + }; + + let mut index_output_file = if let Some(index_output) = &args.index_output { + let main_output = BufWriter::new(fs::File::create(PathBuf::from(index_output).join("index.bin")).context("create index file")?); + let pq_codes =BufWriter::new(fs::File::create(PathBuf::from(index_output).join("index.pq-codes.bin")).context("create index file")?); + Some((main_output, pq_codes)) + } else { + None + }; + + let mut output_file = args.output_embeddings.map(|x| fs::File::create(x).context("create output file")).transpose()?; + + let mut i: u64 = 0; + + let mut dedupe_ring: VecDeque = VecDeque::with_capacity(2<<20); + let mut dedupe_hashset: HashSet = HashSet::with_capacity(2<<21); + let mut dedupe_url_ring: VecDeque = VecDeque::with_capacity(2<<20); + let mut dedupe_url_hashset: HashSet = HashSet::with_capacity(2<<21); + + let (tx, rx) = sync_channel(1024); + + let th = std::thread::spawn(move || reader_thread(&args.paths, tx)); + + let mut rng2 = rng.fork(); + let initial_filter = |x: ProcessedEntry| { + i += 1; + + if args.sample.is_some() && rng2.f32() > args.sample.unwrap() { + return None; + } + let timestamp = Utc.timestamp_opt(x.timestamp as i64, 0).unwrap(); + + let embedding = common::decode_fp16_buffer(&x.embedding); + + latest_timestamp = latest_timestamp.max(timestamp); + earliest_timestamp = earliest_timestamp.min(timestamp); + + for (_name, vec, histogram, threshold) in &mut embeddings { + let dot = SpatialSimilarity::dot(&embedding, vec).unwrap() as f32; + histogram.add(dot); + if let Some(threshold) = threshold { + if dot >= *threshold { + return None; + } + } + } + + // distance thresholding is too costly to do over a long range so just do it badly + if args.deduplicate { + let code = binarize(&embedding); + let mut hasher = seahash::SeaHasher::new(); + hasher.write(&x.metadata.final_url.as_bytes()); + let url_code = hasher.finish(); + if dedupe_ring.len() == dedupe_ring.capacity() { + dedupe_ring.pop_front().unwrap(); + dedupe_url_ring.pop_front().unwrap(); + } + dedupe_ring.push_back(code); + dedupe_url_ring.push_back(url_code); + if dedupe_hashset.insert(code) == false || dedupe_url_hashset.insert(url_code) == false { + deduped_count += 1; + return None; + } + } + + if args.print_records { + println!("{} {} https://reddit.com/r/{}/comments/{} {}", timestamp, x.title, x.subreddit, x.id, x.metadata.final_url); + } + if args.original_url { + println!("{}", x.url); + } + if args.titles { + println!("{}", x.title); + } + if args.print_embeddings { + println!("https://mse.osmarks.net/?e={}", base64::engine::general_purpose::URL_SAFE.encode(&x.embedding)); + } + + Some((x, embedding)) + }; + + let mut dead_count = 0; + + let mut bal_count = 1; + + for batch in &rx.iter().filter_map(initial_filter).chunks(BATCH_SIZE) { + let batch: Vec<_> = batch.collect(); + let batch_len = batch.len(); + + for (x, _embedding) in batch.iter() { + if let Some(ref mut file) = output_file { + file.write_all(&x.embedding)?; + } + } + + if let Some(shards) = &mut shards_out { + let mut knn_query = vec![]; + for (_, embedding) in batch.iter() { + knn_query.extend(embedding); + } + + let index = queries_index.as_mut().context("need queries")?; + let knn_result = index.search(&knn_query, KNN_K)?; + + for (i, (x, embedding)) in batch.iter().enumerate() { + // closest matches first + shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| { + let mut dot = SpatialSimilarity::dot(¢roid, &embedding).unwrap(); + dot -= BALANCE_WEIGHT * (shard_count as f64 / bal_count as f64); + -scale_dot_result_f64(dot) + }); + + let entry = ShardedRecord { + id: count + i as u32, + vector: x.embedding.clone(), + query_knns: knn_result.labels[i * KNN_K..(i + 1)*KNN_K].into_iter().map(|x| x.get().unwrap() as u32).collect() + }; + let data = rmp_serde::to_vec(&entry)?; + for (_, file, shard_count, _shard_index) in shards[0..SHARD_SPILL].iter_mut() { + file.write_all(&data)?; + *shard_count += 1; + } + + bal_count += 1; + // it is possible that using the count which is updated at the end of the batch leads to confusing numerics issues + // also, this one starts at 1, so we avoid a division by zero on the first one + } + } + + if let (Some(read_out_vertices), Some(index_output_file)) = (&mut read_out_vertices, &mut index_output_file) { + let quantizer = pq_codec.as_ref().unwrap(); + + let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize); + for (_x, embedding) in batch.iter() { + batch_embeddings.extend_from_slice(&embedding); + } + let codes = quantizer.quantize_batch(&batch_embeddings); + + for (i, (x, _embedding)) in batch.into_iter().enumerate() { + let (vertices, shards) = read_out_vertices(count)?; // TODO: could parallelize this given the batching + let mut entry = PackedIndexEntry { + id: count + i as u32, + vertices, + vector: x.embedding.chunks_exact(2).map(|x| u16::from_le_bytes([x[0], x[1]])).collect(), + timestamp: x.timestamp, + dimensions: x.metadata.dimension, + score: 0.5, // TODO + url: x.metadata.final_url, + shards + }; + let mut bytes = bitcode::encode(&entry); + if bytes.len() > (RECORD_PAD_SIZE - 2) { + // we do need the records to fit in a fixed size and can't really drop things, so discard URL so it can exist as a graph node only + entry.url = String::new(); + bytes = bitcode::encode(&entry); + dead_count += 1; + } + let len = bytes.len() as u16; + bytes.resize(RECORD_PAD_SIZE - 2, 0); + index_output_file.0.write_all(&u16::to_le_bytes(len))?; + index_output_file.0.write_all(&bytes)?; + } + index_output_file.1.write_all(&codes)?; + } + + count += batch_len as u32; + } if args.print_aggregates { - println!("earliest={} latest={} count={} deduped={}", earliest_timestamp, latest_timestamp, count, deduped_count); + println!("earliest={} latest={} count={} read={} deduped={}", earliest_timestamp, latest_timestamp, count, i, deduped_count); } if let Some(histogram_path) = args.histograms { - let mut file = std::fs::File::create(histogram_path)?; - for (name, _, histogram) in &embeddings { + let mut file = fs::File::create(histogram_path)?; + for (name, _, histogram, _) in &embeddings { let width = 800.0; let padding = 40.0; let bars_height = 300 as f64; @@ -195,5 +467,26 @@ fn main() -> Result<()> { file.write_all(plot.into_string().as_bytes())?; } } + + if let Some(index_output) = &args.index_output { + let mut file = fs::File::create(PathBuf::from(index_output).join("index.msgpack"))?; + let header = IndexHeader { + shards: shard_specs.unwrap(), + count: count as u32, + record_pad_size: RECORD_PAD_SIZE, + dead_count, + quantizer: pq_codec.unwrap() + }; + file.write_all(rmp_serde::to_vec_named(&header)?.as_slice())?; + } + + if let Some(shards) = &mut shards_out { + for (_centroid, _file, count, index) in shards.iter_mut() { + println!("shard {}: {} records", index, count); + } + } + + th.join().unwrap()?; + Ok(()) } diff --git a/src/generate_index_shard.rs b/src/generate_index_shard.rs new file mode 100644 index 0000000..d1227d6 --- /dev/null +++ b/src/generate_index_shard.rs @@ -0,0 +1,133 @@ +use anyhow::{Result, Context}; +use itertools::Itertools; +use std::io::{BufReader, Write, BufWriter}; +use rmp_serde::decode::Error as DecodeError; +use std::fs; +use diskann::{augment_bipartite, build_graph, project_bipartite, random_fill_graph, vector::{dot, VectorList}, IndexBuildConfig, IndexGraph, Timer}; +use half::f16; + +mod common; + +use common::{ShardInputHeader, ShardedRecord, ShardHeader}; + +const D_EMB: usize = 1152; + +fn main() -> Result<()> { + let mut rng = fastrand::Rng::new(); + + let mut stream = BufReader::new(fs::File::open(std::env::args().nth(1).unwrap()).context("read dump file")?); + + let mut original_ids = vec![]; + let mut vector_data = vec![]; + let mut query_knns = vec![]; + + let header: ShardInputHeader = rmp_serde::from_read(&mut stream)?; + let centroid_fp16 = header.centroid.iter().map(|x| f16::from_f32(*x)).collect::>(); + + { + let _timer = Timer::new("read shard"); + loop { + let res: Result = rmp_serde::from_read(&mut stream); + match res { + Ok(x) => { + original_ids.push(x.id); + vector_data.extend(bytemuck::cast_slice(&x.vector)); + query_knns.push(x.query_knns); + }, + Err(DecodeError::InvalidDataRead(x)) | Err(DecodeError::InvalidMarkerRead(x)) if x.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e).context("decode fail") + } + } + } + + let mut config = IndexBuildConfig { + r: 64, + r_cap: 80, + l: 256, + maxc: 750, + alpha: 65536 + }; + + let vecs = VectorList { + data: vector_data, + d_emb: D_EMB, + length: original_ids.len() + }; + + let mut graph = IndexGraph::empty(original_ids.len(), config.r_cap); + + { + //let _timer = Timer::new("project bipartite"); + //project_bipartite(&mut rng, &mut graph, &query_knns, &query_knns_bwd, config, &vecs); + } + + { + let _timer = Timer::new("random fill"); + random_fill_graph(&mut rng, &mut graph, config.r); + } + + let medioid = vecs.iter().position_max_by_key(|&v| { + dot(v, ¢roid_fp16) + }).unwrap() as u32; + + { + let _timer = Timer::new("first pass"); + build_graph(&mut rng, &mut graph, medioid, &vecs, config); + } + + { + let _timer = Timer::new("second pass"); + config.alpha = 80000; + build_graph(&mut rng, &mut graph, medioid, &vecs, config); + } + + std::mem::drop(vecs); + + let mut query_knns_bwd = vec![vec![]; header.max_query_id]; + + { + let _timer = Timer::new("compute backward edges"); + for (record_id, knns) in query_knns.iter().enumerate() { + for &k in knns { + query_knns_bwd[k as usize].push(record_id as u32); + } + } + } + + { + let _timer = Timer::new("augment bipartite"); + augment_bipartite(&mut rng, &mut graph, query_knns, query_knns_bwd, config); + } + + let len = original_ids.len(); + + { + let _timer = Timer::new("write shard"); + let mut graph_data = BufWriter::new(fs::File::create(&format!("{}.shard.bin", header.id))?); + + let mut offsets = Vec::with_capacity(original_ids.len()); + let mut offset = 0; + for out_neighbours in graph.graph.iter() { + let out_neighbours = out_neighbours.read().unwrap(); + offsets.push(offset); + let s: &[u8] = bytemuck::cast_slice(&*out_neighbours); + offset += s.len() as u64; + graph_data.write_all(s)?; + } + offsets.push(offset); // dummy entry for convenience + + let mut header_f = fs::File::create(&format!("{}.shard-header.msgpack", header.id))?; + header_f.write_all(&rmp_serde::to_vec(&ShardHeader { + id: header.id, + max: *original_ids.iter().max().unwrap(), + centroid: header.centroid, + medioid, + offsets, + mapping: original_ids + })?)?; + } + + println!("{} vectors", len); + + Ok(()) +} diff --git a/src/get_embedding.py b/src/get_embedding.py index 2c6695b..54bf9b3 100644 --- a/src/get_embedding.py +++ b/src/get_embedding.py @@ -15,4 +15,9 @@ output, input, *xs = sys.argv[1:] with open(output, "wb") as f: with open(input, "rb") as g: input_data = g.read() - f.write(get_embedding({"images": [input_data]})[0]) + if not xs: + result = get_embedding({"images": [input_data]})[0] + else: + result = get_embedding({"text": xs})[0] + f.write(result) + print(base64.urlsafe_b64encode(result).decode("ascii")) diff --git a/src/query_disk_index.rs b/src/query_disk_index.rs new file mode 100644 index 0000000..c8c7f3d --- /dev/null +++ b/src/query_disk_index.rs @@ -0,0 +1,173 @@ +use anyhow::{bail, Context, Result}; +use diskann::vector::scale_dot_result_f64; +use serde::{Serialize, Deserialize}; +use std::io::{BufReader, Read, Seek, SeekFrom, Write}; +use std::path::PathBuf; +use std::fs; +use base64::Engine; +use argh::FromArgs; +use chrono::{TimeZone, Utc, DateTime}; +use std::collections::VecDeque; +use itertools::Itertools; +use foldhash::{HashSet, HashSetExt}; +use half::f16; +use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, ProductQuantizer, DistanceLUT, scale_dot_result}}; +use simsimd::SpatialSimilarity; +use memmap2::{Mmap, MmapOptions}; + +mod common; + +use common::{PackedIndexEntry, IndexHeader}; + +#[derive(FromArgs)] +#[argh(description="Query disk index")] +struct CLIArguments { + #[argh(positional)] + query_vector: String, + #[argh(positional)] + index_path: String +} + +fn read_node(id: u32, data_file: &mut fs::File, header: &IndexHeader) -> Result { + let offset = id as usize * header.record_pad_size; + data_file.seek(SeekFrom::Start(offset as u64))?; + let mut buf = vec![0; header.record_pad_size as usize]; + data_file.read_exact(&mut buf)?; + let len = u16::from_le_bytes(buf[0..2].try_into().unwrap()) as usize; + Ok(bitcode::decode(&buf[2..len+2])?) +} + +fn read_pq_codes(id: u32, codes: &Mmap, buf: &mut Vec, pq_code_size: usize) { + let loc = (id as usize) * pq_code_size; + buf.extend(&codes[loc..loc+pq_code_size]) +} + +struct Scratch { + visited: HashSet, + neighbour_buffer: NeighbourBuffer, + neighbour_pre_buffer: Vec, + visited_list: Vec<(u32, i64, String, Vec)> +} + +struct IndexRef<'a> { + data_file: &'a mut fs::File, + pq_codes: &'a Mmap, + header: &'a IndexHeader, + pq_code_size: usize +} + +fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preprocessed: &DistanceLUT, index: IndexRef) -> Result<(usize, usize)> { + scratch.visited.clear(); + scratch.neighbour_buffer.clear(); + scratch.visited_list.clear(); + + let mut cmps = 0; + let mut pq_cmps = 0; + + let node = read_node(start, index.data_file, index.header)?; + let vector = bytemuck::cast_slice(&node.vector); + scratch.neighbour_buffer.insert(start, fast_dot_noprefetch(query, &vector)); + scratch.visited.insert(start); + + while let Some(pt) = scratch.neighbour_buffer.next_unvisited() { + //println!("pt {} {:?}", pt, graph.out_neighbours(pt)); + scratch.neighbour_pre_buffer.clear(); + let node = read_node(pt, index.data_file, index.header)?; + let vector = bytemuck::cast_slice(&node.vector); + let distance = fast_dot_noprefetch(query, &vector); + cmps += 1; + scratch.visited_list.push((pt, distance, node.url, node.shards)); + for &neighbour in node.vertices.iter() { + if scratch.visited.insert(neighbour) { + scratch.neighbour_pre_buffer.push(neighbour); + } + } + let mut pq_codes = Vec::with_capacity(index.pq_code_size * scratch.neighbour_pre_buffer.len()); + for &neighbour in scratch.neighbour_pre_buffer.iter() { + read_pq_codes(neighbour, index.pq_codes, &mut pq_codes, index.pq_code_size); + } + let approx_scores = index.header.quantizer.asymmetric_dot_product(&query_preprocessed, &pq_codes); + for (i, &neighbour) in scratch.neighbour_pre_buffer.iter().enumerate() { + //let next_neighbour = scratch.neighbour_pre_buffer[(i + 1) % scratch.neighbour_pre_buffer.len()]; // TODO + //let node = read_node(neighbour, index.data_file, index.header)?; + //let vector = bytemuck::cast_slice(&node.vector); + //let distance = fast_dot_noprefetch(query, &vector); + pq_cmps += 1; + scratch.neighbour_buffer.insert(neighbour, approx_scores[i]); + //scratch.neighbour_buffer.insert(neighbour, distance); + } + } + + Ok((cmps, pq_cmps)) +} + +fn main() -> Result<()> { + let args: CLIArguments = argh::from_env(); + + let query_vector: Vec = common::chunk_fp16_buffer(&base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(args.query_vector.as_bytes()).context("invalid base64")?); + let query_vector_fp32 = query_vector.iter().map(|x| x.to_f32()).collect::>(); + + let index_path = PathBuf::from(&args.index_path); + let header: IndexHeader = rmp_serde::from_read(BufReader::new(fs::File::open(index_path.join("index.msgpack"))?))?; + let mut data_file = fs::File::open(index_path.join("index.bin"))?; + let pq_codes_file = fs::File::open(index_path.join("index.pq-codes.bin"))?; + let pq_codes = unsafe { + // This is unsafe because other processes could in principle edit the mmap'd file. + // It would be annoying to do anything about this possibility, so ignore it. + MmapOptions::new().populate().map(&pq_codes_file)? + }; + + let query_preprocessed = header.quantizer.preprocess_query(&query_vector_fp32); + + println!("{} items {} dead {} shards", header.count, header.dead_count, header.shards.len()); + + // TODO slightly dubious + let selected_shard = header.shards.iter().position_max_by_key(|x| { + scale_dot_result_f64(SpatialSimilarity::dot(&x.0, &query_vector_fp32).unwrap()) + }).unwrap(); + + println!("best shard is {}", selected_shard); + + for shard in 0..header.shards.len() { + let selected_start = header.shards[shard].1; + + let mut scratch = Scratch { + visited: HashSet::new(), + neighbour_buffer: NeighbourBuffer::new(5000), + neighbour_pre_buffer: Vec::new(), + visited_list: Vec::new() + }; + + //let query_vector = diskann::vector::quantize(&query_vector, &header.quantizer, &mut rng); + let cmps = greedy_search(&mut scratch, selected_start, &query_vector, &query_preprocessed, IndexRef { + data_file: &mut data_file, + header: &header, + pq_codes: &pq_codes, + pq_code_size: header.quantizer.n_dims / header.quantizer.n_dims_per_code, + })?; + + println!("index scan {}: {:?} cmps", shard, cmps); + + scratch.visited_list.sort_by_key(|x| -x.1); + for (id, distance, url, shards) in scratch.visited_list.iter().take(20) { + println!("index scan: {} {} {} {:?}", id, distance, url, shards); + } + println!(""); + } + + let mut matches = vec![]; + // brute force scan + for i in 0..header.count { + let node = read_node(i, &mut data_file, &header)?; + //println!("{} {}", i, node.url); + let vector = bytemuck::cast_slice(&node.vector); + matches.push((i, fast_dot_noprefetch(&query_vector, &vector), node.url, node.shards)); + } + + matches.sort_by_key(|x| -x.1); + for (id, distance, url, shards) in matches.iter().take(20) { + println!("brute force: {} {} {} {:?}", id, distance, url, shards); + } + + Ok(()) +} diff --git a/src/reddit_dump.rs b/src/reddit_dump.rs index ec42a15..b29c548 100644 --- a/src/reddit_dump.rs +++ b/src/reddit_dump.rs @@ -19,7 +19,7 @@ static GLOBAL: MiMalloc = MiMalloc; mod common; -use crate::common::{get_backend_config, query_clip_server, EmbeddingRequest}; +use crate::common::{get_backend_config, query_clip_server, EmbeddingRequest, OriginalImageMetadata, ProcessedEntry}; fn function_which_returns_some_na() -> Option { Some(String::from("na")) } @@ -27,14 +27,16 @@ fn function_which_returns_some_na() -> Option { Some(String::from("na")) #[serde(untagged)] enum BadTimestampFormat { Int(u64), - String(String) + String(String), + Float(f64) // *what* are they doing? } impl BadTimestampFormat { fn to_u64(&self) -> Result { match self { BadTimestampFormat::Int(x) => Ok(*x), - BadTimestampFormat::String(x) => u64::from_str(&x).context("invalid string") + BadTimestampFormat::String(x) => u64::from_str(&x).context("invalid string"), + BadTimestampFormat::Float(x) => Ok(*x as u64) } } } @@ -53,31 +55,9 @@ struct Entry { id: String } -#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] -struct OriginalImageMetadata { - mime_type: String, - original_file_size: usize, - dimension: (u32, u32), - final_url: String -} - -#[derive(Clone, Deserialize, Serialize, Debug)] -struct ProcessedEntry { - url: String, - id: String, - title: String, - subreddit: String, - author: String, - timestamp: u64, - #[serde(with = "serde_bytes")] - embedding: Vec, - metadata: OriginalImageMetadata -} - lazy_static! { - // we do exclude galleries doing this but there don't seem to be any in the dataset static ref URL_IGNORE: RegexSet = RegexSet::new([ - r"//reddit\.com", + r"//reddit\.com/[^g]", r"\.html?", r"\.php", r"\?articleid=", @@ -85,7 +65,7 @@ lazy_static! { r"\.xml", r"/rss/", r"//vimeo\.com", - r"//www\.reddit\.com", + r"//www\.reddit\.com/[^g]", r"//v\.redd\.it", r"\.gifv$", r"youtube\.com/user/" @@ -113,6 +93,7 @@ lazy_static! { "/media", r"youtu\.be", r"youtube\.com", + "reddit.com/gallery/" ]).case_insensitive(true).build().unwrap(); static ref ACCEPTABLE_FILETYPES: HashSet<&'static str> = ["image/png", "image/webp", "image/avif", "image/jpeg", "image/gif", "image/webp", "image/apng", "image/bmp", "image/tiff"] .into_iter().collect(); @@ -139,6 +120,7 @@ lazy_static! { static ref HTML_EXTRACTION_RULES: Vec<(Regex, Regex)> = [ (r"//imgur\.com/a/[A-Za-z0-9]+", r#""#), (r"//imgur\.com/gallery/[A-Za-z0-9]+", r#""#), + (r"reddit.com/gallery/[A-Za-z0-9_-]+", r#"
  • , timestamp_threshold: Opt // Technically this is slightly wrong because we reorder images slightly, but as long as it is not restarted all the time this is "fine". let after_threshold = match timestamp_threshold { Some(threshold) => { - let timestamp = match &entry.created_utc { - BadTimestampFormat::Int(x) => *x, - BadTimestampFormat::String(s) => u64::from_str(s).unwrap() - }; + let timestamp = entry.created_utc.to_u64().unwrap(); timestamp > threshold }, None => true @@ -219,7 +198,7 @@ struct Config { async fn fetch_file(client: reqwest::Client, config: Arc, url: &str) -> Result<(Vec, String, String)> { let mut url = url.to_string(); for (regex, replacement) in URL_REPLACEMENT_RULES.iter() { - url = regex.replace(&url, *replacement).to_string(); + url = regex.replace_all(&url, *replacement).to_string(); } let mut html_extract_rule = None; @@ -233,7 +212,7 @@ async fn fetch_file(client: reqwest::Client, config: Arc, url: &str) -> let mut response = client.get(&*url).send().await?; let content_type = std::str::from_utf8(&response.headers().get(reqwest::header::CONTENT_TYPE).context("no content type")?.as_bytes())?.to_owned(); - if !(ACCEPTABLE_FILETYPES.contains(&content_type[..]) || (html_extract_rule.is_some() && content_type == "text/html")) { + if !(ACCEPTABLE_FILETYPES.contains(&content_type[..]) || (html_extract_rule.is_some() && content_type.starts_with("text/html"))) { return Err(anyhow!("invalid Content-Type")); } match response.content_length() { @@ -255,7 +234,7 @@ async fn fetch_file(client: reqwest::Client, config: Arc, url: &str) -> return Err(anyhow!("discarded")); } if let Some(extract_rule) = html_extract_rule { - if content_type == "text/html" { + if content_type.starts_with("text/html") { let buffer = String::from_utf8_lossy(&buffer).to_string(); if let Some(mat) = extract_rule.captures(&buffer) { let new_url = mat.get(1).unwrap().as_str(); @@ -344,11 +323,11 @@ async fn main() -> Result<()> { let config = Arc::new(Config { max_content_length: 1<<24, - input: String::from("./reddit_subs_202212/"), + input: String::from("/srv/scratch/reddit_subs_202312/"), output: String::from("."), backend: String::from("http://localhost:1708"), mode: OperatingMode::FullRun, - filename_threshold: Some(String::from("RS_2019-07.zst")), + filename_threshold: None, metrics_addr: String::from("0.0.0.0:9914"), contact_info: String::from("scraping-ops@osmarks.net"), discard_hashes: [4168519401919155623, 4577010157274124110].into_iter().collect()