1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-10 22:09:54 +00:00

refactor configuration

This commit is contained in:
osmarks 2024-05-22 19:02:34 +01:00
parent ffc3d648a6
commit 14387a61a3
2 changed files with 53 additions and 51 deletions

View File

@ -15,7 +15,7 @@ They say a picture is worth a thousand words. Unfortunately, many (most?) sets o
## Setup
This is untested. It might work.
This is untested. It might work. The new Rust version simplifies some steps (it integrates its own thumbnailing).
* Serve your meme library from a static webserver.
* I use nginx. If you're in a hurry, you can use `python -m http.server`.
@ -30,19 +30,20 @@ This is untested. It might work.
* `model_name` is the name of the model for metrics purposes.
* `max_batch_size` controls the maximum allowed batch size. Higher values generally result in somewhat better performance (the bottleneck in most cases is elsewhere right now though) at the cost of higher VRAM use.
* `port` is the port to run the HTTP server on.
* Run `mse.py` (also as a background service).
* Build and run `meme-search-engine` (Rust) (also as a background service).
* This needs to be exposed somewhere the frontend can reach it. Configure your reverse proxy appropriately.
* It has a JSON config file as well.
* `clip_server` is the full URL for the backend server.
* `db_path` is the path for the SQLite database of images and embedding vectors.
* `files` is where meme files will be read from. Subdirectories are indexed.
* `port` is the port to serve HTTP on.
* If you are deploying to the public set `enable_thumbs` to `true` to serve compressed images.
* Build clipfront2, host on your favourite static webserver.
* `npm install`, `node src/build.js`.
* You will need to rebuild it whenever you edit `frontend_config.json`.
* `image_path` is the base URL of your meme webserver (with trailing slash).
* `backend_url` is the URL `mse.py` is exposed on (trailing slash probably optional).
* If you want, configure Prometheus to monitor `mse.py` and `clip_server.py`.
* If you want, configure Prometheus to monitor `clip_server.py`.
## MemeThresher

View File

