1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2026-04-28 20:51:22 +00:00

Rewrite entire application (well, backend) in Rust and also Go

I decided I wanted to integrate the experimental OCR thing better, so I rewrote in Go and also integrated the thumbnailer.
However, Go is a bad langauge and I only used it out of spite.
It turned out to have a very hard-to-fix memory leak due to some unclear interaction between libvips and both sets of bindings I tried, so I had Claude-3 transpile it to Rust then spent a while fixing the several mistakes it made and making tweaks.
The new Rust version works, although I need to actually do something with the OCR data and make the index queryable concurrently.
This commit is contained in:
2024-05-21 00:09:04 +01:00
parent fa863c2075
commit 7cb42e028f
24 changed files with 7192 additions and 51 deletions

892
src/main.rs Normal file
View File

@@ -0,0 +1,892 @@
use std::{collections::HashMap, io::Cursor};
use std::path::Path;
use std::sync::Arc;
use anyhow::{Result, Context};
use axum::body::Body;
use axum::response::Response;
use axum::{
extract::Json,
response::IntoResponse,
routing::{get, post},
Router,
http::StatusCode
};
use image::{imageops::FilterType, io::Reader as ImageReader, DynamicImage, ImageFormat};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
use tokio::sync::{broadcast, mpsc};
use tokio::task::JoinHandle;
use walkdir::WalkDir;
use base64::prelude::*;
use faiss::Index;
use futures_util::stream::{StreamExt, TryStreamExt};
use tokio_stream::wrappers::ReceiverStream;
use tower_http::cors::CorsLayer;
mod ocr;
use crate::ocr::scan_image;
fn function_which_returns_50() -> usize { 50 }
#[derive(Debug, Deserialize, Clone)]
struct Config {
clip_server: String,
db_path: String,
port: u16,
files: String,
#[serde(default)]
enable_ocr: bool,
#[serde(default)]
thumbs_path: String,
#[serde(default)]
enable_thumbs: bool,
#[serde(default="function_which_returns_50")]
ocr_concurrency: usize,
#[serde(default)]
no_run_server: bool
}
#[derive(Debug)]
struct IIndex {
vectors: faiss::index::IndexImpl,
filenames: Vec<String>,
format_codes: Vec<u64>,
format_names: Vec<String>,
}
const SCHEMA: &str = r#"
CREATE TABLE IF NOT EXISTS files (
filename TEXT NOT NULL PRIMARY KEY,
embedding_time INTEGER,
ocr_time INTEGER,
thumbnail_time INTEGER,
embedding BLOB,
ocr TEXT,
raw_ocr_segments BLOB,
thumbnails BLOB
);
CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
filename,
ocr,
tokenize='unicode61 remove_diacritics 2',
content='ocr'
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER DELETE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
END;
CREATE TRIGGER IF NOT EXISTS ocr_fts_del AFTER UPDATE ON files BEGIN
INSERT INTO ocr_fts (ocr_fts, rowid, filename, ocr) VALUES ('delete', old.rowid, old.filename, COALESCE(old.ocr, ''));
INSERT INTO ocr_fts (rowid, filename, text) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
"#;
#[derive(Debug, sqlx::FromRow, Clone, Default)]
struct FileRecord {
filename: String,
embedding_time: Option<i64>,
ocr_time: Option<i64>,
thumbnail_time: Option<i64>,
embedding: Option<Vec<u8>>,
// this totally "will" be used later
ocr: Option<String>,
raw_ocr_segments: Option<Vec<u8>>,
thumbnails: Option<Vec<u8>>,
}
#[derive(Debug, Deserialize, Clone)]
struct InferenceServerConfig {
batch: usize,
image_size: (u32, u32),
embedding_size: usize,
}
async fn query_clip_server<I, O>(
client: &Client,
config: &Config,
path: &str,
data: I,
) -> Result<O> where I: Serialize, O: serde::de::DeserializeOwned,
{
let response = client
.post(&format!("{}{}", config.clip_server, path))
.header("Content-Type", "application/msgpack")
.body(rmp_serde::to_vec_named(&data)?)
.send()
.await?;
let result: O = rmp_serde::from_slice(&response.bytes().await?)?;
Ok(result)
}
#[derive(Debug)]
struct LoadedImage {
image: Arc<DynamicImage>,
filename: String,
original_size: usize,
}
#[derive(Debug)]
struct EmbeddingInput {
image: Vec<u8>,
filename: String,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum EmbeddingRequest {
Images { images: Vec<serde_bytes::ByteBuf> },
Text { text: Vec<String> }
}
fn timestamp() -> i64 {
chrono::Utc::now().timestamp_micros()
}
#[derive(Debug, Clone)]
struct ImageFormatConfig {
target_width: u32,
target_filesize: u32,
quality: u8,
format: ImageFormat,
extension: String,
}
fn generate_filename_hash(filename: &str) -> String {
use std::hash::{Hash, Hasher};
let mut hasher = fnv::FnvHasher::default();
filename.hash(&mut hasher);
BASE64_URL_SAFE_NO_PAD.encode(hasher.finish().to_le_bytes())
}
fn generate_thumbnail_filename(
filename: &str,
format_name: &str,
format_config: &ImageFormatConfig,
) -> String {
format!(
"{}{}.{}",
generate_filename_hash(filename),
format_name,
format_config.extension
)
}
async fn initialize_database(config: &Config) -> Result<SqlitePool> {
let connection_options = SqliteConnectOptions::new()
.filename(&config.db_path)
.create_if_missing(true);
let pool = SqlitePool::connect_with(connection_options).await?;
sqlx::query(SCHEMA).execute(&pool).await?;
Ok(pool)
}
fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
let mut formats = HashMap::new();
formats.insert(
"jpegl".to_string(),
ImageFormatConfig {
target_width: 800,
target_filesize: 0,
quality: 70,
format: ImageFormat::Jpeg,
extension: "jpg".to_string(),
},
);
formats.insert(
"jpegh".to_string(),
ImageFormatConfig {
target_width: 1600,
target_filesize: 0,
quality: 80,
format: ImageFormat::Jpeg,
extension: "jpg".to_string(),
},
);
formats.insert(
"jpeg256kb".to_string(),
ImageFormatConfig {
target_width: 500,
target_filesize: 256000,
quality: 0,
format: ImageFormat::Jpeg,
extension: "jpg".to_string(),
},
);
formats.insert(
"avifh".to_string(),
ImageFormatConfig {
target_width: 1600,
target_filesize: 0,
quality: 80,
format: ImageFormat::Avif,
extension: "avif".to_string(),
},
);
formats.insert(
"avifl".to_string(),
ImageFormatConfig {
target_width: 800,
target_filesize: 0,
quality: 30,
format: ImageFormat::Avif,
extension: "avif".to_string(),
},
);
formats
}
async fn resize_for_embed(backend_config: Arc<InferenceServerConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> {
let resized = tokio::task::spawn_blocking(move || {
let new = image.resize(
backend_config.image_size.0,
backend_config.image_size.1,
FilterType::Lanczos3
);
let mut buf = Vec::new();
let mut csr = Cursor::new(&mut buf);
new.write_to(&mut csr, ImageFormat::Png)?;
Ok::<Vec<u8>, anyhow::Error>(buf)
}).await??;
Ok(resized)
}
async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<()> {
let pool = initialize_database(&config).await?;
let client = Client::new();
let formats = image_formats(&config);
let (to_process_tx, to_process_rx) = mpsc::channel::<FileRecord>(100);
let (to_embed_tx, to_embed_rx) = mpsc::channel(backend.batch as usize);
let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30);
let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30);
let cpus = num_cpus::get();
// Image loading and preliminary resizing
let image_loading: JoinHandle<Result<()>> = tokio::spawn({
let config = config.clone();
let backend = backend.clone();
let stream = ReceiverStream::new(to_process_rx).map(Ok);
stream.try_for_each_concurrent(Some(cpus), move |record| {
let config = config.clone();
let backend = backend.clone();
let to_embed_tx = to_embed_tx.clone();
let to_thumbnail_tx = to_thumbnail_tx.clone();
let to_ocr_tx = to_ocr_tx.clone();
async move {
let path = Path::new(&config.files).join(&record.filename);
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
let image = match image {
Ok(image) => image,
Err(e) => {
log::error!("Could not read {}: {}", record.filename, e);
return Ok(())
}
};
if record.embedding.is_none() {
let resized = resize_for_embed(backend.clone(), image.clone()).await?;
to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await?
}
if record.thumbnails.is_none() && config.enable_thumbs {
to_thumbnail_tx
.send(LoadedImage {
image: image.clone(),
filename: record.filename.clone(),
original_size: std::fs::metadata(&path)?.len() as usize,
})
.await?;
}
if record.raw_ocr_segments.is_none() && config.enable_ocr {
to_ocr_tx
.send(LoadedImage {
image,
filename: record.filename.clone(),
original_size: 0,
})
.await?;
}
Ok(())
}
})
});
// Thumbnail generation
let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.enable_thumbs {
let config = config.clone();
let pool = pool.clone();
let stream = ReceiverStream::new(to_thumbnail_rx).map(Ok);
let formats = Arc::new(formats);
Some(tokio::spawn({
stream.try_for_each_concurrent(Some(cpus), move |image| {
use image::codecs::*;
let formats = formats.clone();
let config = config.clone();
let pool = pool.clone();
async move {
let filename = image.filename.clone();
log::debug!("thumbnailing {}", filename);
let generated_formats = tokio::task::spawn_blocking(move || {
let mut generated_formats = Vec::new();
let rgb = DynamicImage::from(image.image.to_rgb8());
for (format_name, format_config) in &*formats {
let resized = if format_config.target_filesize != 0 {
let mut lb = 1;
let mut ub = 100;
loop {
let quality = (lb + ub) / 2;
let thumbnail = rgb.resize(
format_config.target_width,
u32::MAX,
FilterType::Lanczos3,
);
let mut buf: Vec<u8> = Vec::new();
let mut csr = Cursor::new(&mut buf);
// this is ugly but I don't actually know how to fix it (cannot factor it out due to issues with dyn Trait)
match format_config.format {
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, quality)),
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, quality)),
_ => unimplemented!()
}?;
if buf.len() > image.original_size {
ub = quality;
} else {
lb = quality + 1;
}
if lb >= ub {
break buf;
}
}
} else {
let thumbnail = rgb.resize(
format_config.target_width,
u32::MAX,
FilterType::Lanczos3,
);
let mut buf: Vec<u8> = Vec::new();
let mut csr = Cursor::new(&mut buf);
match format_config.format {
ImageFormat::Avif => thumbnail.write_with_encoder(avif::AvifEncoder::new_with_speed_quality(&mut csr, 4, format_config.quality)),
ImageFormat::Jpeg => thumbnail.write_with_encoder(jpeg::JpegEncoder::new_with_quality(&mut csr, format_config.quality)),
ImageFormat::WebP => thumbnail.write_with_encoder(webp::WebPEncoder::new_lossless(&mut csr)),
_ => unimplemented!()
}?;
buf
};
if resized.len() < image.original_size {
generated_formats.push(format_name.clone());
let thumbnail_path = Path::new(&config.thumbs_path).join(
generate_thumbnail_filename(
&image.filename,
format_name,
format_config,
),
);
std::fs::write(thumbnail_path, resized)?;
}
}
Ok::<Vec<String>, anyhow::Error>(generated_formats)
}).await??;
let formats_data = rmp_serde::to_vec(&generated_formats)?;
let ts = timestamp();
sqlx::query!(
"UPDATE files SET thumbnails = ?, thumbnail_time = ? WHERE filename = ?",
formats_data,
ts,
filename
)
.execute(&pool)
.await?;
Ok(())
}
})
}))
} else {
None
};
// OCR
let ocr: Option<JoinHandle<Result<()>>> = if config.enable_ocr {
let client = client.clone();
let pool = pool.clone();
let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
Some(tokio::spawn({
stream.try_for_each_concurrent(Some(config.ocr_concurrency), move |image| {
let client = client.clone();
let pool = pool.clone();
async move {
log::debug!("OCRing {}", image.filename);
let scan = match scan_image(&client, &image.image).await {
Ok(scan) => scan,
Err(e) => {
log::error!("OCR failure {}: {}", image.filename, e);
return Ok(())
}
};
let ocr_text = scan
.iter()
.map(|segment| segment.text.clone())
.collect::<Vec<_>>()
.join("\n");
let ocr_data = rmp_serde::to_vec(&scan)?;
let ts = timestamp();
sqlx::query!(
"UPDATE files SET ocr = ?, raw_ocr_segments = ?, ocr_time = ? WHERE filename = ?",
ocr_text,
ocr_data,
ts,
image.filename
)
.execute(&pool)
.await?;
Ok(())
}
})
}))
} else {
None
};
let embedding_generation: JoinHandle<Result<()>> = tokio::spawn({
let stream = ReceiverStream::new(to_embed_rx).chunks(backend.batch);
let client = client.clone();
let config = config.clone();
let pool = pool.clone();
// keep multiple embedding requests in flight
stream.map(Ok).try_for_each_concurrent(Some(3), move |batch| {
let client = client.clone();
let config = config.clone();
let pool = pool.clone();
async move {
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
&client,
&config,
"",
EmbeddingRequest::Images {
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
},
).await.context("querying CLIP server")?;
let mut tx = pool.begin().await?;
let ts = timestamp();
for (i, vector) in result.into_iter().enumerate() {
let vector = vector.into_vec();
log::debug!("embedded {}", batch[i].filename);
sqlx::query!(
"UPDATE files SET embedding_time = ?, embedding = ? WHERE filename = ?",
ts,
vector,
batch[i].filename
)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
anyhow::Result::Ok(())
}
})
});
let mut filenames = HashMap::new();
// blocking OS calls
tokio::task::block_in_place(|| -> anyhow::Result<()> {
for entry in WalkDir::new(config.files.as_str()) {
let entry = entry?;
let path = entry.path();
if path.is_file() {
let filename = path.strip_prefix(&config.files)?.to_str().unwrap().to_string();
let modtime = entry.metadata()?.modified()?.duration_since(std::time::UNIX_EPOCH)?;
let modtime = modtime.as_micros() as i64;
filenames.insert(filename.clone(), (path.to_path_buf(), modtime));
}
}
Ok(())
})?;
log::debug!("finished reading filenames");
for (filename, (_path, modtime)) in filenames.iter() {
let modtime = *modtime;
let record = sqlx::query_as!(FileRecord, "SELECT * FROM files WHERE filename = ?", filename)
.fetch_optional(&pool)
.await?;
let new_record = match record {
None => Some(FileRecord {
filename: filename.clone(),
..Default::default()
}),
Some(r) if modtime > r.embedding_time.unwrap_or(i64::MIN) || (modtime > r.ocr_time.unwrap_or(i64::MIN) && config.enable_ocr) || (modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.enable_thumbs) => {
Some(r)
},
_ => None
};
if let Some(mut record) = new_record {
log::debug!("processing {}", record.filename);
sqlx::query!("INSERT OR IGNORE INTO files (filename) VALUES (?)", filename)
.execute(&pool)
.await?;
if modtime > record.embedding_time.unwrap_or(i64::MIN) {
record.embedding = None;
}
if modtime > record.ocr_time.unwrap_or(i64::MIN) {
record.raw_ocr_segments = None;
}
if modtime > record.thumbnail_time.unwrap_or(i64::MIN) {
record.thumbnails = None;
}
// we need to exit here to actually capture the error
if !to_process_tx.send(record).await.is_ok() {
break
}
}
}
drop(to_process_tx);
embedding_generation.await?.context("generating embeddings")?;
if let Some(thumbnail_generation) = thumbnail_generation {
thumbnail_generation.await?.context("generating thumbnails")?;
}
if let Some(ocr) = ocr {
ocr.await?.context("OCRing")?;
}
image_loading.await?.context("loading images")?;
let stored: Vec<String> = sqlx::query_scalar("SELECT filename FROM files").fetch_all(&pool).await?;
let mut tx = pool.begin().await?;
for filename in stored {
if !filenames.contains_key(&filename) {
sqlx::query!("DELETE FROM files WHERE filename = ?", filename)
.execute(&mut *tx)
.await?;
}
}
tx.commit().await?;
log::info!("ingest done");
Result::Ok(())
}
const INDEX_ADD_BATCH: usize = 512;
async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<IIndex> {
let pool = initialize_database(&config).await?;
let mut index = IIndex {
// Use a suitable vector similarity search library for Rust
vectors: faiss::index_factory(backend.embedding_size as u32, "SQfp16", faiss::MetricType::InnerProduct)?,
filenames: Vec::new(),
format_codes: Vec::new(),
format_names: Vec::new(),
};
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM files")
.fetch_one(&pool)
.await?;
index.filenames = Vec::with_capacity(count as usize);
index.format_codes = Vec::with_capacity(count as usize);
let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * backend.embedding_size as usize);
index.format_names = Vec::with_capacity(5);
let mut rows = sqlx::query_as::<_, FileRecord>("SELECT * FROM files").fetch(&pool);
while let Some(record) = rows.try_next().await? {
if let Some(emb) = record.embedding {
index.filenames.push(record.filename);
for i in (0..emb.len()).step_by(2) {
buffer.push(
half::f16::from_le_bytes([emb[i], emb[i + 1]])
.to_f32(),
);
}
if buffer.len() == buffer.capacity() {
index.vectors.add(&buffer)?;
buffer.clear();
}
let mut formats: Vec<String> = Vec::new();
if let Some(t) = record.thumbnails {
formats = rmp_serde::from_slice(&t)?;
}
let mut format_code = 0;
for format_string in &formats {
let mut found = false;
for (i, name) in index.format_names.iter().enumerate() {
if name == format_string {
format_code |= 1 << i;
found = true;
break;
}
}
if !found {
let new_index = index.format_names.len();
format_code |= 1 << new_index;
index.format_names.push(format_string.clone());
}
}
index.format_codes.push(format_code);
}
}
if !buffer.is_empty() {
index.vectors.add(&buffer)?;
}
Ok(index)
}
fn decode_fp16_buffer(buf: &[u8]) -> Vec<f32> {
buf.chunks_exact(2)
.map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]).to_f32())
.collect()
}
type EmbeddingVector = Vec<f32>;
#[derive(Debug, Serialize)]
struct QueryResult {
matches: Vec<(f32, String, String, u64)>,
formats: Vec<String>,
extensions: HashMap<String, String>,
}
#[derive(Debug, Deserialize)]
struct QueryTerm {
embedding: Option<EmbeddingVector>,
image: Option<String>,
text: Option<String>,
weight: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct QueryRequest {
terms: Vec<QueryTerm>,
k: Option<usize>,
}
async fn query_index(index: &mut IIndex, query: EmbeddingVector, k: usize) -> Result<QueryResult> {
let result = index.vectors.search(&query, k as usize)?;
let items = result.distances
.into_iter()
.zip(result.labels)
.filter_map(|(distance, id)| {
let id = id.get()? as usize;
Some((
distance,
index.filenames[id].clone(),
generate_filename_hash(&index.filenames[id as usize]).clone(),
index.format_codes[id]
))
})
.collect();
Ok(QueryResult {
matches: items,
formats: index.format_names.clone(),
extensions: HashMap::new(),
})
}
async fn handle_request(
config: &Config,
backend_config: Arc<InferenceServerConfig>,
client: Arc<Client>,
index: &mut IIndex,
req: Json<QueryRequest>,
) -> Result<Response<Body>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; backend_config.embedding_size]);
let mut image_batch = Vec::new();
let mut image_weights = Vec::new();
let mut text_batch = Vec::new();
let mut text_weights = Vec::new();
for term in &req.terms {
if let Some(image) = &term.image {
let bytes = BASE64_STANDARD.decode(image)?;
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?);
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(backend_config.clone(), image).await?));
image_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(text) = &term.text {
text_batch.push(text.clone());
text_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(embedding) = &term.embedding {
let weight = term.weight.unwrap_or(1.0);
for (i, value) in embedding.iter().enumerate() {
total_embedding[i] += value * weight;
}
}
}
let mut batches = vec![];
if !image_batch.is_empty() {
batches.push(
EmbeddingRequest::Images {
images: image_batch
}
);
}
if !text_batch.is_empty() {
batches.push(
EmbeddingRequest::Text {
text: text_batch,
}
);
}
for batch in batches {
let embs: Vec<Vec<u8>> = query_clip_server(&client, config, "/", batch).await?;
for emb in embs {
total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb));
}
}
let k = req.k.unwrap_or(1000);
let qres = query_index(index, total_embedding.to_vec(), k).await?;
let mut extensions = HashMap::new();
for (k, v) in image_formats(config) {
extensions.insert(k, v.extension);
}
Ok(Json(QueryResult {
matches: qres.matches,
formats: qres.formats,
extensions,
}).into_response())
}
async fn get_backend_config(config: &Config) -> Result<InferenceServerConfig> {
let res = Client::new().get(&format!("{}/config", config.clip_server)).send().await?;
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
}
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();
let config_path = std::env::args().nth(1).expect("Missing config file path");
let config: Arc<Config> = Arc::new(serde_json::from_slice(&std::fs::read(config_path)?)?);
let pool = initialize_database(&config).await?;
sqlx::query(SCHEMA).execute(&pool).await?;
let backend = Arc::new(loop {
match get_backend_config(&config).await {
Ok(backend) => break backend,
Err(e) => {
log::error!("Backend failed (fetch): {}", e);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
});
if config.no_run_server {
ingest_files(config.clone(), backend.clone()).await?;
return Ok(())
}
let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1);
let index = Arc::new(tokio::sync::Mutex::new(build_index(config.clone(), backend.clone()).await?));
let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1);
let done_tx = Arc::new(ingest_done_tx.clone());
let _ingest_task = tokio::spawn({
let config = config.clone();
let backend = backend.clone();
let index = index.clone();
async move {
loop {
log::info!("Ingest running");
match ingest_files(config.clone(), backend.clone()).await {
Ok(_) => {
match build_index(config.clone(), backend.clone()).await {
Ok(new_index) => {
*index.lock().await = new_index;
}
Err(e) => {
log::error!("Index build failed: {:?}", e);
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
}
}
}
Err(e) => {
log::error!("Ingest failed: {:?}", e);
ingest_done_tx.send((false, format!("{:?}", e))).unwrap();
}
}
ingest_done_tx.send((true, format!("OK"))).unwrap();
request_ingest_rx.recv().await;
}
}
});
let cors = CorsLayer::permissive();
let config_ = config.clone();
let client = Arc::new(Client::new());
let app = Router::new()
.route("/", post(|req| async move {
let config = config.clone();
let backend_config = backend.clone();
let mut index = index.lock().await; // TODO: use ConcurrentIndex here
let client = client.clone();
handle_request(&config, backend_config, client.clone(), &mut index, req).await.map_err(|e| format!("{:?}", e))
}))
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
"OK"
}))
.route("/reload", post(|_req: axum::http::Request<axum::body::Body>| async move {
log::info!("Requesting index reload");
let mut done_rx = done_tx.clone().subscribe();
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
match done_rx.recv().await {
Ok((true, status)) => {
let mut res = status.into_response();
*res.status_mut() = StatusCode::OK;
res
},
Ok((false, status)) => {
let mut res = status.into_response();
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
},
Err(_) => {
let mut res = "internal error".into_response();
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
}
}
}))
.layer(cors);
let addr = format!("0.0.0.0:{}", config_.port);
log::info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await?;
Ok(())
}

