1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-13 07:19:54 +00:00

refactor configuration

This commit is contained in:
osmarks 2024-05-22 19:02:34 +01:00
parent ffc3d648a6
commit 14387a61a3
2 changed files with 53 additions and 51 deletions

View File

@ -15,7 +15,7 @@ They say a picture is worth a thousand words. Unfortunately, many (most?) sets o
## Setup ## Setup
This is untested. It might work. This is untested. It might work. The new Rust version simplifies some steps (it integrates its own thumbnailing).
* Serve your meme library from a static webserver. * Serve your meme library from a static webserver.
* I use nginx. If you're in a hurry, you can use `python -m http.server`. * I use nginx. If you're in a hurry, you can use `python -m http.server`.
@ -30,19 +30,20 @@ This is untested. It might work.
* `model_name` is the name of the model for metrics purposes. * `model_name` is the name of the model for metrics purposes.
* `max_batch_size` controls the maximum allowed batch size. Higher values generally result in somewhat better performance (the bottleneck in most cases is elsewhere right now though) at the cost of higher VRAM use. * `max_batch_size` controls the maximum allowed batch size. Higher values generally result in somewhat better performance (the bottleneck in most cases is elsewhere right now though) at the cost of higher VRAM use.
* `port` is the port to run the HTTP server on. * `port` is the port to run the HTTP server on.
* Run `mse.py` (also as a background service). * Build and run `meme-search-engine` (Rust) (also as a background service).
* This needs to be exposed somewhere the frontend can reach it. Configure your reverse proxy appropriately. * This needs to be exposed somewhere the frontend can reach it. Configure your reverse proxy appropriately.
* It has a JSON config file as well. * It has a JSON config file as well.
* `clip_server` is the full URL for the backend server. * `clip_server` is the full URL for the backend server.
* `db_path` is the path for the SQLite database of images and embedding vectors. * `db_path` is the path for the SQLite database of images and embedding vectors.
* `files` is where meme files will be read from. Subdirectories are indexed. * `files` is where meme files will be read from. Subdirectories are indexed.
* `port` is the port to serve HTTP on. * `port` is the port to serve HTTP on.
* If you are deploying to the public set `enable_thumbs` to `true` to serve compressed images.
* Build clipfront2, host on your favourite static webserver. * Build clipfront2, host on your favourite static webserver.
* `npm install`, `node src/build.js`. * `npm install`, `node src/build.js`.
* You will need to rebuild it whenever you edit `frontend_config.json`. * You will need to rebuild it whenever you edit `frontend_config.json`.
* `image_path` is the base URL of your meme webserver (with trailing slash). * `image_path` is the base URL of your meme webserver (with trailing slash).
* `backend_url` is the URL `mse.py` is exposed on (trailing slash probably optional). * `backend_url` is the URL `mse.py` is exposed on (trailing slash probably optional).
* If you want, configure Prometheus to monitor `mse.py` and `clip_server.py`. * If you want, configure Prometheus to monitor `clip_server.py`.
## MemeThresher ## MemeThresher

View File

