1
0
mirror of https://github.com/osmarks/maghammer.git synced 2025-09-10 22:36:04 +00:00

support modernbert & change walkdir

This commit is contained in:
osmarks
2025-04-11 21:57:01 +01:00
parent a11bc0b22d
commit a940dc9b46
11 changed files with 1143 additions and 636 deletions

1587
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_j
tokio = { version = "1", features = ["full", "tracing"]} tokio = { version = "1", features = ["full", "tracing"]}
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
anyhow = "1" anyhow = "1"
async-walkdir = "2" walkdir = "2"
lazy_static = "1" lazy_static = "1"
compact_str = { version = "0.8.0-beta", features = ["serde"] } compact_str = { version = "0.8.0-beta", features = ["serde"] }
seahash = "4" seahash = "4"
@@ -17,7 +17,7 @@ serde = { version = "1", features = ["derive"] }
reqwest = "0.12" reqwest = "0.12"
deadpool-postgres = "0.14" deadpool-postgres = "0.14"
pgvector = { version = "0.3", features = ["postgres", "halfvec"] } pgvector = { version = "0.3", features = ["postgres", "halfvec"] }
tokenizers = { version = "0.19", features = ["http"] } tokenizers = { version = "0.21", features = ["http"] }
regex = "1" regex = "1"
futures = "0.3" futures = "0.3"
html5gum = "0.5" html5gum = "0.5"

View File

@@ -1,12 +1,14 @@
database = "host=localhost port=5432 user=maghammer dbname=maghammer" database = "host=localhost port=5432 user=maghammer dbname=maghammer-test"
concurrency = 8 concurrency = 8
[semantic] [semantic]
backend = "http://100.64.0.10:1706" backend = "http://100.64.0.10:1706"
embedding_dim = 1024 embedding_dim = 1024
tokenizer = "Snowflake/snowflake-arctic-embed-l" tokenizer = "lightonai/modernbert-embed-large"
max_tokens = 128 max_tokens = 4096
batch_size = 256 batch_size = 12
document_prefix = "search_document: "
query_prefix = "search_query: "
[indexers.text_files] [indexers.text_files]
path = "/data/archive" path = "/data/archive"

View File

@@ -14,13 +14,13 @@ from sentence_transformers import SentenceTransformer
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
device = torch.device("cuda:0") device = torch.device("cuda:0")
model_name = "./snowflake-arctic-embed-l" model_name = "./modernbert-embed-large"
model = SentenceTransformer(model_name).half().to(device) model = SentenceTransformer(model_name).half().to(device)
model.eval() model.eval()
print("model loaded") print("model loaded")
MODELNAME = "sbert-snowflake-arctic-embed-l" MODELNAME = "modernbert-embed-large"
BS = 256 BS = 64
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "callback"]) InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "callback"])
@@ -38,7 +38,7 @@ def do_inference(params: InferenceParameters):
items_ctr.labels(MODELNAME, "text").inc(batch_size) items_ctr.labels(MODELNAME, "text").inc(batch_size)
with inference_time_hist.labels(MODELNAME, batch_size).time(): with inference_time_hist.labels(MODELNAME, batch_size).time():
features = model(text)["sentence_embedding"] features = model(text)["sentence_embedding"]
features /= features.norm(dim=-1, keepdim=True) features /= features.norm(dim=-1, keepdim=True) + 1e-5
features = features.cpu().numpy() features = features.cpu().numpy()
batch_count_ctr.labels(MODELNAME).inc() batch_count_ctr.labels(MODELNAME).inc()
callback(True, features) callback(True, features)

View File

