1
0
mirror of https://github.com/osmarks/maghammer.git synced 2025-09-07 13:17:55 +00:00

support modernbert & change walkdir

This commit is contained in:
osmarks
2025-04-11 21:57:01 +01:00
parent a11bc0b22d
commit a940dc9b46
11 changed files with 1143 additions and 636 deletions

1587
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_j
tokio = { version = "1", features = ["full", "tracing"]}
chrono = { version = "0.4", features = ["serde"] }
anyhow = "1"
async-walkdir = "2"
walkdir = "2"
lazy_static = "1"
compact_str = { version = "0.8.0-beta", features = ["serde"] }
seahash = "4"
@@ -17,7 +17,7 @@ serde = { version = "1", features = ["derive"] }
reqwest = "0.12"
deadpool-postgres = "0.14"
pgvector = { version = "0.3", features = ["postgres", "halfvec"] }
tokenizers = { version = "0.19", features = ["http"] }
tokenizers = { version = "0.21", features = ["http"] }
regex = "1"
futures = "0.3"
html5gum = "0.5"

View File

@@ -1,12 +1,14 @@
database = "host=localhost port=5432 user=maghammer dbname=maghammer"
database = "host=localhost port=5432 user=maghammer dbname=maghammer-test"
concurrency = 8
[semantic]
backend = "http://100.64.0.10:1706"
embedding_dim = 1024
tokenizer = "Snowflake/snowflake-arctic-embed-l"
max_tokens = 128
batch_size = 256
tokenizer = "lightonai/modernbert-embed-large"
max_tokens = 4096
batch_size = 12
document_prefix = "search_document: "
query_prefix = "search_query: "
[indexers.text_files]
path = "/data/archive"

View File

@@ -14,13 +14,13 @@ from sentence_transformers import SentenceTransformer
from prometheus_client import Counter, Histogram, REGISTRY, generate_latest
device = torch.device("cuda:0")
model_name = "./snowflake-arctic-embed-l"
model_name = "./modernbert-embed-large"
model = SentenceTransformer(model_name).half().to(device)
model.eval()
print("model loaded")
MODELNAME = "sbert-snowflake-arctic-embed-l"
BS = 256
MODELNAME = "modernbert-embed-large"
BS = 64
InferenceParameters = collections.namedtuple("InferenceParameters", ["text", "callback"])
@@ -38,7 +38,7 @@ def do_inference(params: InferenceParameters):
items_ctr.labels(MODELNAME, "text").inc(batch_size)
with inference_time_hist.labels(MODELNAME, batch_size).time():
features = model(text)["sentence_embedding"]
features /= features.norm(dim=-1, keepdim=True)
features /= features.norm(dim=-1, keepdim=True) + 1e-5
features = features.cpu().numpy()
batch_count_ctr.labels(MODELNAME).inc()
callback(True, features)

View File

