diff --git a/clip_server.py b/clip_server.py index 02b1ac2..bc3bc04 100644 --- a/clip_server.py +++ b/clip_server.py @@ -20,7 +20,7 @@ with open(sys.argv[1], "r") as config_file: CONFIG = json.load(config_file) device = torch.device(CONFIG["device"]) -model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=CONFIG.get("model_path", dict(open_clip.list_pretrained())[CONFIG["model"]], precision="fp16")) +model, _, preprocess = open_clip.create_model_and_transforms(CONFIG["model"], device=device, pretrained=CONFIG.get("model_path", dict(open_clip.list_pretrained())[CONFIG["model"]]), precision="fp16") model.eval() tokenizer = open_clip.get_tokenizer(CONFIG["model"]) print("Model loaded") diff --git a/src/main.rs b/src/main.rs index 46873ca..6c728d1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,7 @@ use lazy_static::lazy_static; use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec}; use tracing::instrument; use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine}; +use mimalloc::MiMalloc; mod ocr; mod common; @@ -41,6 +42,9 @@ mod video_reader; use crate::ocr::scan_image; use crate::common::{InferenceServerConfig, resize_for_embed, EmbeddingRequest, get_backend_config, query_clip_server, decode_fp16_buffer, QueryRequest, QueryResult, EmbeddingVector}; +#[global_allocator] +static GLOBAL: MiMalloc = MiMalloc; + lazy_static! { static ref RELOADS_COUNTER: IntCounter = register_int_counter!("mse_reloads", "reloads executed").unwrap(); static ref QUERIES_COUNTER: IntCounter = register_int_counter!("mse_queries", "queries executed").unwrap(); @@ -808,7 +812,7 @@ async fn ingest_files(config: Arc) -> Result<()> { Result::Ok(()) } -const INDEX_ADD_BATCH: usize = 512; +const INDEX_ADD_BATCH: usize = 1024; #[instrument] async fn build_index(config: Arc) -> Result { @@ -962,7 +966,7 @@ async fn handle_request(config: Arc, client: Arc, index: &IInde #[tokio::main] async fn main() -> Result<()> { - console_subscriber::init(); + tracing_subscriber::fmt().init(); let config_path = std::env::args().nth(1).expect("Missing config file path"); let config: Config = serde_json::from_slice(&std::fs::read(config_path)?)?;