diff --git a/.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json b/.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json new file mode 100644 index 0000000..d7d6101 --- /dev/null +++ b/.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json @@ -0,0 +1,26 @@ +{ + "db_name": "SQLite", + "query": "SELECT * FROM predefined_embeddings", + "describe": { + "columns": [ + { + "name": "name", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "embedding", + "ordinal": 1, + "type_info": "Blob" + } + ], + "parameters": { + "Right": 0 + }, + "nullable": [ + false, + false + ] + }, + "hash": "d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b" +} diff --git a/clipfront2/src/App.svelte b/clipfront2/src/App.svelte index b0df369..a2458ba 100644 --- a/clipfront2/src/App.svelte +++ b/clipfront2/src/App.svelte @@ -60,6 +60,8 @@ align-items: center > * margin: 0 2px + .sliders-ctrl + width: 5em .result border: 1px solid gray @@ -70,6 +72,9 @@

Meme Search Engine

+{#if config.n_total} +

{config.n_total} items indexed.

+{/if}
Usage tips
@@ -98,12 +104,21 @@ {#if term.type === "embedding"} [embedding loaded from URL] {/if} + {#if term.type === "predefined_embedding"} + {term.predefinedEmbedding} + {/if} {/each}
+
@@ -151,6 +166,25 @@ let queryTerms = [] let queryCounter = 0 + let config = {} + util.serverConfig.subscribe(x => { + config = x + }) + let predefinedEmbeddingName = "Sliders" + + const setPredefinedEmbedding = () => { + if (predefinedEmbeddingName !== "Sliders") { + queryTerms.push({ + type: "predefined_embedding", + predefinedEmbedding: predefinedEmbeddingName, + sign: "+", + weight: 0.2 + }) + } + queryTerms = queryTerms + predefinedEmbeddingName = "Sliders" + } + const decodeFloat16 = uint16 => { const sign = (uint16 & 0x8000) ? -1 : 1 const exponent = (uint16 & 0x7C00) >> 10 @@ -200,7 +234,7 @@ let displayedResults = [] const runSearch = async () => { if (!resultPromise) { - let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))} + let args = {"terms": queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))} queryCounter += 1 resultPromise = util.doQuery(args).then(res => { error = null diff --git a/clipfront2/src/util.js b/clipfront2/src/util.js index 515ca5e..324d08d 100644 --- a/clipfront2/src/util.js +++ b/clipfront2/src/util.js @@ -1,4 +1,5 @@ import * as config from "../../frontend_config.json" +import { writable } from "svelte/store" export const getURL = x => config.image_path + x[1] @@ -16,4 +17,10 @@ export const hasFormat = (results, result, format) => { export const thumbnailURL = (results, result, format) => { return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}` -} \ No newline at end of file +} + +export let serverConfig = writable({}) +fetch(config.backend_url).then(x => x.json().then(x => { + serverConfig.set(x) + window.serverConfig = x +})) \ No newline at end of file diff --git a/load_embedding.py b/load_embedding.py new file mode 100644 index 0000000..b96fc1a --- /dev/null +++ b/load_embedding.py @@ -0,0 +1,14 @@ +import numpy as np +import sqlite3 +import base64 + +#db = sqlite3.connect("/srv/mse/data.sqlite3") +db = sqlite3.connect("data.sqlite3") +db.row_factory = sqlite3.Row + +name = input("Name: ") +url = input("Embedding search URL: ") +data = base64.urlsafe_b64decode(url.removeprefix("https://mse.osmarks.net/?e=")) +arr = np.frombuffer(data, dtype=np.float16).copy() +db.execute("INSERT OR REPLACE INTO predefined_embeddings VALUES (?, ?)", (name, arr.tobytes())) +db.commit() \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index d5de134..d4379e3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,6 +27,7 @@ use tower_http::cors::CorsLayer; use faiss::index::scalar_quantizer; use lazy_static::lazy_static; use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec}; +use ndarray::ArrayBase; mod ocr; @@ -93,6 +94,11 @@ CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 ( content='files' ); +CREATE TABLE IF NOT EXISTS predefined_embeddings ( + name TEXT NOT NULL PRIMARY KEY, + embedding BLOB NOT NULL +); + CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, '')); END; @@ -127,10 +133,11 @@ struct InferenceServerConfig { embedding_size: usize, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Clone)] struct WConfig { backend: InferenceServerConfig, - service: Config + service: Config, + predefined_embeddings: HashMap, ndarray::prelude::Dim<[usize; 1]>>> } async fn query_clip_server( @@ -699,6 +706,7 @@ struct QueryTerm { embedding: Option, image: Option, text: Option, + predefined_embedding: Option, weight: Option, } @@ -760,6 +768,10 @@ async fn handle_request(config: Arc, client: Arc, index: &IInde total_embedding[i] += value * weight; } } + if let Some(name) = &term.predefined_embedding { + let embedding = config.predefined_embeddings.get(name).context("name invalid")?; + total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0); + } } let mut batches = vec![]; @@ -806,6 +818,12 @@ async fn get_backend_config(config: &Config) -> Result { Ok(rmp_serde::from_slice(&res.bytes().await?)?) } +#[derive(Serialize, Deserialize)] +struct FrontendInit { + n_total: u64, + predefined_embedding_names: Vec +} + #[tokio::main] async fn main() -> Result<()> { pretty_env_logger::init(); @@ -826,9 +844,21 @@ async fn main() -> Result<()> { } }; + let mut predefined_embeddings = HashMap::new(); + + { + let db = initialize_database(&config).await?; + let result = sqlx::query!("SELECT * FROM predefined_embeddings") + .fetch_all(&db).await?; + for row in result { + predefined_embeddings.insert(row.name, ndarray::Array::from(decode_fp16_buffer(&row.embedding))); + } + } + let config = Arc::new(WConfig { service: config, - backend + backend, + predefined_embeddings }); if config.service.no_run_server { @@ -878,6 +908,8 @@ async fn main() -> Result<()> { let config_ = config.clone(); let client = Arc::new(Client::new()); + let index_ = index.clone(); + let config__ = config.clone(); let app = Router::new() .route("/", post(|req| async move { let config = config.clone(); @@ -886,10 +918,13 @@ async fn main() -> Result<()> { QUERIES_COUNTER.inc(); handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e)) })) - .route("/", get(|_req: axum::http::Request| async move { - "OK" + .route("/", get(|_req: ()| async move { + Json(FrontendInit { + n_total: index_.read().await.vectors.ntotal(), + predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect() + }) })) - .route("/reload", post(|_req: axum::http::Request| async move { + .route("/reload", post(|_req: ()| async move { log::info!("Requesting index reload"); let mut done_rx = done_tx.clone().subscribe(); let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full @@ -911,7 +946,7 @@ async fn main() -> Result<()> { } } })) - .route("/metrics", get(|_req: axum::http::Request| async move { + .route("/metrics", get(|_req: ()| async move { let mut buffer = Vec::new(); let encoder = prometheus::TextEncoder::new(); let metric_families = prometheus::gather();