@@ -3,17 +3,17 @@ use std::ffi::OsStr;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Context, Result};
use async_walkdir::{WalkDir, DirEntry};
use compact_str::{CompactString, ToCompactString};
use epub::doc::EpubDoc;
use futures::TryStreamExt;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::util::{hash_str, parse_html, parse_date, systemtime_to_utc, CONFIG};
use crate::util::{hash_str, parse_html, parse_date, systemtime_to_utc, CONFIG, async_walkdir};
use crate::indexer::{delete_nonexistent_files, Ctx, Indexer, TableSpec, ColumnSpec};
use chrono::prelude::*;
use std::str::FromStr;
use tracing::instrument;
use walkdir::DirEntry;
#[derive(Serialize, Deserialize, Clone)]
struct Config {
@@ -91,14 +91,15 @@ async fn handle_epub(relpath: CompactString, ctx: Arc<Ctx>, entry: DirEntry, exi
let epub = entry.path();
let mut conn = ctx.pool.get().await?;
let metadata = entry.metadata().await?;
let metadata = tokio::fs::metadata(epub).await?;
let row = conn.query_opt("SELECT timestamp FROM books WHERE id = $1", &[&hash_str(&relpath)]).await?;
existing_files.write().await.insert(relpath.clone());
let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC);
let modtime = systemtime_to_utc(metadata.modified()?)?;
if modtime > timestamp {
let parse_result = match tokio::task::spawn_blocking(move || parse_epub(&epub)).await? {
let pathbuf = epub.to_path_buf();
let parse_result = match tokio::task::spawn_blocking(move || parse_epub(&pathbuf)).await? {
Ok(x) => x,
Err(e) => {
tracing::warn!("Failed parse for {}: {}", relpath, e);
@@ -232,7 +233,7 @@ CREATE TABLE chapters (
async fn run(&self, ctx: Arc<Ctx>) -> Result<()> {
let existing_files = Arc::new(RwLock::new(HashSet::new()));
let entries = WalkDir::new(&self.config.path);
let entries = async_walkdir(self.config.path.clone().into(), false, |_| true);
let base_path = &self.config.path;
entries.map_err(|e| anyhow::Error::from(e)).try_for_each_concurrent(Some(CONFIG.concurrency), |entry|
@@ -247,17 +248,17 @@ CREATE TABLE chapters (
return Result::Ok(());
};
let ext = real_path.extension().and_then(OsStr::to_str);
if !entry.file_type().await?.is_file() || "epub" != ext.unwrap_or_default() {
if !entry.file_type().is_file() || "epub" != ext.unwrap_or_default() {
return Ok(());
}
let conn = ctx.pool.get().await?;
existing_files.write().await.insert(CompactString::from(path));
let metadata = entry.metadata().await?;
let metadata = tokio::fs::metadata(real_path).await?;
let row = conn.query_opt("SELECT timestamp FROM books WHERE id = $1", &[&hash_str(path)]).await?;
let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC);
let modtime = systemtime_to_utc(metadata.modified()?)?;
if modtime > timestamp {
let relpath = entry.path().as_path().strip_prefix(&base_path)?.as_os_str().to_str().context("invalid path")?.to_compact_string();
let relpath = entry.path().strip_prefix(&base_path)?.as_os_str().to_str().context("invalid path")?.to_compact_string();
handle_epub(relpath.clone(), ctx, entry, existing_files).await
} else {
Ok(())

View File

@@ -9,9 +9,8 @@ use futures::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tokio::sync::RwLock;
use crate::util::{hash_str, parse_date, parse_html, systemtime_to_utc, urlencode, CONFIG};
use crate::util::{hash_str, parse_date, parse_html, systemtime_to_utc, urlencode, CONFIG, async_walkdir};
use crate::indexer::{Ctx, Indexer, TableSpec, delete_nonexistent_files, ColumnSpec};
use async_walkdir::{Filtering, WalkDir};
use chrono::prelude::*;
use regex::{RegexSet, Regex};
use tracing::instrument;
@@ -67,6 +66,7 @@ struct MediaParse {
struct Chapter {
start_time: String,
end_time: String,
#[serde(default)]
tags: HashMap<String, String>
}
@@ -79,6 +79,7 @@ struct Disposition {
struct Stream {
index: usize,
codec_name: Option<String>,
#[serde(default)]
duration: Option<String>,
codec_type: String,
#[serde(default)]
@@ -89,7 +90,7 @@ struct Stream {
#[derive(Deserialize, Debug)]
struct Format {
duration: String,
duration: Option<String>,
#[serde(default)]
tags: HashMap<String, String>
}
@@ -319,7 +320,7 @@ async fn parse_media(path: &PathBuf, ignore: Arc<RegexSet>) -> Result<MediaParse
}
}
result.duration = f32::from_str(&probe.format.duration)?;
result.duration = probe.format.duration.map(|x| f32::from_str(&x)).transpose()?.unwrap_or(0.0);
let mut best_subtitle_track = (0, i8::MIN);
@@ -470,31 +471,31 @@ CREATE TABLE media_files (
}
async fn run(&self, ctx: Arc<Ctx>) -> Result<()> {
let entries = WalkDir::new(&self.config.path);
let ignore = Arc::new(self.ignore_files.clone());
let base_path = Arc::new(self.config.path.clone());
let base_path_ = base_path.clone();
let ignore_metadata = self.ignore_metadata.clone();
let existing_files = Arc::new(RwLock::new(HashSet::new()));
let existing_files_ = existing_files.clone();
let ctx_ = ctx.clone();
entries
.filter(move |entry| {
let entries = async_walkdir(self.config.path.clone().into(), true, move |entry| {
let ignore = ignore.clone();
let base_path = base_path.clone();
let path = entry.path();
tracing::trace!("filtering {:?}", path);
if let Some(path) = path.strip_prefix(&*base_path).ok().and_then(|x| x.to_str()) {
if ignore.is_match(path) {
return std::future::ready(Filtering::IgnoreDir);
return false;
}
} else {
return std::future::ready(Filtering::IgnoreDir);
return false;
}
std::future::ready(Filtering::Continue)
})
true
});
let existing_files = Arc::new(RwLock::new(HashSet::new()));
let existing_files_ = existing_files.clone();
let ctx_ = ctx.clone();
entries
.map_err(|e| anyhow::Error::from(e))
.filter(|r| {
// ignore permissions errors because things apparently break otherwise
@@ -517,18 +518,18 @@ CREATE TABLE media_files (
} else {
return Result::Ok(());
};
if !entry.file_type().await?.is_file() {
if !entry.file_type().is_file() {
return Ok(());
}
let mut conn = ctx.pool.get().await?;
existing_files.write().await.insert(CompactString::from(path));
let metadata = entry.metadata().await?;
let metadata = tokio::fs::metadata(real_path).await?;
let row = conn.query_opt("SELECT timestamp FROM media_files WHERE id = $1", &[&hash_str(path)]).await?;
let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC);
let modtime = systemtime_to_utc(metadata.modified()?)?;
tracing::trace!("timestamp {:?}", timestamp);
if modtime > timestamp {
match parse_media(&real_path, ignore_metadata).await {
match parse_media(&real_path.to_path_buf(), ignore_metadata).await {
Ok(x) => {
let tx = conn.transaction().await?;
tx.execute("DELETE FROM media_files WHERE id = $1", &[&hash_str(path)]).await?;

View File

@@ -7,9 +7,8 @@ use futures::TryStreamExt;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tracing::instrument;
use crate::util::{hash_str, parse_html, parse_pdf, systemtime_to_utc, urlencode, CONFIG};
use crate::util::{hash_str, parse_html, parse_pdf, systemtime_to_utc, urlencode, CONFIG, async_walkdir};
use crate::indexer::{Ctx, Indexer, TableSpec, delete_nonexistent_files, ColumnSpec};
use async_walkdir::WalkDir;
use chrono::prelude::*;
use regex::RegexSet;
@@ -85,7 +84,7 @@ CREATE TABLE text_files (
}
async fn run(&self, ctx: Arc<Ctx>) -> Result<()> {
let entries = WalkDir::new(&self.config.path); // TODO
let entries = async_walkdir(self.config.path.clone().into(), false, |_| true);
let ignore = &self.ignore;
let base_path = &self.config.path;
@@ -121,7 +120,7 @@ impl TextFilesIndexer {
}
#[instrument(skip(ctx, ignore, existing_files, base_path))]
async fn process_file(entry: async_walkdir::DirEntry, ctx: Arc<Ctx>, ignore: &RegexSet, existing_files: Arc<RwLock<HashSet<CompactString>>>, base_path: &String) -> Result<()> {
async fn process_file(entry: walkdir::DirEntry, ctx: Arc<Ctx>, ignore: &RegexSet, existing_files: Arc<RwLock<HashSet<CompactString>>>, base_path: &String) -> Result<()> {
let real_path = entry.path();
let path = if let Some(path) = real_path.strip_prefix(base_path)?.to_str() {
path
@@ -129,17 +128,17 @@ impl TextFilesIndexer {
return Result::Ok(());
};
let ext = real_path.extension().and_then(OsStr::to_str);
if ignore.is_match(path) || !entry.file_type().await?.is_file() || !VALID_EXTENSIONS.contains(ext.unwrap_or_default()) {
if ignore.is_match(path) || !entry.file_type().is_file() || !VALID_EXTENSIONS.contains(ext.unwrap_or_default()) {
return Ok(());
}
let mut conn = ctx.pool.get().await?;
existing_files.write().await.insert(CompactString::from(path));
let metadata = entry.metadata().await?;
let metadata = tokio::fs::metadata(real_path).await?;
let row = conn.query_opt("SELECT timestamp FROM text_files WHERE id = $1", &[&hash_str(path)]).await?;
let timestamp: DateTime<Utc> = row.map(|r| r.get(0)).unwrap_or(DateTime::<Utc>::MIN_UTC);
let modtime = systemtime_to_utc(metadata.modified()?)?;
if modtime > timestamp {
let parse = TextFilesIndexer::read_file(&real_path, ext).await;
let parse = TextFilesIndexer::read_file(&real_path.to_path_buf(), ext).await;
match parse {
Ok(None) => (),
Ok(Some((content, title))) => {

View File

@@ -2,12 +2,11 @@ use std::collections::HashSet;
use std::{collections::HashMap, path::PathBuf};
use std::sync::Arc;
use anyhow::{Context, Result};
use async_walkdir::WalkDir;
use compact_str::ToCompactString;
use serde::{Deserialize, Serialize};
use tokio_stream::StreamExt;
use crate::indexer::{Ctx, Indexer, TableSpec, ColumnSpec};
use crate::util::{hash_thing, parse_html};
use crate::util::{hash_thing, parse_html, async_walkdir};
use chrono::prelude::*;
use tokio::{fs::File, io::{AsyncBufReadExt, BufReader}};
use mail_parser::MessageParser;
@@ -177,7 +176,7 @@ CREATE TABLE IF NOT EXISTS emails (
async fn run(&self, ctx: Arc<Ctx>) -> Result<()> {
let mut js: tokio::task::JoinSet<Result<()>> = tokio::task::JoinSet::new();
let mut entries = WalkDir::new(&self.config.mboxes_path);
let mut entries = async_walkdir(self.config.mboxes_path.clone().into(), false, |_| true);
let config = self.config.clone();
while let Some(entry) = entries.try_next().await? {
let path = entry.path();
@@ -189,7 +188,7 @@ CREATE TABLE IF NOT EXISTS emails (
if let None = ext {
if !self.config.ignore_mboxes.contains(mbox.as_str()) {
let ctx = ctx.clone();
js.spawn(EmailIndexer::process_mbox(ctx.clone(), path.clone(), mbox, folder, account.clone()));
js.spawn(EmailIndexer::process_mbox(ctx.clone(), path.to_path_buf(), mbox, folder, account.clone()));
}
}
}
@@ -237,10 +236,10 @@ impl EmailIndexer {
&mail.raw,
&account.as_str(),
&mbox.as_str(),
&mail.from,
&mail.from_address,
&mail.subject,
&body
&mail.from.replace("\0", ""),
&mail.from_address.replace("\0", ""),
&mail.subject.replace("\0", ""),
&body.replace("\0", "")
]).await?;
Ok(())
}

View File

@@ -94,7 +94,7 @@ fn page(title: &str, body: Markup) -> web::HttpResponse {
fn search_bar(ctx: &ServerState, value: &SearchQuery) -> Markup {
html! {
form.search-bar action="/search" {
input type="search" placeholder="Query" value=(value.q) name="q";
input type="search" placeholder="Query" value=(value.q) name="q" autofocus;
select name="src_mode" {
option selected[value.src_mode == SearchSourceMode::Mix] { "Mix" }
option selected[value.src_mode == SearchSourceMode::Titles] { "Titles" }
@@ -108,7 +108,7 @@ fn search_bar(ctx: &ServerState, value: &SearchQuery) -> Markup {
}
}
}
input type="submit" value="Search";
input type="submit" value="Search" autofocus;
}
}
}
@@ -651,7 +651,7 @@ async fn query_one_table(table: &TableSpec, indexer: &'static str, col: &ColumnS
let matchq_int = (*matchq * (i32::MAX as f64)) as i32;
results.push(SearchResult {
docid: doc,
indexer: indexer,
indexer,
table: table.name,
column: col.name,
title: title.clone(),
@@ -667,9 +667,9 @@ async fn query_one_table(table: &TableSpec, indexer: &'static str, col: &ColumnS
#[web::get("/search")]
async fn fts_page(state: web::types::State<ServerState>, query: web::types::Query<SearchQuery>) -> impl web::Responder {
let state = (*state).clone();
let (prefixed, unprefixed) = semantic::embed_query(&query.q, state.semantic.clone()).await?;
let prefixed = Arc::new(prefixed);
let unprefixed = Arc::new(unprefixed);
let (q_prefix, d_prefix) = semantic::embed_query(&query.q, state.semantic.clone()).await?;
let q_prefix = Arc::new(q_prefix);
let d_prefix = Arc::new(d_prefix);
let mut results = HashMap::new();
let mut set = tokio::task::JoinSet::new();
@@ -688,9 +688,9 @@ async fn fts_page(state: web::types::State<ServerState>, query: web::types::Quer
// Some columns are not like this (their spec says so) so we use the model for symmetric search.
// This does result in a different distribution of dot products.
let (embedding_choice, count) = if col.fts_short {
(unprefixed.clone(), if query.src_mode == SearchSourceMode::Mix { 1 } else { 20 })
(q_prefix.clone(), if query.src_mode == SearchSourceMode::Mix { 1 } else { 20 })
} else {
(prefixed.clone(), 20)
(q_prefix.clone(), 20)
};
set.spawn(query_one_table(table, ix.name(), col, state.clone(), count, embedding_choice));
}

View File

@@ -26,7 +26,9 @@ pub struct SemanticSearchConfig {
backend: String,
pub embedding_dim: u32,
max_tokens: usize,
batch_size: usize
batch_size: usize,
document_prefix: String,
query_prefix: String
}
fn convert_tokenizer_error(e: Error) -> anyhow::Error {
@@ -102,6 +104,10 @@ fn decode_fp16_buffer(buf: &[u8]) -> Vec<f16> {
.collect()
}
fn contains_nan(buf: &[f16]) -> bool {
buf.iter().any(|x| x.is_nan())
}
#[instrument(skip(client))]
async fn send_batch(client: &Client, batch: Vec<&str>) -> Result<Vec<Vec<f16>>> {
let res = client.post(&CONFIG.semantic.backend)
@@ -158,10 +164,11 @@ async fn insert_fts_chunks(id: i64, chunks: Vec<(Chunk, Vec<f16>)>, table: &Tabl
}
pub async fn embed_query(q: &str, ctx: Arc<SemanticCtx>) -> Result<(HalfVector, HalfVector)> {
let prefixed = format!("Represent this sentence for searching relevant passages: {}", q);
let query_prefixed = format!("{}{}", CONFIG.semantic.query_prefix, q);
let doc_prefixed = format!("{}{}", CONFIG.semantic.document_prefix, q);
let mut result = send_batch(&ctx.client, vec![
&prefixed,
q
&query_prefixed,
&doc_prefixed
]).await?.into_iter();
Ok((HalfVector::from(result.next().unwrap()), HalfVector::from(result.next().unwrap())))
}
@@ -213,7 +220,7 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res
col: col.name,
start: chunk.0 as i32,
length: (chunk.1 - chunk.0) as i32,
text: chunk.2
text: format!("{}{}", CONFIG.semantic.document_prefix, chunk.2)
});
}
}
@@ -242,6 +249,12 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res
let mut pending = pending.write().await;
for (embedding, chunk) in embs.into_iter().zip(batch.into_iter()) {
// ugly hack
if contains_nan(&embedding) {
// write no entries
tokio::task::spawn(insert_fts_chunks(chunk.id, vec![], table, ctx.clone()));
}
let record = pending.get_mut(&chunk.id).unwrap();
record.1 -= 1;
let id = chunk.id;
@@ -261,8 +274,8 @@ pub async fn fts_for_indexer(i: &Box<dyn Indexer>, ctx: Arc<SemanticCtx>) -> Res
let make_embeddings: tokio::task::JoinHandle<Result<()>> = tokio::task::spawn(make_embeddings);
get_inputs.await??;
make_embeddings.await??;
get_inputs.await??;
}
}
Ok(())

View File

@@ -7,6 +7,8 @@ use serde::{Serialize, Deserialize};
use tokio_postgres::Row;
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
use tracing::instrument;
use futures::stream::Stream;
use std::path::Path;
const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`');
@@ -179,3 +181,38 @@ pub fn get_column_string(row: &Row, index: usize, spec: &ColumnSpec) -> Option<S
pub fn urlencode(s: &str) -> String {
utf8_percent_encode(s, FRAGMENT).to_string()
}
// TODO: check the joinhandle somewhere
struct AsyncWalkdirStream {
rx: tokio::sync::mpsc::Receiver<Result<walkdir::DirEntry>>,
handle: tokio::task::JoinHandle<Result<()>>
}
impl Stream for AsyncWalkdirStream {
type Item = Result<walkdir::DirEntry>;
fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
match self.rx.poll_recv(cx) {
std::task::Poll::Ready(x) => std::task::Poll::Ready(x),
std::task::Poll::Pending => std::task::Poll::Pending
}
}
}
impl Drop for AsyncWalkdirStream {
fn drop(&mut self) {
self.handle.abort();
}
}
pub fn async_walkdir<F: Fn(&walkdir::DirEntry) -> bool + Send + 'static>(path: PathBuf, follow_symlinks: bool, filter: F) -> impl Stream<Item = Result<walkdir::DirEntry>> {
let (tx, rx) = tokio::sync::mpsc::channel(128);
let handle = tokio::task::spawn_blocking(move || {
let walker = walkdir::WalkDir::new(path).follow_links(follow_symlinks).into_iter();
for entry in walker.filter_entry(|e| filter(&e)) {
tx.blocking_send(entry.map_err(|e| anyhow::Error::from(e)))?;
}
Ok::<(), anyhow::Error>(())
});
AsyncWalkdirStream { rx, handle }
}