mirror of
				https://github.com/osmarks/maghammer.git
				synced 2025-10-31 14:03:00 +00:00 
			
		
		
		
	support modernbert & change walkdir
This commit is contained in:
		
							
								
								
									
										1587
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1587
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -8,7 +8,7 @@ tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_j | ||||
| tokio = { version = "1", features = ["full", "tracing"]} | ||||
| chrono = { version = "0.4", features = ["serde"] } | ||||
| anyhow = "1" | ||||
| async-walkdir = "2" | ||||
| walkdir = "2" | ||||
| lazy_static = "1" | ||||
| compact_str = { version = "0.8.0-beta", features = ["serde"] } | ||||
| seahash = "4" | ||||
| @@ -17,7 +17,7 @@ serde = { version = "1", features = ["derive"] } | ||||
| reqwest = "0.12" | ||||
| deadpool-postgres = "0.14" | ||||
| pgvector = { version = "0.3", features = ["postgres", "halfvec"] } | ||||
| tokenizers = { version = "0.19", features = ["http"] } | ||||
| tokenizers = { version = "0.21", features = ["http"] } | ||||
| regex = "1" | ||||
| futures = "0.3" | ||||
| html5gum = "0.5" | ||||
|   | ||||
							
								
								
									
										10
									
								
								config.toml
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								config.toml
									
									
									
									
									
								
							| @@ -1,12 +1,14 @@ | ||||