@ -127,6 +127,12 @@ struct InferenceServerConfig {
embedding_size: usize, embedding_size: usize,
} }
#[derive(Debug, Deserialize, Clone)]
struct WConfig {
backend: InferenceServerConfig,
service: Config
}
async fn query_clip_server<I, O>( async fn query_clip_server<I, O>(
client: &Client, client: &Client,
config: &Config, config: &Config,
@ -261,11 +267,11 @@ fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
formats formats
} }
async fn resize_for_embed(backend_config: Arc<InferenceServerConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> { async fn resize_for_embed(config: Arc<WConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> {
let resized = tokio::task::spawn_blocking(move || { let resized = tokio::task::spawn_blocking(move || {
let new = image.resize( let new = image.resize(
backend_config.image_size.0, config.backend.image_size.0,
backend_config.image_size.1, config.backend.image_size.1,
FilterType::Lanczos3 FilterType::Lanczos3
); );
let mut buf = Vec::new(); let mut buf = Vec::new();
@ -276,14 +282,14 @@ async fn resize_for_embed(backend_config: Arc<InferenceServerConfig>, image: Arc
Ok(resized) Ok(resized)
} }
async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<()> { async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
let pool = initialize_database(&config).await?; let pool = initialize_database(&config.service).await?;
let client = Client::new(); let client = Client::new();
let formats = image_formats(&config); let formats = image_formats(&config.service);
let (to_process_tx, to_process_rx) = mpsc::channel::<FileRecord>(100); let (to_process_tx, to_process_rx) = mpsc::channel::<FileRecord>(100);
let (to_embed_tx, to_embed_rx) = mpsc::channel(backend.batch as usize); let (to_embed_tx, to_embed_rx) = mpsc::channel(config.backend.batch as usize);
let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30); let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30);
let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30); let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30);
@ -292,16 +298,14 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
// Image loading and preliminary resizing // Image loading and preliminary resizing
let image_loading: JoinHandle<Result<()>> = tokio::spawn({ let image_loading: JoinHandle<Result<()>> = tokio::spawn({
let config = config.clone(); let config = config.clone();
let backend = backend.clone();
let stream = ReceiverStream::new(to_process_rx).map(Ok); let stream = ReceiverStream::new(to_process_rx).map(Ok);
stream.try_for_each_concurrent(Some(cpus), move |record| { stream.try_for_each_concurrent(Some(cpus), move |record| {
let config = config.clone(); let config = config.clone();
let backend = backend.clone();
let to_embed_tx = to_embed_tx.clone(); let to_embed_tx = to_embed_tx.clone();
let to_thumbnail_tx = to_thumbnail_tx.clone(); let to_thumbnail_tx = to_thumbnail_tx.clone();
let to_ocr_tx = to_ocr_tx.clone(); let to_ocr_tx = to_ocr_tx.clone();
async move { async move {
let path = Path::new(&config.files).join(&record.filename); let path = Path::new(&config.service.files).join(&record.filename);
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?))); let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
let image = match image { let image = match image {
Ok(image) => image, Ok(image) => image,
@ -313,11 +317,11 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
}; };
IMAGES_LOADED_COUNTER.inc(); IMAGES_LOADED_COUNTER.inc();
if record.embedding.is_none() { if record.embedding.is_none() {
let resized = resize_for_embed(backend.clone(), image.clone()).await?; let resized = resize_for_embed(config.clone(), image.clone()).await?;
to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await? to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await?
} }
if record.thumbnails.is_none() && config.enable_thumbs { if record.thumbnails.is_none() && config.service.enable_thumbs {
to_thumbnail_tx to_thumbnail_tx
.send(LoadedImage { .send(LoadedImage {
image: image.clone(), image: image.clone(),
@ -326,7 +330,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
}) })
.await?; .await?;
} }
if record.raw_ocr_segments.is_none() && config.enable_ocr { if record.raw_ocr_segments.is_none() && config.service.enable_ocr {
to_ocr_tx to_ocr_tx
.send(LoadedImage { .send(LoadedImage {
image, image,
@ -341,7 +345,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
}); });
// Thumbnail generation // Thumbnail generation
let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.enable_thumbs { let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.service.enable_thumbs {
let config = config.clone(); let config = config.clone();
let pool = pool.clone(); let pool = pool.clone();
let stream = ReceiverStream::new(to_thumbnail_rx).map(Ok); let stream = ReceiverStream::new(to_thumbnail_rx).map(Ok);
@ -405,7 +409,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
}; };
if resized.len() < image.original_size { if resized.len() < image.original_size {
generated_formats.push(format_name.clone()); generated_formats.push(format_name.clone());
let thumbnail_path = Path::new(&config.thumbs_path).join( let thumbnail_path = Path::new(&config.service.thumbs_path).join(
generate_thumbnail_filename( generate_thumbnail_filename(
&image.filename, &image.filename,
format_name, format_name,
@ -438,12 +442,12 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
}; };
// OCR // OCR
let ocr: Option<JoinHandle<Result<()>>> = if config.enable_ocr { let ocr: Option<JoinHandle<Result<()>>> = if config.service.enable_ocr {
let client = client.clone(); let client = client.clone();
let pool = pool.clone(); let pool = pool.clone();
let stream = ReceiverStream::new(to_ocr_rx).map(Ok); let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
Some(tokio::spawn({ Some(tokio::spawn({
stream.try_for_each_concurrent(Some(config.ocr_concurrency), move |image| { stream.try_for_each_concurrent(Some(config.service.ocr_concurrency), move |image| {
let client = client.clone(); let client = client.clone();
let pool = pool.clone(); let pool = pool.clone();
async move { async move {
@ -482,7 +486,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
}; };
let embedding_generation: JoinHandle<Result<()>> = tokio::spawn({ let embedding_generation: JoinHandle<Result<()>> = tokio::spawn({
let stream = ReceiverStream::new(to_embed_rx).chunks(backend.batch); let stream = ReceiverStream::new(to_embed_rx).chunks(config.backend.batch);
let client = client.clone(); let client = client.clone();
let config = config.clone(); let config = config.clone();
let pool = pool.clone(); let pool = pool.clone();
@ -494,7 +498,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
async move { async move {
let result: Vec<serde_bytes::ByteBuf> = query_clip_server( let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
&client, &client,
&config, &config.service,
"", "",
EmbeddingRequest::Images { EmbeddingRequest::Images {
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(), images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
@ -526,11 +530,11 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
// blocking OS calls // blocking OS calls
tokio::task::block_in_place(|| -> anyhow::Result<()> { tokio::task::block_in_place(|| -> anyhow::Result<()> {
for entry in WalkDir::new(config.files.as_str()) { for entry in WalkDir::new(config.service.files.as_str()) {
let entry = entry?; let entry = entry?;
let path = entry.path(); let path = entry.path();
if path.is_file() { if path.is_file() {
let filename = path.strip_prefix(&config.files)?.to_str().unwrap().to_string(); let filename = path.strip_prefix(&config.service.files)?.to_str().unwrap().to_string();
let modtime = entry.metadata()?.modified()?.duration_since(std::time::UNIX_EPOCH)?; let modtime = entry.metadata()?.modified()?.duration_since(std::time::UNIX_EPOCH)?;
let modtime = modtime.as_micros() as i64; let modtime = modtime.as_micros() as i64;
filenames.insert(filename.clone(), (path.to_path_buf(), modtime)); filenames.insert(filename.clone(), (path.to_path_buf(), modtime));
@ -552,7 +556,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
filename: filename.clone(), filename: filename.clone(),
..Default::default() ..Default::default()
}), }),
Some(r) if modtime > r.embedding_time.unwrap_or(i64::MIN) || (modtime > r.ocr_time.unwrap_or(i64::MIN) && config.enable_ocr) || (modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.enable_thumbs) => { Some(r) if modtime > r.embedding_time.unwrap_or(i64::MIN) || (modtime > r.ocr_time.unwrap_or(i64::MIN) && config.service.enable_ocr) || (modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.service.enable_thumbs) => {
Some(r) Some(r)
}, },
_ => None _ => None
@ -610,11 +614,11 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
const INDEX_ADD_BATCH: usize = 512; const INDEX_ADD_BATCH: usize = 512;
async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<IIndex> { async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
let pool = initialize_database(&config).await?; let pool = initialize_database(&config.service).await?;
let mut index = IIndex { let mut index = IIndex {
vectors: scalar_quantizer::ScalarQuantizerIndexImpl::new(backend.embedding_size as u32, scalar_quantizer::QuantizerType::QT_fp16, faiss::MetricType::InnerProduct)?, vectors: scalar_quantizer::ScalarQuantizerIndexImpl::new(config.backend.embedding_size as u32, scalar_quantizer::QuantizerType::QT_fp16, faiss::MetricType::InnerProduct)?,
filenames: Vec::new(), filenames: Vec::new(),
format_codes: Vec::new(), format_codes: Vec::new(),
format_names: Vec::new(), format_names: Vec::new(),
@ -626,7 +630,7 @@ async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -
index.filenames = Vec::with_capacity(count as usize); index.filenames = Vec::with_capacity(count as usize);
index.format_codes = Vec::with_capacity(count as usize); index.format_codes = Vec::with_capacity(count as usize);
let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * backend.embedding_size as usize); let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * config.backend.embedding_size as usize);
index.format_names = Vec::with_capacity(5); index.format_names = Vec::with_capacity(5);
let mut rows = sqlx::query_as::<_, FileRecord>("SELECT * FROM files").fetch(&pool); let mut rows = sqlx::query_as::<_, FileRecord>("SELECT * FROM files").fetch(&pool);
@ -728,14 +732,8 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize) -> Result
}) })
} }
async fn handle_request( async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
config: &Config, let mut total_embedding = ndarray::Array::from(vec![0.0; config.backend.embedding_size]);
backend_config: Arc<InferenceServerConfig>,
client: Arc<Client>,
index: &IIndex,
req: Json<QueryRequest>,
) -> Result<Response<Body>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; backend_config.embedding_size]);
let mut image_batch = Vec::new(); let mut image_batch = Vec::new();
let mut image_weights = Vec::new(); let mut image_weights = Vec::new();
@ -747,7 +745,7 @@ async fn handle_request(
TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc(); TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc();
let bytes = BASE64_STANDARD.decode(image)?; let bytes = BASE64_STANDARD.decode(image)?;
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?); let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?);
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(backend_config.clone(), image).await?)); image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.clone(), image).await?));
image_weights.push(term.weight.unwrap_or(1.0)); image_weights.push(term.weight.unwrap_or(1.0));
} }
if let Some(text) = &term.text { if let Some(text) = &term.text {
@ -782,7 +780,7 @@ async fn handle_request(
} }
for batch in batches { for batch in batches {
let embs: Vec<Vec<u8>> = query_clip_server(&client, config, "/", batch).await?; let embs: Vec<Vec<u8>> = query_clip_server(&client, &config.service, "/", batch).await?;
for emb in embs { for emb in embs {
total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb)); total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb));
} }
@ -792,7 +790,7 @@ async fn handle_request(
let qres = query_index(index, total_embedding.to_vec(), k).await?; let qres = query_index(index, total_embedding.to_vec(), k).await?;
let mut extensions = HashMap::new(); let mut extensions = HashMap::new();
for (k, v) in image_formats(config) { for (k, v) in image_formats(&config.service) {
extensions.insert(k, v.extension); extensions.insert(k, v.extension);
} }
@ -813,12 +811,12 @@ async fn main() -> Result<()> {
pretty_env_logger::init(); pretty_env_logger::init();
let config_path = std::env::args().nth(1).expect("Missing config file path"); let config_path = std::env::args().nth(1).expect("Missing config file path");
let config: Arc<Config> = Arc::new(serde_json::from_slice(&std::fs::read(config_path)?)?); let config = serde_json::from_slice(&std::fs::read(config_path)?)?;
let pool = initialize_database(&config).await?; let pool = initialize_database(&config).await?;
sqlx::query(SCHEMA).execute(&pool).await?; sqlx::query(SCHEMA).execute(&pool).await?;
let backend = Arc::new(loop { let backend = loop {
match get_backend_config(&config).await { match get_backend_config(&config).await {
Ok(backend) => break backend, Ok(backend) => break backend,
Err(e) => { Err(e) => {
@ -826,30 +824,34 @@ async fn main() -> Result<()> {
tokio::time::sleep(std::time::Duration::from_secs(1)).await; tokio::time::sleep(std::time::Duration::from_secs(1)).await;
} }
} }
};
let config = Arc::new(WConfig {
service: config,
backend
}); });
if config.no_run_server { if config.service.no_run_server {
ingest_files(config.clone(), backend.clone()).await?; ingest_files(config.clone()).await?;
return Ok(()) return Ok(())
} }
let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1); let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1);
let index = Arc::new(tokio::sync::RwLock::new(build_index(config.clone(), backend.clone()).await?)); let index = Arc::new(tokio::sync::RwLock::new(build_index(config.clone()).await?));
let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1); let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1);
let done_tx = Arc::new(ingest_done_tx.clone()); let done_tx = Arc::new(ingest_done_tx.clone());
let _ingest_task = tokio::spawn({ let _ingest_task = tokio::spawn({
let config = config.clone(); let config = config.clone();
let backend = backend.clone();
let index = index.clone(); let index = index.clone();
async move { async move {
loop { loop {
log::info!("Ingest running"); log::info!("Ingest running");
match ingest_files(config.clone(), backend.clone()).await { match ingest_files(config.clone()).await {
Ok(_) => { Ok(_) => {
match build_index(config.clone(), backend.clone()).await { match build_index(config.clone()).await {
Ok(new_index) => { Ok(new_index) => {
LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64); LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64);
*index.write().await = new_index; *index.write().await = new_index;
@ -879,11 +881,10 @@ async fn main() -> Result<()> {
let app = Router::new() let app = Router::new()
.route("/", post(|req| async move { .route("/", post(|req| async move {
let config = config.clone(); let config = config.clone();
let backend_config = backend.clone();
let index = index.read().await; // TODO: use ConcurrentIndex here let index = index.read().await; // TODO: use ConcurrentIndex here
let client = client.clone(); let client = client.clone();
QUERIES_COUNTER.inc(); QUERIES_COUNTER.inc();
handle_request(&config, backend_config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e)) handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
})) }))
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move { .route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
"OK" "OK"
@ -919,7 +920,7 @@ async fn main() -> Result<()> {
})) }))
.layer(cors); .layer(cors);
let addr = format!("0.0.0.0:{}", config_.port); let addr = format!("0.0.0.0:{}", config_.service.port);
log::info!("Starting server on {}", addr); log::info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await?; axum::serve(listener, app).await?;