diff --git a/.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json b/.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json
new file mode 100644
index 0000000..d7d6101
--- /dev/null
+++ b/.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json
@@ -0,0 +1,26 @@
+{
+ "db_name": "SQLite",
+ "query": "SELECT * FROM predefined_embeddings",
+ "describe": {
+ "columns": [
+ {
+ "name": "name",
+ "ordinal": 0,
+ "type_info": "Text"
+ },
+ {
+ "name": "embedding",
+ "ordinal": 1,
+ "type_info": "Blob"
+ }
+ ],
+ "parameters": {
+ "Right": 0
+ },
+ "nullable": [
+ false,
+ false
+ ]
+ },
+ "hash": "d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b"
+}
diff --git a/clipfront2/src/App.svelte b/clipfront2/src/App.svelte
index b0df369..a2458ba 100644
--- a/clipfront2/src/App.svelte
+++ b/clipfront2/src/App.svelte
@@ -60,6 +60,8 @@
align-items: center
> *
margin: 0 2px
+ .sliders-ctrl
+ width: 5em
.result
border: 1px solid gray
@@ -70,6 +72,9 @@
Meme Search Engine
+{#if config.n_total}
+ {config.n_total} items indexed.
+{/if}
Usage tips
@@ -78,6 +83,7 @@
- In certain circumstances, it may be useful to postfix your query with "meme".
- Capitalization is ignored.
- Only English is supported. Other languages might work slightly.
+ - Sliders are generated from PCA on the index. The human-readable labels are approximate.
@@ -98,12 +104,21 @@
{#if term.type === "embedding"}
[embedding loaded from URL]
{/if}
+ {#if term.type === "predefined_embedding"}
+
{term.predefinedEmbedding}
+ {/if}
{/each}
+
@@ -151,6 +166,25 @@
let queryTerms = []
let queryCounter = 0
+ let config = {}
+ util.serverConfig.subscribe(x => {
+ config = x
+ })
+ let predefinedEmbeddingName = "Sliders"
+
+ const setPredefinedEmbedding = () => {
+ if (predefinedEmbeddingName !== "Sliders") {
+ queryTerms.push({
+ type: "predefined_embedding",
+ predefinedEmbedding: predefinedEmbeddingName,
+ sign: "+",
+ weight: 0.2
+ })
+ }
+ queryTerms = queryTerms
+ predefinedEmbeddingName = "Sliders"
+ }
+
const decodeFloat16 = uint16 => {
const sign = (uint16 & 0x8000) ? -1 : 1
const exponent = (uint16 & 0x7C00) >> 10
@@ -200,7 +234,7 @@
let displayedResults = []
const runSearch = async () => {
if (!resultPromise) {
- let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
+ let args = {"terms": queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
queryCounter += 1
resultPromise = util.doQuery(args).then(res => {
error = null
diff --git a/clipfront2/src/util.js b/clipfront2/src/util.js
index 515ca5e..324d08d 100644
--- a/clipfront2/src/util.js
+++ b/clipfront2/src/util.js
@@ -1,4 +1,5 @@
import * as config from "../../frontend_config.json"
+import { writable } from "svelte/store"
export const getURL = x => config.image_path + x[1]
@@ -16,4 +17,10 @@ export const hasFormat = (results, result, format) => {
export const thumbnailURL = (results, result, format) => {
return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}`
-}
\ No newline at end of file
+}
+
+export let serverConfig = writable({})
+fetch(config.backend_url).then(x => x.json().then(x => {
+ serverConfig.set(x)
+ window.serverConfig = x
+}))
\ No newline at end of file
diff --git a/load_embedding.py b/load_embedding.py
new file mode 100644
index 0000000..b96fc1a
--- /dev/null
+++ b/load_embedding.py
@@ -0,0 +1,14 @@
+import numpy as np
+import sqlite3
+import base64
+
+#db = sqlite3.connect("/srv/mse/data.sqlite3")
+db = sqlite3.connect("data.sqlite3")
+db.row_factory = sqlite3.Row
+
+name = input("Name: ")
+url = input("Embedding search URL: ")
+data = base64.urlsafe_b64decode(url.removeprefix("https://mse.osmarks.net/?e="))
+arr = np.frombuffer(data, dtype=np.float16).copy()
+db.execute("INSERT OR REPLACE INTO predefined_embeddings VALUES (?, ?)", (name, arr.tobytes()))
+db.commit()
\ No newline at end of file
diff --git a/src/main.rs b/src/main.rs
index d5de134..d4379e3 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -27,6 +27,7 @@ use tower_http::cors::CorsLayer;
use faiss::index::scalar_quantizer;
use lazy_static::lazy_static;
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
+use ndarray::ArrayBase;
mod ocr;
@@ -93,6 +94,11 @@ CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
content='files'
);
+CREATE TABLE IF NOT EXISTS predefined_embeddings (
+ name TEXT NOT NULL PRIMARY KEY,
+ embedding BLOB NOT NULL
+);
+
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END;
@@ -127,10 +133,11 @@ struct InferenceServerConfig {
embedding_size: usize,
}
-#[derive(Debug, Deserialize, Clone)]
+#[derive(Debug, Clone)]
struct WConfig {
backend: InferenceServerConfig,
- service: Config
+ service: Config,
+ predefined_embeddings: HashMap, ndarray::prelude::Dim<[usize; 1]>>>
}
async fn query_clip_server(
@@ -699,6 +706,7 @@ struct QueryTerm {
embedding: Option,
image: Option,
text: Option,
+ predefined_embedding: Option,
weight: Option,
}
@@ -760,6 +768,10 @@ async fn handle_request(config: Arc, client: Arc, index: &IInde
total_embedding[i] += value * weight;
}
}
+ if let Some(name) = &term.predefined_embedding {
+ let embedding = config.predefined_embeddings.get(name).context("name invalid")?;
+ total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
+ }
}
let mut batches = vec![];
@@ -806,6 +818,12 @@ async fn get_backend_config(config: &Config) -> Result {
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
}
+#[derive(Serialize, Deserialize)]
+struct FrontendInit {
+ n_total: u64,
+ predefined_embedding_names: Vec
+}
+
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();
@@ -826,9 +844,21 @@ async fn main() -> Result<()> {
}
};
+ let mut predefined_embeddings = HashMap::new();
+
+ {
+ let db = initialize_database(&config).await?;
+ let result = sqlx::query!("SELECT * FROM predefined_embeddings")
+ .fetch_all(&db).await?;
+ for row in result {
+ predefined_embeddings.insert(row.name, ndarray::Array::from(decode_fp16_buffer(&row.embedding)));
+ }
+ }
+
let config = Arc::new(WConfig {
service: config,
- backend
+ backend,
+ predefined_embeddings
});
if config.service.no_run_server {
@@ -878,6 +908,8 @@ async fn main() -> Result<()> {
let config_ = config.clone();
let client = Arc::new(Client::new());
+ let index_ = index.clone();
+ let config__ = config.clone();
let app = Router::new()
.route("/", post(|req| async move {
let config = config.clone();
@@ -886,10 +918,13 @@ async fn main() -> Result<()> {
QUERIES_COUNTER.inc();
handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
}))
- .route("/", get(|_req: axum::http::Request| async move {
- "OK"
+ .route("/", get(|_req: ()| async move {
+ Json(FrontendInit {
+ n_total: index_.read().await.vectors.ntotal(),
+ predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect()
+ })
}))
- .route("/reload", post(|_req: axum::http::Request| async move {
+ .route("/reload", post(|_req: ()| async move {
log::info!("Requesting index reload");
let mut done_rx = done_tx.clone().subscribe();
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
@@ -911,7 +946,7 @@ async fn main() -> Result<()> {
}
}
}))
- .route("/metrics", get(|_req: axum::http::Request| async move {
+ .route("/metrics", get(|_req: ()| async move {
let mut buffer = Vec::new();
let encoder = prometheus::TextEncoder::new();
let metric_families = prometheus::gather();