173
src/ocr.rs Normal file
View File

@@ -0,0 +1,173 @@
use anyhow::{anyhow, Result};
use image::{DynamicImage, GenericImageView, ImageFormat};
use regex::Regex;
use reqwest::{
header::{HeaderMap, HeaderValue},
multipart::{Form, Part},
Client,
};
use serde_json::Value;
use std::{io::Cursor, time::{SystemTime, UNIX_EPOCH}};
use serde::{Deserialize, Serialize};
const CALLBACK_REGEX: &str = r">AF_initDataCallback\((\{key: 'ds:1'.*?\})\);</script>";
const MAX_DIM: u32 = 1024;
#[derive(Debug, Serialize, Deserialize)]
pub struct SegmentCoords {
pub x: i32,
pub y: i32,
pub w: i32,
pub h: i32,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Segment {
pub coords: SegmentCoords,
pub text: String,
}
pub type ScanResult = Vec<Segment>;
fn rationalize_coords_format1(
image_w: f64,
image_h: f64,
center_x_fraction: f64,
center_y_fraction: f64,
width_fraction: f64,
height_fraction: f64,
) -> SegmentCoords {
SegmentCoords {
x: ((center_x_fraction - width_fraction / 2.0) * image_w).round() as i32,
y: ((center_y_fraction - height_fraction / 2.0) * image_h).round() as i32,
w: (width_fraction * image_w).round() as i32,
h: (height_fraction * image_h).round() as i32,
}
}
async fn scan_image_chunk(
client: &Client,
image: &[u8],
image_width: u32,
image_height: u32,
) -> Result<ScanResult> {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros();
let part = Part::bytes(image.to_vec())
.file_name(format!("ocr{}.png", timestamp))
.mime_str("image/png")?;
let form = Form::new().part("encoded_image", part);
let mut headers = HeaderMap::new();
headers.insert(
"User-Agent",
HeaderValue::from_static("Mozilla/5.0 (Linux; Android 13; RMX3771) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.6167.144 Mobile Safari/537.36"),
);
headers.insert("Cookie", HeaderValue::from_str(&format!("SOCS=CAESEwgDEgk0ODE3Nzk3MjQaAmVuIAEaBgiA_LyaBg; stcs={}", timestamp))?);
let response = client
.post(&format!("https://lens.google.com/v3/upload?stcs={}", timestamp))
.multipart(form)
.headers(headers)
.send()
.await?;
let body = response.text().await?;
let re = Regex::new(CALLBACK_REGEX)?;
let captures = re
.captures(&body)
.ok_or_else(|| anyhow!("invalid API response"))?;
let match_str = captures.get(1).unwrap().as_str();
let lens_object: Value = json5::from_str(match_str)?;
if lens_object.get("errorHasStatus").is_some() {
return Err(anyhow!("lens failed"));
}
let root = lens_object["data"].as_array().unwrap();
let mut text_segments = Vec::new();
let mut text_regions = Vec::new();
let text_segments_raw = root[3][4][0][0]
.as_array()
.ok_or_else(|| anyhow!("invalid text segments"))?;
let text_regions_raw = root[2][3][0]
.as_array()
.ok_or_else(|| anyhow!("invalid text regions"))?;
for region in text_regions_raw {
let region_data = region.as_array().unwrap();
if region_data[11].as_str().unwrap().starts_with("text:") {
let raw_coords = region_data[1].as_array().unwrap();
let coords = rationalize_coords_format1(
image_width as f64,
image_height as f64,
raw_coords[0].as_f64().unwrap(),
raw_coords[1].as_f64().unwrap(),
raw_coords[2].as_f64().unwrap(),
raw_coords[3].as_f64().unwrap(),
);
text_regions.push(coords);
}
}
for segment in text_segments_raw {
let text_segment = segment.as_str().unwrap().to_string();
text_segments.push(text_segment);
}
Ok(text_segments
.into_iter()
.zip(text_regions.into_iter())
.map(|(text, coords)| Segment { text, coords })
.collect())
}
pub async fn scan_image(client: &Client, image: &DynamicImage) -> Result<ScanResult> {
let mut result = ScanResult::new();
let (width, height) = image.dimensions();
let (width, height, image) = if width > MAX_DIM {
let height = ((height as f64) * (MAX_DIM as f64) / (width as f64)).round() as u32;
let new_image = tokio::task::block_in_place(|| image.resize_exact(MAX_DIM, height, image::imageops::FilterType::Lanczos3));
(MAX_DIM, height, std::borrow::Cow::Owned(new_image))
} else {
(width, height, std::borrow::Cow::Borrowed(image))
};
let mut y = 0;
while y < height {
let chunk_height = (height - y).min(MAX_DIM);
let chunk = tokio::task::block_in_place(|| {
let chunk = image.view(0, y, width, chunk_height).to_image();
let mut buf = Vec::new();
let mut csr = Cursor::new(&mut buf);
chunk.write_to(&mut csr, ImageFormat::Png)?;
Ok::<Vec<u8>, anyhow::Error>(buf)
})?;
let res = scan_image_chunk(client, &chunk, width, chunk_height).await?;
for segment in res {
result.push(Segment {
text: segment.text,
coords: SegmentCoords {
y: segment.coords.y + y as i32,
x: segment.coords.x,
w: segment.coords.w,
h: segment.coords.h,
},
});
}
y += chunk_height;
}
Ok(result)
}