@@ -3,17 +3,17 @@ use std::ffi::OsStr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_walkdir::{WalkDir, DirEntry};
use compact_str::{CompactString, ToCompactString}; use compact_str::{CompactString, ToCompactString};
use epub::doc::EpubDoc; use epub::doc::EpubDoc;
use futures::TryStreamExt; use futures::TryStreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::util::{hash_str, parse_html, parse_date, systemtime_to_utc, CONFIG}; use crate::util::{hash_str, parse_html, parse_date, systemtime_to_utc, CONFIG, async_walkdir};
use crate::indexer::{delete_nonexistent_files, Ctx, Indexer, TableSpec, ColumnSpec}; use crate::indexer::{delete_nonexistent_files, Ctx, Indexer, TableSpec, ColumnSpec};
use chrono::prelude::*; use chrono::prelude::*;
use std::str::FromStr; use std::str::FromStr;
use tracing::instrument; use tracing::instrument;
use walkdir::DirEntry;
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct Config { struct Config {
@@ -91,14 +91,15 @@ async fn handle_epub(relpath: CompactString, ctx: Arc<Ctx>, entry: DirEntry, exi
let epub = entry.path(); let epub = entry.path();
let mut conn = ctx.pool.get().await?; let mut conn = ctx.pool.get().await?;
let metadata = entry.metadata().await?; let metadata = tokio::fs::metadata(epub).await?;
let row = conn.query_opt("SELECT timestamp FROM books WHERE id = $1", &[&hash_str(&relpath)]).await?; let row = conn.query_opt("SELECT timestamp FROM books WHERE id = $1", &[&hash_str(&relpath)]).await?;
existing_files.write().await.insert(relpath.clone()); existing_files.write().await.insert(relpath.clone());
let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC); let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC);
let modtime = systemtime_to_utc(metadata.modified()?)?; let modtime = systemtime_to_utc(metadata.modified()?)?;
if modtime > timestamp { if modtime > timestamp {
let parse_result = match tokio::task::spawn_blocking(move || parse_epub(&epub)).await? { let pathbuf = epub.to_path_buf();
let parse_result = match tokio::task::spawn_blocking(move || parse_epub(&pathbuf)).await? {
Ok(x) => x, Ok(x) => x,
Err(e) => { Err(e) => {
tracing::warn!("Failed parse for {}: {}", relpath, e); tracing::warn!("Failed parse for {}: {}", relpath, e);
@@ -232,7 +233,7 @@ CREATE TABLE chapters (
async fn run(&self, ctx: Arc<Ctx>) -> Result<()> { async fn run(&self, ctx: Arc<Ctx>) -> Result<()> {
let existing_files = Arc::new(RwLock::new(HashSet::new())); let existing_files = Arc::new(RwLock::new(HashSet::new()));
let entries = WalkDir::new(&self.config.path); let entries = async_walkdir(self.config.path.clone().into(), false, |_| true);
let base_path = &self.config.path; let base_path = &self.config.path;
entries.map_err(|e| anyhow::Error::from(e)).try_for_each_concurrent(Some(CONFIG.concurrency), |entry| entries.map_err(|e| anyhow::Error::from(e)).try_for_each_concurrent(Some(CONFIG.concurrency), |entry|
@@ -247,17 +248,17 @@ CREATE TABLE chapters (
return Result::Ok(()); return Result::Ok(());
}; };
let ext = real_path.extension().and_then(OsStr::to_str); let ext = real_path.extension().and_then(OsStr::to_str);
if !entry.file_type().await?.is_file() || "epub" != ext.unwrap_or_default() { if !entry.file_type().is_file() || "epub" != ext.unwrap_or_default() {
return Ok(()); return Ok(());
} }
let conn = ctx.pool.get().await?; let conn = ctx.pool.get().await?;
existing_files.write().await.insert(CompactString::from(path)); existing_files.write().await.insert(CompactString::from(path));
let metadata = entry.metadata().await?; let metadata = tokio::fs::metadata(real_path).await?;
let row = conn.query_opt("SELECT timestamp FROM books WHERE id = $1", &[&hash_str(path)]).await?; let row = conn.query_opt("SELECT timestamp FROM books WHERE id = $1", &[&hash_str(path)]).await?;
let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC); let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC);
let modtime = systemtime_to_utc(metadata.modified()?)?; let modtime = systemtime_to_utc(metadata.modified()?)?;
if modtime > timestamp { if modtime > timestamp {
let relpath = entry.path().as_path().strip_prefix(&base_path)?.as_os_str().to_str().context("invalid path")?.to_compact_string(); let relpath = entry.path().strip_prefix(&base_path)?.as_os_str().to_str().context("invalid path")?.to_compact_string();
handle_epub(relpath.clone(), ctx, entry, existing_files).await handle_epub(relpath.clone(), ctx, entry, existing_files).await
} else { } else {
Ok(()) Ok(())

View File

@@ -9,9 +9,8 @@ use futures::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::process::Command; use tokio::process::Command;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::util::{hash_str, parse_date, parse_html, systemtime_to_utc, urlencode, CONFIG}; use crate::util::{hash_str, parse_date, parse_html, systemtime_to_utc, urlencode, CONFIG, async_walkdir};
use crate::indexer::{Ctx, Indexer, TableSpec, delete_nonexistent_files, ColumnSpec}; use crate::indexer::{Ctx, Indexer, TableSpec, delete_nonexistent_files, ColumnSpec};
use async_walkdir::{Filtering, WalkDir};
use chrono::prelude::*; use chrono::prelude::*;
use regex::{RegexSet, Regex}; use regex::{RegexSet, Regex};
use tracing::instrument; use tracing::instrument;
@@ -67,6 +66,7 @@ struct MediaParse {
struct Chapter { struct Chapter {
start_time: String, start_time: String,
end_time: String, end_time: String,
#[serde(default)]
tags: HashMap<String, String> tags: HashMap<String, String>
} }
@@ -79,6 +79,7 @@ struct Disposition {
struct Stream { struct Stream {
index: usize, index: usize,
codec_name: Option<String>, codec_name: Option<String>,
#[serde(default)]
duration: Option<String>, duration: Option<String>,
codec_type: String, codec_type: String,
#[serde(default)] #[serde(default)]
@@ -89,7 +90,7 @@ struct Stream {
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
struct Format { struct Format {
duration: String, duration: Option<String>,
#[serde(default)] #[serde(default)]
tags: HashMap<String, String> tags: HashMap<String, String>
} }
@@ -319,7 +320,7 @@ async fn parse_media(path: &PathBuf, ignore: Arc<RegexSet>) -> Result<MediaParse
} }
} }
result.duration = f32::from_str(&probe.format.duration)?; result.duration = probe.format.duration.map(|x| f32::from_str(&x)).transpose()?.unwrap_or(0.0);
let mut best_subtitle_track = (0, i8::MIN); let mut best_subtitle_track = (0, i8::MIN);
@@ -470,31 +471,31 @@ CREATE TABLE media_files (
} }
async fn run(&self, ctx: Arc<Ctx>) -> Result<()> { async fn run(&self, ctx: Arc<Ctx>) -> Result<()> {
let entries = WalkDir::new(&self.config.path);
let ignore = Arc::new(self.ignore_files.clone()); let ignore = Arc::new(self.ignore_files.clone());
let base_path = Arc::new(self.config.path.clone()); let base_path = Arc::new(self.config.path.clone());
let base_path_ = base_path.clone(); let base_path_ = base_path.clone();
let ignore_metadata = self.ignore_metadata.clone(); let ignore_metadata = self.ignore_metadata.clone();
let existing_files = Arc::new(RwLock::new(HashSet::new())); let entries = async_walkdir(self.config.path.clone().into(), true, move |entry| {
let existing_files_ = existing_files.clone();
let ctx_ = ctx.clone();
entries
.filter(move |entry| {
let ignore = ignore.clone(); let ignore = ignore.clone();
let base_path = base_path.clone(); let base_path = base_path.clone();
let path = entry.path(); let path = entry.path();
tracing::trace!("filtering {:?}", path); tracing::trace!("filtering {:?}", path);
if let Some(path) = path.strip_prefix(&*base_path).ok().and_then(|x| x.to_str()) { if let Some(path) = path.strip_prefix(&*base_path).ok().and_then(|x| x.to_str()) {
if ignore.is_match(path) { if ignore.is_match(path) {
return std::future::ready(Filtering::IgnoreDir); return false;
} }
} else { } else {
return std::future::ready(Filtering::IgnoreDir); return false;
} }
std::future::ready(Filtering::Continue) true
}) });
let existing_files = Arc::new(RwLock::new(HashSet::new()));
let existing_files_ = existing_files.clone();
let ctx_ = ctx.clone();
entries
.map_err(|e| anyhow::Error::from(e)) .map_err(|e| anyhow::Error::from(e))
.filter(|r| { .filter(|r| {
// ignore permissions errors because things apparently break otherwise // ignore permissions errors because things apparently break otherwise
@@ -517,18 +518,18 @@ CREATE TABLE media_files (
} else { } else {
return Result::Ok(()); return Result::Ok(());
}; };
if !entry.file_type().await?.is_file() { if !entry.file_type().is_file() {
return Ok(()); return Ok(());
} }
let mut conn = ctx.pool.get().await?; let mut conn = ctx.pool.get().await?;
existing_files.write().await.insert(CompactString::from(path)); existing_files.write().await.insert(CompactString::from(path));
let metadata = entry.metadata().await?; let metadata = tokio::fs::metadata(real_path).await?;
let row = conn.query_opt("SELECT timestamp FROM media_files WHERE id = $1", &[&hash_str(path)]).await?; let row = conn.query_opt("SELECT timestamp FROM media_files WHERE id = $1", &[&hash_str(path)]).await?;
let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC); let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC);
let modtime = systemtime_to_utc(metadata.modified()?)?; let modtime = systemtime_to_utc(metadata.modified()?)?;
tracing::trace!("timestamp {:?}", timestamp); tracing::trace!("timestamp {:?}", timestamp);
if modtime > timestamp { if modtime > timestamp {
match parse_media(&real_path, ignore_metadata).await { match parse_media(&real_path.to_path_buf(), ignore_metadata).await {
Ok(x) => { Ok(x) => {
let tx = conn.transaction().await?; let tx = conn.transaction().await?;
tx.execute("DELETE FROM media_files WHERE id = $1", &[&hash_str(path)]).await?; tx.execute("DELETE FROM media_files WHERE id = $1", &[&hash_str(path)]).await?;

View File

@@ -7,9 +7,8 @@ use futures::TryStreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::instrument; use tracing::instrument;
use crate::util::{hash_str, parse_html, parse_pdf, systemtime_to_utc, urlencode, CONFIG}; use crate::util::{hash_str, parse_html, parse_pdf, systemtime_to_utc, urlencode, CONFIG, async_walkdir};
use crate::indexer::{Ctx, Indexer, TableSpec, delete_nonexistent_files, ColumnSpec}; use crate::indexer::{Ctx, Indexer, TableSpec, delete_nonexistent_files, ColumnSpec};
use async_walkdir::WalkDir;
use chrono::prelude::*; use chrono::prelude::*;
use regex::RegexSet; use regex::RegexSet;
@@ -85,7 +84,7 @@ CREATE TABLE text_files (
} }
async fn run(&self, ctx: Arc<Ctx>) -> Result<()> { async fn run(&self, ctx: Arc<Ctx>) -> Result<()> {
let entries = WalkDir::new(&self.config.path); // TODO let entries = async_walkdir(self.config.path.clone().into(), false, |_| true);
let ignore = &self.ignore; let ignore = &self.ignore;
let base_path = &self.config.path; let base_path = &self.config.path;
@@ -121,7 +120,7 @@ impl TextFilesIndexer {
} }
#[instrument(skip(ctx, ignore, existing_files, base_path))] #[instrument(skip(ctx, ignore, existing_files, base_path))]
async fn process_file(entry: async_walkdir::DirEntry, ctx: Arc<Ctx>, ignore: &RegexSet, existing_files: Arc<RwLock<HashSet<CompactString>>>, base_path: &String) -> Result<()> { async fn process_file(entry: walkdir::DirEntry, ctx: Arc<Ctx>, ignore: &RegexSet, existing_files: Arc<RwLock<HashSet<CompactString>>>, base_path: &String) -> Result<()> {
let real_path = entry.path(); let real_path = entry.path();
let path = if let Some(path) = real_path.strip_prefix(base_path)?.to_str() { let path = if let Some(path) = real_path.strip_prefix(base_path)?.to_str() {
path path
@@ -129,17 +128,17 @@ impl TextFilesIndexer {
return Result::Ok(()); return Result::Ok(());
}; };
let ext = real_path.extension().and_then(OsStr::to_str); let ext = real_path.extension().and_then(OsStr::to_str);
if ignore.is_match(path) || !entry.file_type().await?.is_file() || !VALID_EXTENSIONS.contains(ext.unwrap_or_default()) { if ignore.is_match(path) || !entry.file_type().is_file() || !VALID_EXTENSIONS.contains(ext.unwrap_or_default()) {
return Ok(()); return Ok(());
} }
let mut conn = ctx.pool.get().await?; let mut conn = ctx.pool.get().await?;
existing_files.write().await.insert(CompactString::from(path)); existing_files.write().await.insert(CompactString::from(path));
let metadata = entry.metadata().await?; let metadata = tokio::fs::metadata(real_path).await?;
let row = conn.query_opt("SELECT timestamp FROM text_files WHERE id = $1", &[&hash_str(path)]).await?; let row = conn.query_opt("SELECT timestamp FROM text_files WHERE id = $1", &[&hash_str(path)]).await?;
let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC); let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC);
let modtime = systemtime_to_utc(metadata.modified()?)?; let modtime = systemtime_to_utc(metadata.modified()?)?;
if modtime > timestamp { if modtime > timestamp {
let parse = TextFilesIndexer::read_file(&real_path, ext).await; let parse = TextFilesIndexer::read_file(&real_path.to_path_buf(), ext).await;
match parse { match parse {
Ok(None) => (), Ok(None) => (),
Ok(Some((content, title))) => { Ok(Some((content, title))) => {

View File

@@ -2,12 +2,11 @@ use std::collections::HashSet;
use std::{collections::HashMap, path::PathBuf}; use std::{collections::HashMap, path::PathBuf};
use std::sync::Arc; use std::sync::Arc;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_walkdir::WalkDir;
use compact_str::ToCompactString; use compact_str::ToCompactString;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::indexer::{Ctx, Indexer, TableSpec, ColumnSpec}; use crate::indexer::{Ctx, Indexer, TableSpec, ColumnSpec};
use crate::util::{hash_thing, parse_html}; use crate::util::{hash_thing, parse_html, async_walkdir};
use chrono::prelude::*; use chrono::prelude::*;
use tokio::{fs::File, io::{AsyncBufReadExt, BufReader}}; use tokio::{fs::File, io::{AsyncBufReadExt, BufReader}};
use mail_parser::MessageParser; use mail_parser::MessageParser;
@@ -177,7 +176,7 @@ CREATE TABLE IF NOT EXISTS emails (
async fn run(&self, ctx: Arc<Ctx>) -> Result<()> { async fn run(&self, ctx: Arc<Ctx>) -> Result<()> {
let mut js: tokio::task::JoinSet<Result<()>> = tokio::task::JoinSet::new(); let mut js: tokio::task::JoinSet<Result<()>> = tokio::task::JoinSet::new();
let mut entries = WalkDir::new(&self.config.mboxes_path); let mut entries = async_walkdir(self.config.mboxes_path.clone().into(), false, |_| true);
let config = self.config.clone(); let config = self.config.clone();
while let Some(entry) = entries.try_next().await? { while let Some(entry) = entries.try_next().await? {
let path = entry.path(); let path = entry.path();
@@ -189,7 +188,7 @@ CREATE TABLE IF NOT EXISTS emails (
if let None = ext { if let None = ext {
if !self.config.ignore_mboxes.contains(mbox.as_str()) { if !self.config.ignore_mboxes.contains(mbox.as_str()) {
let ctx = ctx.clone(); let ctx = ctx.clone();
js.spawn(EmailIndexer::process_mbox(ctx.clone(), path.clone(), mbox, folder, account.clone())); js.spawn(EmailIndexer::process_mbox(ctx.clone(), path.to_path_buf(), mbox, folder, account.clone()));
} }
} }
} }
@@ -237,10 +236,10 @@ impl EmailIndexer {
&mail.raw, &mail.raw,
&account.as_str(), &account.as_str(),
&mbox.as_str(), &mbox.as_str(),
&mail.from, &mail.from.replace("\0", ""),
&mail.from_address, &mail.from_address.replace("\0", ""),
&mail.subject, &mail.subject.replace("\0", ""),
&body &body.replace("\0", "")
]).await?; ]).await?;
Ok(()) Ok(())
} }

View File

@@ -94,7 +94,7 @@ fn page(title: &str, body: Markup) -> web::HttpResponse {
fn search_bar(ctx: &ServerState, value: &SearchQuery) -> Markup { fn search_bar(ctx: &ServerState, value: &SearchQuery) -> Markup {
html! { html! {
form.search-bar action="/search" { form.search-bar action="/search" {
input type="search" placeholder="Query" value=(value.q) name="q"; input type="search" placeholder="Query" value=(value.q) name="q" autofocus;
select name="src_mode" { select name="src_mode" {
option selected[value.src_mode == SearchSourceMode::Mix] { "Mix" } option selected[value.src_mode == SearchSourceMode::Mix] { "Mix" }
option selected[value.src_mode == SearchSourceMode::Titles] { "Titles" } option selected[value.src_mode == SearchSourceMode::Titles] { "Titles" }
@@ -108,7 +108,7 @@ fn search_bar(ctx: &ServerState, value: &SearchQuery) -> Markup {
} }
} }
} }
input type="submit" value="Search"; input type="submit" value="Search" autofocus;
} }
} }
} }
@@ -651,7 +651,7 @@ async fn query_one_table(table: &TableSpec, indexer: &'static str, col: &ColumnS
let matchq_int = (*matchq * (i32::MAX as f64)) as i32; let matchq_int = (*matchq * (i32::MAX as f64)) as i32;
results.push(SearchResult { results.push(SearchResult {
docid: doc, docid: doc,
indexer: indexer, indexer,
table: table.name, table: table.name,
column: col.name, column: col.name,
title: title.clone(), title: title.clone(),
@@ -667,9 +667,9 @@ async fn query_one_table(table: &TableSpec, indexer: &'static str, col: &ColumnS
#[web::get("/search")] #[web::get("/search")]
async fn fts_page(state: web::types::State<ServerState>, query: web::types::Query<SearchQuery>) -> impl web::Responder { async fn fts_page(state: web::types::State<ServerState>, query: web::types::Query<SearchQuery>) -> impl web::Responder {
let state = (*state).clone(); let state = (*state).clone();
let (prefixed, unprefixed) = semantic::embed_query(&query.q, state.semantic.clone()).await?; let (q_prefix, d_prefix) = semantic::embed_query(&query.q, state.semantic.clone()).await?;
let prefixed = Arc::new(prefixed); let q_prefix = Arc::new(q_prefix);
let unprefixed = Arc::new(unprefixed); let d_prefix = Arc::new(d_prefix);
let mut results = HashMap::new(); let mut results = HashMap::new();
let mut set = tokio::task::JoinSet::new(); let mut set = tokio::task::JoinSet::new();
@@ -688,9 +688,9 @@ async fn fts_page(state: web::types::State<ServerState>, query: web::types::Quer
// Some columns are not like this (their spec says so) so we use the model for symmetric search. // Some columns are not like this (their spec says so) so we use the model for symmetric search.
// This does result in a different distribution of dot products. // This does result in a different distribution of dot products.
let (embedding_choice, count) = if col.fts_short { let (embedding_choice, count) = if col.fts_short {
(unprefixed.clone(), if query.src_mode == SearchSourceMode::Mix { 1 } else { 20 }) (q_prefix.clone(), if query.src_mode == SearchSourceMode::Mix { 1 } else { 20 })
} else { } else {
(prefixed.clone(), 20) (q_prefix.clone(), 20)
}; };
set.spawn(query_one_table(table, ix.name(), col, state.clone(), count, embedding_choice)); set.spawn(query_one_table(table, ix.name(), col, state.clone(), count, embedding_choice));
} }

View File

@@ -26,7 +26,9 @@ pub struct SemanticSearchConfig {
backend: String, backend: String,
pub embedding_dim: u32, pub embedding_dim: u32,
max_tokens: usize, max_tokens: usize,
batch_size: usize batch_size: usize,
document_prefix: String,
query_prefix: String
} }
fn convert_tokenizer_error(e: Error) -> anyhow::Error { fn convert_tokenizer_error(e: Error) -> anyhow::Error {
@@ -102,6 +104,10 @@ fn decode_fp16_buffer(buf: &[u8]) -> Vec<f16> {
.collect() .collect()
} }
fn contains_nan(buf: &[f16]) -> bool {
buf.iter().any(|x| x.is_nan())
}
#[instrument(skip(client))] #[instrument(skip(client))]
async fn send_batch(client: &Client, batch: Vec<&str>) -> Result<Vec<Vec<f16>>> { async fn send_batch(client: &Client, batch: Vec<&str>) -> Result<Vec<Vec<f16>>> {
let res = client.post(&CONFIG.semantic.backend) let res = client.post(&CONFIG.semantic.backend)
@@ -158,10 +164,11 @@ async fn insert_fts_chunks(id: i64, chunks: Vec<(Chunk, Vec<f16>)>, table: &Tabl
} }
pub async fn embed_query(q: &str, ctx: Arc<SemanticCtx>) -> Result<(HalfVector, HalfVector)> { pub async fn embed_query(q: &str, ctx: Arc<SemanticCtx>) -> Result<(HalfVector, HalfVector)> {
let prefixed = format!("Represent this sentence for searching relevant passages: {}", q); let query_prefixed = format!("{}{}", CONFIG.semantic.query_prefix, q);
let doc_prefixed = format!("{}{}", CONFIG.semantic.document_prefix, q);
let mut result = send_batch(&ctx.client, vec![ let mut result = send_batch(&ctx.client, vec![
&prefixed, &query_prefixed,
q &doc_prefixed
]).await?.into_iter(); ]).await?.into_iter();
Ok((HalfVector::from(result.next().unwrap()), HalfVector::from(result.next().unwrap()))) Ok((HalfVector::from(result.next().unwrap()), HalfVector::from(result.next().unwrap())))
} }
@@ -213,7 +220,7 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res
col: col.name, col: col.name,
start: chunk.0 as i32, start: chunk.0 as i32,
length: (chunk.1 - chunk.0) as i32, length: (chunk.1 - chunk.0) as i32,
text: chunk.2 text: format!("{}{}", CONFIG.semantic.document_prefix, chunk.2)
}); });
} }
} }
@@ -242,6 +249,12 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res
let mut pending = pending.write().await; let mut pending = pending.write().await;
for (embedding, chunk) in embs.into_iter().zip(batch.into_iter()) { for (embedding, chunk) in embs.into_iter().zip(batch.into_iter()) {
// ugly hack
if contains_nan(&embedding) {
// write no entries
tokio::task::spawn(insert_fts_chunks(chunk.id, vec![], table, ctx.clone()));
}
let record = pending.get_mut(&chunk.id).unwrap(); let record = pending.get_mut(&chunk.id).unwrap();
record.1 -= 1; record.1 -= 1;
let id = chunk.id; let id = chunk.id;
@@ -261,8 +274,8 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res
let make_embeddings: tokio::task::JoinHandle<Result<()>> = tokio::task::spawn(make_embeddings); let make_embeddings: tokio::task::JoinHandle<Result<()>> = tokio::task::spawn(make_embeddings);
get_inputs.await??;
make_embeddings.await??; make_embeddings.await??;
get_inputs.await??;
} }
} }
Ok(()) Ok(())

View File

@@ -7,6 +7,8 @@ use serde::{Serialize, Deserialize};
use tokio_postgres::Row; use tokio_postgres::Row;
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
use tracing::instrument; use tracing::instrument;
use futures::stream::Stream;
use std::path::Path;
const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`');
@@ -179,3 +181,38 @@ pub fn get_column_string(row: &Row, index: usize, spec: &ColumnSpec) -> Option<S
pub fn urlencode(s: &str) -> String { pub fn urlencode(s: &str) -> String {
utf8_percent_encode(s, FRAGMENT).to_string() utf8_percent_encode(s, FRAGMENT).to_string()
} }
// TODO: check the joinhandle somewhere
struct AsyncWalkdirStream {
rx: tokio::sync::mpsc::Receiver<Result<walkdir::DirEntry>>,
handle: tokio::task::JoinHandle<Result<()>>
}
impl Stream for AsyncWalkdirStream {
type Item = Result<walkdir::DirEntry>;
fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
match self.rx.poll_recv(cx) {
std::task::Poll::Ready(x) => std::task::Poll::Ready(x),
std::task::Poll::Pending => std::task::Poll::Pending
}
}
}
impl Drop for AsyncWalkdirStream {
fn drop(&mut self) {
self.handle.abort();
}
}
pub fn async_walkdir<F: Fn(&walkdir::DirEntry) -> bool + Send + 'static>(path: PathBuf, follow_symlinks: bool, filter: F) -> impl Stream<Item = Result<walkdir::DirEntry>> {
let (tx, rx) = tokio::sync::mpsc::channel(128);
let handle = tokio::task::spawn_blocking(move || {
let walker = walkdir::WalkDir::new(path).follow_links(follow_symlinks).into_iter();
for entry in walker.filter_entry(|e| filter(&e)) {
tx.blocking_send(entry.map_err(|e| anyhow::Error::from(e)))?;
}
Ok::<(), anyhow::Error>(())
});
AsyncWalkdirStream { rx, handle }
}