1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-01-04 22:40:31 +00:00

Predefined embedding modes in search

This commit is contained in:
osmarks 2024-05-22 20:17:13 +01:00
parent 14387a61a3
commit d8c147df52
5 changed files with 125 additions and 9 deletions

View File

@ -0,0 +1,26 @@
{
"db_name": "SQLite",
"query": "SELECT * FROM predefined_embeddings",
"describe": {
"columns": [
{
"name": "name",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "embedding",
"ordinal": 1,
"type_info": "Blob"
}
],
"parameters": {
"Right": 0
},
"nullable": [
false,
false
]
},
"hash": "d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b"
}

View File

@ -60,6 +60,8 @@
align-items: center align-items: center
> * > *
margin: 0 2px margin: 0 2px
.sliders-ctrl
width: 5em
.result .result
border: 1px solid gray border: 1px solid gray
@ -70,6 +72,9 @@
</style> </style>
<h1>Meme Search Engine</h1> <h1>Meme Search Engine</h1>
{#if config.n_total}
<p>{config.n_total} items indexed.</p>
{/if}
<details> <details>
<summary>Usage tips</summary> <summary>Usage tips</summary>
<ul> <ul>
@ -78,6 +83,7 @@
<li>In certain circumstances, it may be useful to postfix your query with "meme".</li> <li>In certain circumstances, it may be useful to postfix your query with "meme".</li>
<li>Capitalization is ignored.</li> <li>Capitalization is ignored.</li>
<li>Only English is supported. Other languages might work slightly.</li> <li>Only English is supported. Other languages might work slightly.</li>
<li>Sliders are generated from PCA on the index. The human-readable labels are approximate.</li>
</ul> </ul>
</details> </details>
<div class="controls"> <div class="controls">
@ -98,12 +104,21 @@
{#if term.type === "embedding"} {#if term.type === "embedding"}
<span>[embedding loaded from URL]</span> <span>[embedding loaded from URL]</span>
{/if} {/if}
{#if term.type === "predefined_embedding"}
<span>{term.predefinedEmbedding}</span>
{/if}
</li> </li>
{/each} {/each}
</ul> </ul>
<div class="ctrlbar"> <div class="ctrlbar">
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}> <input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
<button on:click={pickFile}>Image Query</button> <button on:click={pickFile}>Image Query</button>
<select bind:value={predefinedEmbeddingName} on:change={setPredefinedEmbedding} class="sliders-ctrl">
<option>Sliders</option>
{#each config.predefined_embedding_names ?? [] as name}
<option>{name}</option>
{/each}
</select>
<button on:click={runSearch} style="margin-left: auto">Search</button> <button on:click={runSearch} style="margin-left: auto">Search</button>
</div> </div>
</div> </div>
@ -151,6 +166,25 @@
let queryTerms = [] let queryTerms = []
let queryCounter = 0 let queryCounter = 0
let config = {}
util.serverConfig.subscribe(x => {
config = x
})
let predefinedEmbeddingName = "Sliders"
const setPredefinedEmbedding = () => {
if (predefinedEmbeddingName !== "Sliders") {
queryTerms.push({
type: "predefined_embedding",
predefinedEmbedding: predefinedEmbeddingName,
sign: "+",
weight: 0.2
})
}
queryTerms = queryTerms
predefinedEmbeddingName = "Sliders"
}
const decodeFloat16 = uint16 => { const decodeFloat16 = uint16 => {
const sign = (uint16 & 0x8000) ? -1 : 1 const sign = (uint16 & 0x8000) ? -1 : 1
const exponent = (uint16 & 0x7C00) >> 10 const exponent = (uint16 & 0x7C00) >> 10
@ -200,7 +234,7 @@
let displayedResults = [] let displayedResults = []
const runSearch = async () => { const runSearch = async () => {
if (!resultPromise) { if (!resultPromise) {
let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))} let args = {"terms": queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
queryCounter += 1 queryCounter += 1
resultPromise = util.doQuery(args).then(res => { resultPromise = util.doQuery(args).then(res => {
error = null error = null

View File

@ -1,4 +1,5 @@
import * as config from "../../frontend_config.json" import * as config from "../../frontend_config.json"
import { writable } from "svelte/store"
export const getURL = x => config.image_path + x[1] export const getURL = x => config.image_path + x[1]
@ -17,3 +18,9 @@ export const hasFormat = (results, result, format) => {
export const thumbnailURL = (results, result, format) => { export const thumbnailURL = (results, result, format) => {
return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}` return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}`
} }
export let serverConfig = writable({})
fetch(config.backend_url).then(x => x.json().then(x => {
serverConfig.set(x)
window.serverConfig = x
}))

14
load_embedding.py Normal file
View File

@ -0,0 +1,14 @@
import numpy as np
import sqlite3
import base64
#db = sqlite3.connect("/srv/mse/data.sqlite3")
db = sqlite3.connect("data.sqlite3")
db.row_factory = sqlite3.Row
name = input("Name: ")
url = input("Embedding search URL: ")
data = base64.urlsafe_b64decode(url.removeprefix("https://mse.osmarks.net/?e="))
arr = np.frombuffer(data, dtype=np.float16).copy()
db.execute("INSERT OR REPLACE INTO predefined_embeddings VALUES (?, ?)", (name, arr.tobytes()))
db.commit()

View File

@ -27,6 +27,7 @@ use tower_http::cors::CorsLayer;
use faiss::index::scalar_quantizer; use faiss::index::scalar_quantizer;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec}; use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
use ndarray::ArrayBase;
mod ocr; mod ocr;
@ -93,6 +94,11 @@ CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
content='files' content='files'
); );
CREATE TABLE IF NOT EXISTS predefined_embeddings (
name TEXT NOT NULL PRIMARY KEY,
embedding BLOB NOT NULL
);
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, '')); INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
END; END;
@ -127,10 +133,11 @@ struct InferenceServerConfig {
embedding_size: usize, embedding_size: usize,
} }
#[derive(Debug, Deserialize, Clone)] #[derive(Debug, Clone)]
struct WConfig { struct WConfig {
backend: InferenceServerConfig, backend: InferenceServerConfig,
service: Config service: Config,
predefined_embeddings: HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::prelude::Dim<[usize; 1]>>>
} }
async fn query_clip_server<I, O>( async fn query_clip_server<I, O>(
@ -699,6 +706,7 @@ struct QueryTerm {
embedding: Option<EmbeddingVector>, embedding: Option<EmbeddingVector>,
image: Option<String>, image: Option<String>,
text: Option<String>, text: Option<String>,
predefined_embedding: Option<String>,
weight: Option<f32>, weight: Option<f32>,
} }
@ -760,6 +768,10 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
total_embedding[i] += value * weight; total_embedding[i] += value * weight;
} }
} }
if let Some(name) = &term.predefined_embedding {
let embedding = config.predefined_embeddings.get(name).context("name invalid")?;
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
}
} }
let mut batches = vec![]; let mut batches = vec![];
@ -806,6 +818,12 @@ async fn get_backend_config(config: &Config) -> Result<InferenceServerConfig> {
Ok(rmp_serde::from_slice(&res.bytes().await?)?) Ok(rmp_serde::from_slice(&res.bytes().await?)?)
} }
#[derive(Serialize, Deserialize)]
struct FrontendInit {
n_total: u64,
predefined_embedding_names: Vec<String>
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
pretty_env_logger::init(); pretty_env_logger::init();
@ -826,9 +844,21 @@ async fn main() -> Result<()> {
} }
}; };
let mut predefined_embeddings = HashMap::new();
{
let db = initialize_database(&config).await?;
let result = sqlx::query!("SELECT * FROM predefined_embeddings")
.fetch_all(&db).await?;
for row in result {
predefined_embeddings.insert(row.name, ndarray::Array::from(decode_fp16_buffer(&row.embedding)));
}
}
let config = Arc::new(WConfig { let config = Arc::new(WConfig {
service: config, service: config,
backend backend,
predefined_embeddings
}); });
if config.service.no_run_server { if config.service.no_run_server {
@ -878,6 +908,8 @@ async fn main() -> Result<()> {
let config_ = config.clone(); let config_ = config.clone();
let client = Arc::new(Client::new()); let client = Arc::new(Client::new());
let index_ = index.clone();
let config__ = config.clone();
let app = Router::new() let app = Router::new()
.route("/", post(|req| async move { .route("/", post(|req| async move {
let config = config.clone(); let config = config.clone();
@ -886,10 +918,13 @@ async fn main() -> Result<()> {
QUERIES_COUNTER.inc(); QUERIES_COUNTER.inc();
handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e)) handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
})) }))
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move { .route("/", get(|_req: ()| async move {
"OK" Json(FrontendInit {
n_total: index_.read().await.vectors.ntotal(),
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect()
})
})) }))
.route("/reload", post(|_req: axum::http::Request<axum::body::Body>| async move { .route("/reload", post(|_req: ()| async move {
log::info!("Requesting index reload"); log::info!("Requesting index reload");
let mut done_rx = done_tx.clone().subscribe(); let mut done_rx = done_tx.clone().subscribe();
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
@ -911,7 +946,7 @@ async fn main() -> Result<()> {
} }
} }
})) }))
.route("/metrics", get(|_req: axum::http::Request<axum::body::Body>| async move { .route("/metrics", get(|_req: ()| async move {
let mut buffer = Vec::new(); let mut buffer = Vec::new();
let encoder = prometheus::TextEncoder::new(); let encoder = prometheus::TextEncoder::new();
let metric_families = prometheus::gather(); let metric_families = prometheus::gather();