| database = "host=localhost port=5432 user=maghammer dbname=maghammer" | ||||
| database = "host=localhost port=5432 user=maghammer dbname=maghammer-test" | ||||
| concurrency = 8 | ||||
|  | ||||
| [semantic] | ||||
| backend = "http://100.64.0.10:1706" | ||||
| embedding_dim = 1024 | ||||
| tokenizer = "Snowflake/snowflake-arctic-embed-l" | ||||
| max_tokens = 128 | ||||
| batch_size = 256 | ||||
| tokenizer = "lightonai/modernbert-embed-large" | ||||
| max_tokens = 4096 | ||||
| batch_size = 12 | ||||
| document_prefix = "search_document: " | ||||
| query_prefix = "search_query: " | ||||
|  | ||||
| [indexers.text_files] | ||||
| path = "/data/archive" | ||||
|   | ||||
| @@ -14,13 +14,13 @@ from sentence_transformers import SentenceTransformer | ||||
| from prometheus_client import Counter, Histogram, REGISTRY, generate_latest | ||||
|  | ||||
| device = torch.device("cuda:0") | ||||
| model_name = "./snowflake-arctic-embed-l" | ||||
| model_name = "./modernbert-embed-large" | ||||
| model = SentenceTransformer(model_name).half().to(device) | ||||
| model.eval() | ||||
| print("model loaded") | ||||
|  | ||||
| MODELNAME = "sbert-snowflake-arctic-embed-l" | ||||
| BS = 256 | ||||
| MODELNAME = "modernbert-embed-large" | ||||
| BS = 64 | ||||
|  | ||||
| InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "callback"]) | ||||
|  | ||||
| @@ -38,7 +38,7 @@ def do_inference(params: InferenceParameters): | ||||
|             items_ctr.labels(MODELNAME, "text").inc(batch_size) | ||||
|             with inference_time_hist.labels(MODELNAME, batch_size).time(): | ||||
|                 features = model(text)["sentence_embedding"] | ||||
|                 features /= features.norm(dim=-1, keepdim=True) | ||||
|                 features /= features.norm(dim=-1, keepdim=True) + 1e-5 | ||||
|                 features = features.cpu().numpy() | ||||
|             batch_count_ctr.labels(MODELNAME).inc() | ||||
|             callback(True, features) | ||||
| @@ -115,4 +115,4 @@ try: | ||||
| except KeyboardInterrupt: | ||||
|     print("quitting") | ||||
|     import sys | ||||
|     sys.exit(0) | ||||
|     sys.exit(0) | ||||
|   | ||||
| @@ -3,17 +3,17 @@ use std::ffi::OsStr; | ||||
| use std::path::PathBuf; | ||||
| use std::sync::Arc; | ||||
| use anyhow::{Context, Result}; | ||||
| use async_walkdir::{WalkDir, DirEntry}; | ||||
| use compact_str::{CompactString, ToCompactString}; | ||||
| use epub::doc::EpubDoc; | ||||
| use futures::TryStreamExt; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use tokio::sync::RwLock; | ||||
| use crate::util::{hash_str, parse_html, parse_date, systemtime_to_utc, CONFIG}; | ||||
| use crate::util::{hash_str, parse_html, parse_date, systemtime_to_utc, CONFIG, async_walkdir}; | ||||
| use crate::indexer::{delete_nonexistent_files, Ctx, Indexer, TableSpec, ColumnSpec}; | ||||
| use chrono::prelude::*; | ||||
| use std::str::FromStr; | ||||
| use tracing::instrument; | ||||
| use walkdir::DirEntry; | ||||
|  | ||||
| #[derive(Serialize, Deserialize, Clone)] | ||||
| struct Config { | ||||
| @@ -91,14 +91,15 @@ async fn handle_epub(relpath: CompactString, ctx: Arc<Ctx>, entry: DirEntry, exi | ||||
|     let epub = entry.path(); | ||||
|  | ||||
|     let mut conn = ctx.pool.get().await?; | ||||
|     let metadata = entry.metadata().await?; | ||||
|     let metadata = tokio::fs::metadata(epub).await?; | ||||
|     let row = conn.query_opt("SELECT timestamp FROM books WHERE id = $1", &[&hash_str(&relpath)]).await?; | ||||
|     existing_files.write().await.insert(relpath.clone()); | ||||
|     let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC); | ||||
|     let modtime = systemtime_to_utc(metadata.modified()?)?; | ||||
|  | ||||
|     if modtime > timestamp { | ||||
|         let parse_result = match tokio::task::spawn_blocking(move || parse_epub(&epub)).await? { | ||||
|         let pathbuf = epub.to_path_buf(); | ||||
|         let parse_result = match tokio::task::spawn_blocking(move || parse_epub(&pathbuf)).await? { | ||||
|             Ok(x) => x, | ||||
|             Err(e) => { | ||||
|                 tracing::warn!("Failed parse for {}: {}", relpath, e); | ||||
| @@ -232,7 +233,7 @@ CREATE TABLE chapters ( | ||||
|     async fn run(&self, ctx: Arc<Ctx>) -> Result<()> { | ||||
|         let existing_files = Arc::new(RwLock::new(HashSet::new())); | ||||
|  | ||||
|         let entries = WalkDir::new(&self.config.path); | ||||
|         let entries = async_walkdir(self.config.path.clone().into(), false, |_| true); | ||||
|         let base_path = &self.config.path; | ||||
|  | ||||
|         entries.map_err(|e| anyhow::Error::from(e)).try_for_each_concurrent(Some(CONFIG.concurrency), |entry| | ||||
| @@ -247,17 +248,17 @@ CREATE TABLE chapters ( | ||||
|                     return Result::Ok(()); | ||||
|                 }; | ||||
|                 let ext = real_path.extension().and_then(OsStr::to_str); | ||||
|                 if !entry.file_type().await?.is_file() || "epub" != ext.unwrap_or_default() { | ||||
|                 if !entry.file_type().is_file() || "epub" != ext.unwrap_or_default() { | ||||
|                     return Ok(()); | ||||
|                 } | ||||
|                 let conn = ctx.pool.get().await?; | ||||
|                 existing_files.write().await.insert(CompactString::from(path)); | ||||
|                 let metadata = entry.metadata().await?; | ||||
|                 let metadata = tokio::fs::metadata(real_path).await?; | ||||
|                 let row = conn.query_opt("SELECT timestamp FROM books WHERE id = $1", &[&hash_str(path)]).await?; | ||||
|                 let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC); | ||||
|                 let modtime = systemtime_to_utc(metadata.modified()?)?; | ||||
|                 if modtime > timestamp { | ||||
|                     let relpath = entry.path().as_path().strip_prefix(&base_path)?.as_os_str().to_str().context("invalid path")?.to_compact_string(); | ||||
|                     let relpath = entry.path().strip_prefix(&base_path)?.as_os_str().to_str().context("invalid path")?.to_compact_string(); | ||||
|                     handle_epub(relpath.clone(), ctx, entry, existing_files).await | ||||
|                 } else { | ||||
|                     Ok(()) | ||||
|   | ||||
| @@ -9,9 +9,8 @@ use futures::{StreamExt, TryStreamExt}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use tokio::process::Command; | ||||
| use tokio::sync::RwLock; | ||||
| use crate::util::{hash_str, parse_date, parse_html, systemtime_to_utc, urlencode, CONFIG}; | ||||
| use crate::util::{hash_str, parse_date, parse_html, systemtime_to_utc, urlencode, CONFIG, async_walkdir}; | ||||
| use crate::indexer::{Ctx, Indexer, TableSpec, delete_nonexistent_files, ColumnSpec}; | ||||
| use async_walkdir::{Filtering, WalkDir}; | ||||
| use chrono::prelude::*; | ||||
| use regex::{RegexSet, Regex}; | ||||
| use tracing::instrument; | ||||
| @@ -67,6 +66,7 @@ struct MediaParse { | ||||
| struct Chapter { | ||||
|     start_time: String, | ||||
|     end_time: String, | ||||
|     #[serde(default)] | ||||
|     tags: HashMap<String, String> | ||||
| } | ||||
|  | ||||
| @@ -79,6 +79,7 @@ struct Disposition { | ||||
| struct Stream { | ||||
|     index: usize, | ||||
|     codec_name: Option<String>, | ||||
|     #[serde(default)] | ||||
|     duration: Option<String>, | ||||
|     codec_type: String, | ||||
|     #[serde(default)] | ||||
| @@ -89,7 +90,7 @@ struct Stream { | ||||
|  | ||||
| #[derive(Deserialize, Debug)] | ||||
| struct Format { | ||||
|     duration: String, | ||||
|     duration: Option<String>, | ||||
|     #[serde(default)] | ||||
|     tags: HashMap<String, String> | ||||
| } | ||||
| @@ -319,7 +320,7 @@ async fn parse_media(path: &PathBuf, ignore: Arc<RegexSet>) -> Result<MediaParse | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     result.duration = f32::from_str(&probe.format.duration)?; | ||||
|     result.duration = probe.format.duration.map(|x| f32::from_str(&x)).transpose()?.unwrap_or(0.0); | ||||
|  | ||||
|     let mut best_subtitle_track = (0, i8::MIN); | ||||
|  | ||||
| @@ -470,31 +471,31 @@ CREATE TABLE media_files ( | ||||
|     } | ||||
|  | ||||
|     async fn run(&self, ctx: Arc<Ctx>) -> Result<()> { | ||||
|         let entries = WalkDir::new(&self.config.path); | ||||
|         let ignore = Arc::new(self.ignore_files.clone()); | ||||
|         let base_path = Arc::new(self.config.path.clone()); | ||||
|         let base_path_ = base_path.clone(); | ||||
|         let ignore_metadata = self.ignore_metadata.clone(); | ||||
|  | ||||
|         let entries = async_walkdir(self.config.path.clone().into(), true, move |entry| { | ||||
|             let ignore = ignore.clone(); | ||||
|             let base_path = base_path.clone(); | ||||
|             let path = entry.path(); | ||||
|             tracing::trace!("filtering {:?}", path); | ||||
|             if let Some(path) = path.strip_prefix(&*base_path).ok().and_then(|x| x.to_str()) { | ||||
|                 if ignore.is_match(path) { | ||||
|                     return false; | ||||
|                 } | ||||
|             } else { | ||||
|                 return false; | ||||
|             } | ||||
|             true | ||||
|         }); | ||||
|  | ||||
|         let existing_files = Arc::new(RwLock::new(HashSet::new())); | ||||
|         let existing_files_ = existing_files.clone(); | ||||
|         let ctx_ = ctx.clone(); | ||||
|  | ||||
|         entries | ||||
|             .filter(move |entry| { | ||||
|                 let ignore = ignore.clone(); | ||||
|                 let base_path = base_path.clone(); | ||||
|                 let path = entry.path(); | ||||
|                 tracing::trace!("filtering {:?}", path); | ||||
|                 if let Some(path) = path.strip_prefix(&*base_path).ok().and_then(|x| x.to_str()) { | ||||
|                     if ignore.is_match(path) { | ||||
|                         return std::future::ready(Filtering::IgnoreDir); | ||||
|                     } | ||||
|                 } else { | ||||
|                     return std::future::ready(Filtering::IgnoreDir); | ||||
|                 } | ||||
|                 std::future::ready(Filtering::Continue) | ||||
|             }) | ||||
|             .map_err(|e| anyhow::Error::from(e)) | ||||
|             .filter(|r| { | ||||
|                 // ignore permissions errors because things apparently break otherwise | ||||
| @@ -517,18 +518,18 @@ CREATE TABLE media_files ( | ||||
|                     } else { | ||||
|                         return Result::Ok(()); | ||||
|                     }; | ||||
|                     if !entry.file_type().await?.is_file() { | ||||
|                     if !entry.file_type().is_file() { | ||||
|                         return Ok(()); | ||||
|                     } | ||||
|                     let mut conn = ctx.pool.get().await?; | ||||
|                     existing_files.write().await.insert(CompactString::from(path)); | ||||
|                     let metadata = entry.metadata().await?; | ||||
|                     let metadata = tokio::fs::metadata(real_path).await?; | ||||
|                     let row = conn.query_opt("SELECT timestamp FROM media_files WHERE id = $1", &[&hash_str(path)]).await?; | ||||
|                     let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC); | ||||
|                     let modtime = systemtime_to_utc(metadata.modified()?)?; | ||||
|                     tracing::trace!("timestamp {:?}", timestamp); | ||||
|                     if modtime > timestamp { | ||||
|                         match parse_media(&real_path, ignore_metadata).await { | ||||
|                         match parse_media(&real_path.to_path_buf(), ignore_metadata).await { | ||||
|                             Ok(x) => { | ||||
|                                 let tx = conn.transaction().await?; | ||||
|                                 tx.execute("DELETE FROM media_files WHERE id = $1", &[&hash_str(path)]).await?; | ||||
|   | ||||
| @@ -7,9 +7,8 @@ use futures::TryStreamExt; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use tokio::sync::RwLock; | ||||
| use tracing::instrument; | ||||
| use crate::util::{hash_str, parse_html, parse_pdf, systemtime_to_utc, urlencode, CONFIG}; | ||||
| use crate::util::{hash_str, parse_html, parse_pdf, systemtime_to_utc, urlencode, CONFIG, async_walkdir}; | ||||
| use crate::indexer::{Ctx, Indexer, TableSpec, delete_nonexistent_files, ColumnSpec}; | ||||
| use async_walkdir::WalkDir; | ||||
| use chrono::prelude::*; | ||||
| use regex::RegexSet; | ||||
|  | ||||
| @@ -85,7 +84,7 @@ CREATE TABLE text_files ( | ||||
|     } | ||||
|  | ||||
|     async fn run(&self, ctx: Arc<Ctx>) -> Result<()> { | ||||
|         let entries = WalkDir::new(&self.config.path); // TODO | ||||
|         let entries = async_walkdir(self.config.path.clone().into(), false, |_| true); | ||||
|         let ignore = &self.ignore; | ||||
|         let base_path = &self.config.path; | ||||
|  | ||||
| @@ -121,7 +120,7 @@ impl TextFilesIndexer { | ||||
|     } | ||||
|  | ||||
|     #[instrument(skip(ctx, ignore, existing_files, base_path))] | ||||
|     async fn process_file(entry: async_walkdir::DirEntry, ctx: Arc<Ctx>, ignore: &RegexSet, existing_files: Arc<RwLock<HashSet<CompactString>>>, base_path: &String) -> Result<()> { | ||||
|     async fn process_file(entry: walkdir::DirEntry, ctx: Arc<Ctx>, ignore: &RegexSet, existing_files: Arc<RwLock<HashSet<CompactString>>>, base_path: &String) -> Result<()> { | ||||
|         let real_path = entry.path(); | ||||
|         let path = if let Some(path) = real_path.strip_prefix(base_path)?.to_str() { | ||||
|             path | ||||
| @@ -129,17 +128,17 @@ impl TextFilesIndexer { | ||||
|             return Result::Ok(()); | ||||
|         }; | ||||
|         let ext = real_path.extension().and_then(OsStr::to_str); | ||||
|         if ignore.is_match(path) || !entry.file_type().await?.is_file() || !VALID_EXTENSIONS.contains(ext.unwrap_or_default()) { | ||||
|         if ignore.is_match(path) || !entry.file_type().is_file() || !VALID_EXTENSIONS.contains(ext.unwrap_or_default()) { | ||||
|             return Ok(()); | ||||
|         } | ||||
|         let mut conn = ctx.pool.get().await?; | ||||
|         existing_files.write().await.insert(CompactString::from(path)); | ||||
|         let metadata = entry.metadata().await?; | ||||
|         let metadata = tokio::fs::metadata(real_path).await?; | ||||
|         let row = conn.query_opt("SELECT timestamp FROM text_files WHERE id = $1", &[&hash_str(path)]).await?; | ||||
|         let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC); | ||||
|         let modtime = systemtime_to_utc(metadata.modified()?)?; | ||||
|         if modtime > timestamp { | ||||
|             let parse = TextFilesIndexer::read_file(&real_path, ext).await; | ||||
|             let parse = TextFilesIndexer::read_file(&real_path.to_path_buf(), ext).await; | ||||
|             match parse { | ||||
|                 Ok(None) => (), | ||||
|                 Ok(Some((content, title))) => { | ||||
|   | ||||
| @@ -2,12 +2,11 @@ use std::collections::HashSet; | ||||
| use std::{collections::HashMap, path::PathBuf}; | ||||
| use std::sync::Arc; | ||||
| use anyhow::{Context, Result}; | ||||
| use async_walkdir::WalkDir; | ||||
| use compact_str::ToCompactString; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use tokio_stream::StreamExt; | ||||
| use crate::indexer::{Ctx, Indexer, TableSpec, ColumnSpec}; | ||||
| use crate::util::{hash_thing, parse_html}; | ||||
| use crate::util::{hash_thing, parse_html, async_walkdir}; | ||||
| use chrono::prelude::*; | ||||
| use tokio::{fs::File, io::{AsyncBufReadExt, BufReader}}; | ||||
| use mail_parser::MessageParser; | ||||
| @@ -177,7 +176,7 @@ CREATE TABLE IF NOT EXISTS emails ( | ||||
|  | ||||
|     async fn run(&self, ctx: Arc<Ctx>) -> Result<()> { | ||||
|         let mut js: tokio::task::JoinSet<Result<()>> = tokio::task::JoinSet::new(); | ||||
|         let mut entries = WalkDir::new(&self.config.mboxes_path); | ||||
|         let mut entries = async_walkdir(self.config.mboxes_path.clone().into(), false, |_| true); | ||||
|         let config = self.config.clone(); | ||||
|         while let Some(entry) = entries.try_next().await? { | ||||
|             let path = entry.path(); | ||||
| @@ -189,7 +188,7 @@ CREATE TABLE IF NOT EXISTS emails ( | ||||
|                 if let None = ext { | ||||
|                     if !self.config.ignore_mboxes.contains(mbox.as_str()) { | ||||
|                         let ctx = ctx.clone(); | ||||
|                         js.spawn(EmailIndexer::process_mbox(ctx.clone(), path.clone(), mbox, folder, account.clone())); | ||||
|                         js.spawn(EmailIndexer::process_mbox(ctx.clone(), path.to_path_buf(), mbox, folder, account.clone())); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| @@ -237,10 +236,10 @@ impl EmailIndexer { | ||||
|                     &mail.raw, | ||||
|                     &account.as_str(), | ||||
|                     &mbox.as_str(), | ||||
|                     &mail.from, | ||||
|                     &mail.from_address, | ||||
|                     &mail.subject, | ||||
|                     &body | ||||
|                     &mail.from.replace("\0", ""), | ||||
|                     &mail.from_address.replace("\0", ""), | ||||
|                     &mail.subject.replace("\0", ""), | ||||
|                     &body.replace("\0", "") | ||||
|                 ]).await?; | ||||
|                 Ok(()) | ||||
|             } | ||||
|   | ||||
							
								
								
									
										16
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								src/main.rs
									
									
									
									
									
								
							| @@ -94,7 +94,7 @@ fn page(title: &str, body: Markup) -> web::HttpResponse { | ||||
| fn search_bar(ctx: &ServerState, value: &SearchQuery) -> Markup { | ||||
|     html! { | ||||
|         form.search-bar action="/search" { | ||||
|             input type="search" placeholder="Query" value=(value.q) name="q"; | ||||
|             input type="search" placeholder="Query" value=(value.q) name="q" autofocus; | ||||
|             select name="src_mode" { | ||||
|                 option selected[value.src_mode == SearchSourceMode::Mix] { "Mix" } | ||||
|                 option selected[value.src_mode == SearchSourceMode::Titles] { "Titles" } | ||||
| @@ -108,7 +108,7 @@ fn search_bar(ctx: &ServerState, value: &SearchQuery) -> Markup { | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             input type="submit" value="Search"; | ||||
|             input type="submit" value="Search" autofocus; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -651,7 +651,7 @@ async fn query_one_table(table: &TableSpec, indexer: &'static str, col: &ColumnS | ||||
|             let matchq_int = (*matchq * (i32::MAX as f64)) as i32; | ||||
|             results.push(SearchResult { | ||||
|                 docid: doc, | ||||
|                 indexer: indexer, | ||||
|                 indexer, | ||||
|                 table: table.name, | ||||
|                 column: col.name, | ||||
|                 title: title.clone(), | ||||
| @@ -667,9 +667,9 @@ async fn query_one_table(table: &TableSpec, indexer: &'static str, col: &ColumnS | ||||
| #[web::get("/search")] | ||||
| async fn fts_page(state: web::types::State<ServerState>, query: web::types::Query<SearchQuery>) -> impl web::Responder { | ||||
|     let state = (*state).clone(); | ||||
|     let (prefixed, unprefixed) = semantic::embed_query(&query.q, state.semantic.clone()).await?; | ||||
|     let prefixed = Arc::new(prefixed); | ||||
|     let unprefixed = Arc::new(unprefixed); | ||||
|     let (q_prefix, d_prefix) = semantic::embed_query(&query.q, state.semantic.clone()).await?; | ||||
|     let q_prefix = Arc::new(q_prefix); | ||||
|     let d_prefix = Arc::new(d_prefix); | ||||
|     let mut results = HashMap::new(); | ||||
|  | ||||
|     let mut set = tokio::task::JoinSet::new(); | ||||
| @@ -688,9 +688,9 @@ async fn fts_page(state: web::types::State<ServerState>, query: web::types::Quer | ||||
|                 // Some columns are not like this (their spec says so) so we use the model for symmetric search. | ||||
|                 // This does result in a different distribution of dot products. | ||||
|                 let (embedding_choice, count) = if col.fts_short { | ||||
|                     (unprefixed.clone(), if query.src_mode == SearchSourceMode::Mix { 1 } else { 20 }) | ||||
|                     (q_prefix.clone(), if query.src_mode == SearchSourceMode::Mix { 1 } else { 20 }) | ||||
|                 } else { | ||||
|                     (prefixed.clone(), 20) | ||||
|                     (q_prefix.clone(), 20) | ||||
|                 }; | ||||
|                 set.spawn(query_one_table(table, ix.name(), col, state.clone(), count, embedding_choice)); | ||||
|             } | ||||
|   | ||||
| @@ -26,7 +26,9 @@ pub struct SemanticSearchConfig { | ||||
|     backend: String, | ||||
|     pub embedding_dim: u32, | ||||
|     max_tokens: usize, | ||||
|     batch_size: usize | ||||
|     batch_size: usize, | ||||
|     document_prefix: String, | ||||
|     query_prefix: String | ||||
| } | ||||
|  | ||||
| fn convert_tokenizer_error(e: Error) -> anyhow::Error { | ||||
| @@ -102,6 +104,10 @@ fn decode_fp16_buffer(buf: &[u8]) -> Vec<f16> { | ||||
|         .collect() | ||||
| } | ||||
|  | ||||
| fn contains_nan(buf: &[f16]) -> bool { | ||||
|     buf.iter().any(|x| x.is_nan()) | ||||
| } | ||||
|  | ||||
| #[instrument(skip(client))] | ||||
| async fn send_batch(client: &Client, batch: Vec<&str>) -> Result<Vec<Vec<f16>>> { | ||||
|     let res = client.post(&CONFIG.semantic.backend) | ||||
| @@ -158,10 +164,11 @@ async fn insert_fts_chunks(id: i64, chunks: Vec<(Chunk, Vec<f16>)>, table: &Tabl | ||||
| } | ||||
|  | ||||
| pub async fn embed_query(q: &str, ctx: Arc<SemanticCtx>) -> Result<(HalfVector, HalfVector)> { | ||||
|     let prefixed = format!("Represent this sentence for searching relevant passages: {}", q); | ||||
|     let query_prefixed = format!("{}{}", CONFIG.semantic.query_prefix, q); | ||||
|     let doc_prefixed = format!("{}{}", CONFIG.semantic.document_prefix, q); | ||||
|     let mut result = send_batch(&ctx.client, vec![ | ||||
|         &prefixed, | ||||
|         q | ||||
|         &query_prefixed, | ||||
|         &doc_prefixed | ||||
|     ]).await?.into_iter(); | ||||
|     Ok((HalfVector::from(result.next().unwrap()), HalfVector::from(result.next().unwrap()))) | ||||
| } | ||||
| @@ -213,7 +220,7 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res | ||||
|                                             col: col.name, | ||||
|                                             start: chunk.0 as i32, | ||||
|                                             length: (chunk.1 - chunk.0) as i32, | ||||
|                                             text: chunk.2 | ||||
|                                             text: format!("{}{}", CONFIG.semantic.document_prefix, chunk.2) | ||||
|                                         }); | ||||
|                                     } | ||||
|                                 } | ||||
| @@ -242,6 +249,12 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res | ||||
|                         let mut pending = pending.write().await; | ||||
|  | ||||
|                         for (embedding, chunk) in embs.into_iter().zip(batch.into_iter()) { | ||||
|                             // ugly hack | ||||
|                             if contains_nan(&embedding) { | ||||
|                                 // write no entries | ||||
|                                 tokio::task::spawn(insert_fts_chunks(chunk.id, vec![], table, ctx.clone())); | ||||
|                             } | ||||
|  | ||||
|                             let record = pending.get_mut(&chunk.id).unwrap(); | ||||
|                             record.1 -= 1; | ||||
|                             let id = chunk.id; | ||||
| @@ -261,8 +274,8 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res | ||||
|  | ||||
|             let make_embeddings: tokio::task::JoinHandle<Result<()>> = tokio::task::spawn(make_embeddings); | ||||
|  | ||||
|             get_inputs.await??; | ||||
|             make_embeddings.await??; | ||||
|             get_inputs.await??; | ||||
|         } | ||||
|     } | ||||
|     Ok(()) | ||||
|   | ||||
							
								
								
									
										37
									
								
								src/util.rs
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								src/util.rs
									
									
									
									
									
								
							| @@ -7,6 +7,8 @@ use serde::{Serialize, Deserialize}; | ||||
| use tokio_postgres::Row; | ||||
| use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; | ||||
| use tracing::instrument; | ||||
| use futures::stream::Stream; | ||||
| use std::path::Path; | ||||
|  | ||||
| const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); | ||||
|  | ||||
| @@ -179,3 +181,38 @@ pub fn get_column_string(row: &Row, index: usize, spec: &ColumnSpec) -> Option<S | ||||
| pub fn urlencode(s: &str) -> String { | ||||
|     utf8_percent_encode(s, FRAGMENT).to_string() | ||||
| } | ||||
|  | ||||
| // TODO: check the joinhandle somewhere | ||||
| struct AsyncWalkdirStream { | ||||
|     rx: tokio::sync::mpsc::Receiver<Result<walkdir::DirEntry>>, | ||||
|     handle: tokio::task::JoinHandle<Result<()>> | ||||
| } | ||||
|  | ||||
| impl Stream for AsyncWalkdirStream { | ||||
|     type Item = Result<walkdir::DirEntry>; | ||||
|  | ||||
|     fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> { | ||||
|         match self.rx.poll_recv(cx) { | ||||
|             std::task::Poll::Ready(x) => std::task::Poll::Ready(x), | ||||
|             std::task::Poll::Pending => std::task::Poll::Pending | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Drop for AsyncWalkdirStream { | ||||
|     fn drop(&mut self) { | ||||
|         self.handle.abort(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn async_walkdir<F: Fn(&walkdir::DirEntry) -> bool + Send + 'static>(path: PathBuf, follow_symlinks: bool, filter: F) -> impl Stream<Item = Result<walkdir::DirEntry>> { | ||||
|     let (tx, rx) = tokio::sync::mpsc::channel(128); | ||||
|     let handle = tokio::task::spawn_blocking(move || { | ||||
|         let walker = walkdir::WalkDir::new(path).follow_links(follow_symlinks).into_iter(); | ||||
|         for entry in walker.filter_entry(|e| filter(&e)) { | ||||
|             tx.blocking_send(entry.map_err(|e| anyhow::Error::from(e)))?; | ||||
|         } | ||||
|         Ok::<(), anyhow::Error>(()) | ||||
|     }); | ||||
|     AsyncWalkdirStream { rx, handle } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 osmarks
					osmarks