@ -127,6 +127,12 @@ struct InferenceServerConfig {
embedding_size: usize,
}
#[derive(Debug, Deserialize, Clone)]
struct WConfig {
backend: InferenceServerConfig,
service: Config
}
async fn query_clip_server<I, O>(
client: &Client,
config: &Config,
@ -261,11 +267,11 @@ fn image_formats(_config: &Config) -> HashMap<String, ImageFormatConfig> {
formats
}
async fn resize_for_embed(backend_config: Arc<InferenceServerConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> {
async fn resize_for_embed(config: Arc<WConfig>, image: Arc<DynamicImage>) -> Result<Vec<u8>> {
let resized = tokio::task::spawn_blocking(move || {
let new = image.resize(
backend_config.image_size.0,
backend_config.image_size.1,
config.backend.image_size.0,
config.backend.image_size.1,
FilterType::Lanczos3
);
let mut buf = Vec::new();
@ -276,14 +282,14 @@ async fn resize_for_embed(backend_config: Arc<InferenceServerConfig>, image: Arc
Ok(resized)
}
async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<()> {
let pool = initialize_database(&config).await?;
async fn ingest_files(config: Arc<WConfig>) -> Result<()> {
let pool = initialize_database(&config.service).await?;
let client = Client::new();
let formats = image_formats(&config);
let formats = image_formats(&config.service);
let (to_process_tx, to_process_rx) = mpsc::channel::<FileRecord>(100);
let (to_embed_tx, to_embed_rx) = mpsc::channel(backend.batch as usize);
let (to_embed_tx, to_embed_rx) = mpsc::channel(config.backend.batch as usize);
let (to_thumbnail_tx, to_thumbnail_rx) = mpsc::channel(30);
let (to_ocr_tx, to_ocr_rx) = mpsc::channel(30);
@ -292,16 +298,14 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
// Image loading and preliminary resizing
let image_loading: JoinHandle<Result<()>> = tokio::spawn({
let config = config.clone();
let backend = backend.clone();
let stream = ReceiverStream::new(to_process_rx).map(Ok);
stream.try_for_each_concurrent(Some(cpus), move |record| {
let config = config.clone();
let backend = backend.clone();
let to_embed_tx = to_embed_tx.clone();
let to_thumbnail_tx = to_thumbnail_tx.clone();
let to_ocr_tx = to_ocr_tx.clone();
async move {
let path = Path::new(&config.files).join(&record.filename);
let path = Path::new(&config.service.files).join(&record.filename);
let image: Result<Arc<DynamicImage>> = tokio::task::block_in_place(|| Ok(Arc::new(ImageReader::open(&path)?.with_guessed_format()?.decode()?)));
let image = match image {
Ok(image) => image,
@ -313,11 +317,11 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
};
IMAGES_LOADED_COUNTER.inc();
if record.embedding.is_none() {
let resized = resize_for_embed(backend.clone(), image.clone()).await?;
let resized = resize_for_embed(config.clone(), image.clone()).await?;
to_embed_tx.send(EmbeddingInput { image: resized, filename: record.filename.clone() }).await?
}
if record.thumbnails.is_none() && config.enable_thumbs {
if record.thumbnails.is_none() && config.service.enable_thumbs {
to_thumbnail_tx
.send(LoadedImage {
image: image.clone(),
@ -326,7 +330,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
})
.await?;
}
if record.raw_ocr_segments.is_none() && config.enable_ocr {
if record.raw_ocr_segments.is_none() && config.service.enable_ocr {
to_ocr_tx
.send(LoadedImage {
image,
@ -341,7 +345,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
});
// Thumbnail generation
let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.enable_thumbs {
let thumbnail_generation: Option<JoinHandle<Result<()>>> = if config.service.enable_thumbs {
let config = config.clone();
let pool = pool.clone();
let stream = ReceiverStream::new(to_thumbnail_rx).map(Ok);
@ -405,7 +409,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
};
if resized.len() < image.original_size {
generated_formats.push(format_name.clone());
let thumbnail_path = Path::new(&config.thumbs_path).join(
let thumbnail_path = Path::new(&config.service.thumbs_path).join(
generate_thumbnail_filename(
&image.filename,
format_name,
@ -438,12 +442,12 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
};
// OCR
let ocr: Option<JoinHandle<Result<()>>> = if config.enable_ocr {
let ocr: Option<JoinHandle<Result<()>>> = if config.service.enable_ocr {
let client = client.clone();
let pool = pool.clone();
let stream = ReceiverStream::new(to_ocr_rx).map(Ok);
Some(tokio::spawn({
stream.try_for_each_concurrent(Some(config.ocr_concurrency), move |image| {
stream.try_for_each_concurrent(Some(config.service.ocr_concurrency), move |image| {
let client = client.clone();
let pool = pool.clone();
async move {
@ -482,7 +486,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
};
let embedding_generation: JoinHandle<Result<()>> = tokio::spawn({
let stream = ReceiverStream::new(to_embed_rx).chunks(backend.batch);
let stream = ReceiverStream::new(to_embed_rx).chunks(config.backend.batch);
let client = client.clone();
let config = config.clone();
let pool = pool.clone();
@ -494,7 +498,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
async move {
let result: Vec<serde_bytes::ByteBuf> = query_clip_server(
&client,
&config,
&config.service,
"",
EmbeddingRequest::Images {
images: batch.iter().map(|input| serde_bytes::ByteBuf::from(input.image.clone())).collect(),
@ -526,11 +530,11 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
// blocking OS calls
tokio::task::block_in_place(|| -> anyhow::Result<()> {
for entry in WalkDir::new(config.files.as_str()) {
for entry in WalkDir::new(config.service.files.as_str()) {
let entry = entry?;
let path = entry.path();
if path.is_file() {
let filename = path.strip_prefix(&config.files)?.to_str().unwrap().to_string();
let filename = path.strip_prefix(&config.service.files)?.to_str().unwrap().to_string();
let modtime = entry.metadata()?.modified()?.duration_since(std::time::UNIX_EPOCH)?;
let modtime = modtime.as_micros() as i64;
filenames.insert(filename.clone(), (path.to_path_buf(), modtime));
@ -552,7 +556,7 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
filename: filename.clone(),
..Default::default()
}),
Some(r) if modtime > r.embedding_time.unwrap_or(i64::MIN) || (modtime > r.ocr_time.unwrap_or(i64::MIN) && config.enable_ocr) || (modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.enable_thumbs) => {
Some(r) if modtime > r.embedding_time.unwrap_or(i64::MIN) || (modtime > r.ocr_time.unwrap_or(i64::MIN) && config.service.enable_ocr) || (modtime > r.thumbnail_time.unwrap_or(i64::MIN) && config.service.enable_thumbs) => {
Some(r)
},
_ => None
@ -610,11 +614,11 @@ async fn ingest_files(config: Arc<Config>, backend: Arc<InferenceServerConfig>)
const INDEX_ADD_BATCH: usize = 512;
async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -> Result<IIndex> {
let pool = initialize_database(&config).await?;
async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
let pool = initialize_database(&config.service).await?;
let mut index = IIndex {
vectors: scalar_quantizer::ScalarQuantizerIndexImpl::new(backend.embedding_size as u32, scalar_quantizer::QuantizerType::QT_fp16, faiss::MetricType::InnerProduct)?,
vectors: scalar_quantizer::ScalarQuantizerIndexImpl::new(config.backend.embedding_size as u32, scalar_quantizer::QuantizerType::QT_fp16, faiss::MetricType::InnerProduct)?,
filenames: Vec::new(),
format_codes: Vec::new(),
format_names: Vec::new(),
@ -626,7 +630,7 @@ async fn build_index(config: Arc<Config>, backend: Arc<InferenceServerConfig>) -
index.filenames = Vec::with_capacity(count as usize);
index.format_codes = Vec::with_capacity(count as usize);
let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * backend.embedding_size as usize);
let mut buffer = Vec::with_capacity(INDEX_ADD_BATCH * config.backend.embedding_size as usize);
index.format_names = Vec::with_capacity(5);
let mut rows = sqlx::query_as::<_, FileRecord>("SELECT * FROM files").fetch(&pool);
@ -728,14 +732,8 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize) -> Result
})
}
async fn handle_request(
config: &Config,
backend_config: Arc<InferenceServerConfig>,
client: Arc<Client>,
index: &IIndex,
req: Json<QueryRequest>,
) -> Result<Response<Body>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; backend_config.embedding_size]);
async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IIndex, req: Json<QueryRequest>) -> Result<Response<Body>> {
let mut total_embedding = ndarray::Array::from(vec![0.0; config.backend.embedding_size]);
let mut image_batch = Vec::new();
let mut image_weights = Vec::new();
@ -747,7 +745,7 @@ async fn handle_request(
TERMS_COUNTER.get_metric_with_label_values(&["image"]).unwrap().inc();
let bytes = BASE64_STANDARD.decode(image)?;
let image = Arc::new(tokio::task::block_in_place(|| image::load_from_memory(&bytes))?);
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(backend_config.clone(), image).await?));
image_batch.push(serde_bytes::ByteBuf::from(resize_for_embed(config.clone(), image).await?));
image_weights.push(term.weight.unwrap_or(1.0));
}
if let Some(text) = &term.text {
@ -782,7 +780,7 @@ async fn handle_request(
}
for batch in batches {
let embs: Vec<Vec<u8>> = query_clip_server(&client, config, "/", batch).await?;
let embs: Vec<Vec<u8>> = query_clip_server(&client, &config.service, "/", batch).await?;
for emb in embs {
total_embedding += &ndarray::Array::from_vec(decode_fp16_buffer(&emb));
}
@ -792,7 +790,7 @@ async fn handle_request(
let qres = query_index(index, total_embedding.to_vec(), k).await?;
let mut extensions = HashMap::new();
for (k, v) in image_formats(config) {
for (k, v) in image_formats(&config.service) {
extensions.insert(k, v.extension);
}
@ -813,12 +811,12 @@ async fn main() -> Result<()> {
pretty_env_logger::init();
let config_path = std::env::args().nth(1).expect("Missing config file path");
let config: Arc<Config> = Arc::new(serde_json::from_slice(&std::fs::read(config_path)?)?);
let config = serde_json::from_slice(&std::fs::read(config_path)?)?;
let pool = initialize_database(&config).await?;
sqlx::query(SCHEMA).execute(&pool).await?;
let backend = Arc::new(loop {
let backend = loop {
match get_backend_config(&config).await {
Ok(backend) => break backend,
Err(e) => {
@ -826,30 +824,34 @@ async fn main() -> Result<()> {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
};
let config = Arc::new(WConfig {
service: config,
backend
});
if config.no_run_server {
ingest_files(config.clone(), backend.clone()).await?;
if config.service.no_run_server {
ingest_files(config.clone()).await?;
return Ok(())
}
let (request_ingest_tx, mut request_ingest_rx) = mpsc::channel(1);
let index = Arc::new(tokio::sync::RwLock::new(build_index(config.clone(), backend.clone()).await?));
let index = Arc::new(tokio::sync::RwLock::new(build_index(config.clone()).await?));
let (ingest_done_tx, _ingest_done_rx) = broadcast::channel(1);
let done_tx = Arc::new(ingest_done_tx.clone());
let _ingest_task = tokio::spawn({
let config = config.clone();
let backend = backend.clone();
let index = index.clone();
async move {
loop {
log::info!("Ingest running");
match ingest_files(config.clone(), backend.clone()).await {
match ingest_files(config.clone()).await {
Ok(_) => {
match build_index(config.clone(), backend.clone()).await {
match build_index(config.clone()).await {
Ok(new_index) => {
LAST_INDEX_SIZE.set(new_index.vectors.ntotal() as i64);
*index.write().await = new_index;
@ -879,11 +881,10 @@ async fn main() -> Result<()> {
let app = Router::new()
.route("/", post(|req| async move {
let config = config.clone();
let backend_config = backend.clone();
let index = index.read().await; // TODO: use ConcurrentIndex here
let client = client.clone();
QUERIES_COUNTER.inc();
handle_request(&config, backend_config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
}))
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
"OK"
@ -919,7 +920,7 @@ async fn main() -> Result<()> {
}))
.layer(cors);
let addr = format!("0.0.0.0:{}", config_.port);
let addr = format!("0.0.0.0:{}", config_.service.port);
log::info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await?;