mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-12-30 12:00:31 +00:00
Predefined embedding modes in search
This commit is contained in:
parent
14387a61a3
commit
d8c147df52
26
.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json
generated
Normal file
26
.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json
generated
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"db_name": "SQLite",
|
||||
"query": "SELECT * FROM predefined_embeddings",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "name",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "embedding",
|
||||
"ordinal": 1,
|
||||
"type_info": "Blob"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 0
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b"
|
||||
}
|
@ -60,6 +60,8 @@
|
||||
align-items: center
|
||||
> *
|
||||
margin: 0 2px
|
||||
.sliders-ctrl
|
||||
width: 5em
|
||||
|
||||
.result
|
||||
border: 1px solid gray
|
||||
@ -70,6 +72,9 @@
|
||||
</style>
|
||||
|
||||
<h1>Meme Search Engine</h1>
|
||||
{#if config.n_total}
|
||||
<p>{config.n_total} items indexed.</p>
|
||||
{/if}
|
||||
<details>
|
||||
<summary>Usage tips</summary>
|
||||
<ul>
|
||||
@ -78,6 +83,7 @@
|
||||
<li>In certain circumstances, it may be useful to postfix your query with "meme".</li>
|
||||
<li>Capitalization is ignored.</li>
|
||||
<li>Only English is supported. Other languages might work slightly.</li>
|
||||
<li>Sliders are generated from PCA on the index. The human-readable labels are approximate.</li>
|
||||
</ul>
|
||||
</details>
|
||||
<div class="controls">
|
||||
@ -98,12 +104,21 @@
|
||||
{#if term.type === "embedding"}
|
||||
<span>[embedding loaded from URL]</span>
|
||||
{/if}
|
||||
{#if term.type === "predefined_embedding"}
|
||||
<span>{term.predefinedEmbedding}</span>
|
||||
{/if}
|
||||
</li>
|
||||
{/each}
|
||||
</ul>
|
||||
<div class="ctrlbar">
|
||||
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
|
||||
<button on:click={pickFile}>Image Query</button>
|
||||
<select bind:value={predefinedEmbeddingName} on:change={setPredefinedEmbedding} class="sliders-ctrl">
|
||||
<option>Sliders</option>
|
||||
{#each config.predefined_embedding_names ?? [] as name}
|
||||
<option>{name}</option>
|
||||
{/each}
|
||||
</select>
|
||||
<button on:click={runSearch} style="margin-left: auto">Search</button>
|
||||
</div>
|
||||
</div>
|
||||
@ -151,6 +166,25 @@
|
||||
let queryTerms = []
|
||||
let queryCounter = 0
|
||||
|
||||
let config = {}
|
||||
util.serverConfig.subscribe(x => {
|
||||
config = x
|
||||
})
|
||||
let predefinedEmbeddingName = "Sliders"
|
||||
|
||||
const setPredefinedEmbedding = () => {
|
||||
if (predefinedEmbeddingName !== "Sliders") {
|
||||
queryTerms.push({
|
||||
type: "predefined_embedding",
|
||||
predefinedEmbedding: predefinedEmbeddingName,
|
||||
sign: "+",
|
||||
weight: 0.2
|
||||
})
|
||||
}
|
||||
queryTerms = queryTerms
|
||||
predefinedEmbeddingName = "Sliders"
|
||||
}
|
||||
|
||||
const decodeFloat16 = uint16 => {
|
||||
const sign = (uint16 & 0x8000) ? -1 : 1
|
||||
const exponent = (uint16 & 0x7C00) >> 10
|
||||
@ -200,7 +234,7 @@
|
||||
let displayedResults = []
|
||||
const runSearch = async () => {
|
||||
if (!resultPromise) {
|
||||
let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
|
||||
let args = {"terms": queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
|
||||
queryCounter += 1
|
||||
resultPromise = util.doQuery(args).then(res => {
|
||||
error = null
|
||||
|
@ -1,4 +1,5 @@
|
||||
import * as config from "../../frontend_config.json"
|
||||
import { writable } from "svelte/store"
|
||||
|
||||
export const getURL = x => config.image_path + x[1]
|
||||
|
||||
@ -16,4 +17,10 @@ export const hasFormat = (results, result, format) => {
|
||||
|
||||
export const thumbnailURL = (results, result, format) => {
|
||||
return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}`
|
||||
}
|
||||
}
|
||||
|
||||
export let serverConfig = writable({})
|
||||
fetch(config.backend_url).then(x => x.json().then(x => {
|
||||
serverConfig.set(x)
|
||||
window.serverConfig = x
|
||||
}))
|
14
load_embedding.py
Normal file
14
load_embedding.py
Normal file
@ -0,0 +1,14 @@
|
||||
import numpy as np
|
||||
import sqlite3
|
||||
import base64
|
||||
|
||||
#db = sqlite3.connect("/srv/mse/data.sqlite3")
|
||||
db = sqlite3.connect("data.sqlite3")
|
||||
db.row_factory = sqlite3.Row
|
||||
|
||||
name = input("Name: ")
|
||||
url = input("Embedding search URL: ")
|
||||
data = base64.urlsafe_b64decode(url.removeprefix("https://mse.osmarks.net/?e="))
|
||||
arr = np.frombuffer(data, dtype=np.float16).copy()
|
||||
db.execute("INSERT OR REPLACE INTO predefined_embeddings VALUES (?, ?)", (name, arr.tobytes()))
|
||||
db.commit()
|
49
src/main.rs
49
src/main.rs
@ -27,6 +27,7 @@ use tower_http::cors::CorsLayer;
|
||||
use faiss::index::scalar_quantizer;
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
||||
use ndarray::ArrayBase;
|
||||
|
||||
mod ocr;
|
||||
|
||||
@ -93,6 +94,11 @@ CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
|
||||
content='files'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS predefined_embeddings (
|
||||
name TEXT NOT NULL PRIMARY KEY,
|
||||
embedding BLOB NOT NULL
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
|
||||
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||
END;
|
||||
@ -127,10 +133,11 @@ struct InferenceServerConfig {
|
||||
embedding_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct WConfig {
|
||||
backend: InferenceServerConfig,
|
||||
service: Config
|
||||
service: Config,
|
||||
predefined_embeddings: HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::prelude::Dim<[usize; 1]>>>
|
||||
}
|
||||
|
||||
async fn query_clip_server<I, O>(
|
||||
@ -699,6 +706,7 @@ struct QueryTerm {
|
||||
embedding: Option<EmbeddingVector>,
|
||||
image: Option<String>,
|
||||
text: Option<String>,
|
||||
predefined_embedding: Option<String>,
|
||||
weight: Option<f32>,
|
||||
}
|
||||
|
||||
@ -760,6 +768,10 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
||||
total_embedding[i] += value * weight;
|
||||
}
|
||||
}
|
||||
if let Some(name) = &term.predefined_embedding {
|
||||
let embedding = config.predefined_embeddings.get(name).context("name invalid")?;
|
||||
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
||||
}
|
||||
}
|
||||
|
||||
let mut batches = vec![];
|
||||
@ -806,6 +818,12 @@ async fn get_backend_config(config: &Config) -> Result<InferenceServerConfig> {
|
||||
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct FrontendInit {
|
||||
n_total: u64,
|
||||
predefined_embedding_names: Vec<String>
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
pretty_env_logger::init();
|
||||
@ -826,9 +844,21 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
let mut predefined_embeddings = HashMap::new();
|
||||
|
||||
{
|
||||
let db = initialize_database(&config).await?;
|
||||
let result = sqlx::query!("SELECT * FROM predefined_embeddings")
|
||||
.fetch_all(&db).await?;
|
||||
for row in result {
|
||||
predefined_embeddings.insert(row.name, ndarray::Array::from(decode_fp16_buffer(&row.embedding)));
|
||||
}
|
||||
}
|
||||
|
||||
let config = Arc::new(WConfig {
|
||||
service: config,
|
||||
backend
|
||||
backend,
|
||||
predefined_embeddings
|
||||
});
|
||||
|
||||
if config.service.no_run_server {
|
||||
@ -878,6 +908,8 @@ async fn main() -> Result<()> {
|
||||
|
||||
let config_ = config.clone();
|
||||
let client = Arc::new(Client::new());
|
||||
let index_ = index.clone();
|
||||
let config__ = config.clone();
|
||||
let app = Router::new()
|
||||
.route("/", post(|req| async move {
|
||||
let config = config.clone();
|
||||
@ -886,10 +918,13 @@ async fn main() -> Result<()> {
|
||||
QUERIES_COUNTER.inc();
|
||||
handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
|
||||
}))
|
||||
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
|
||||
"OK"
|
||||
.route("/", get(|_req: ()| async move {
|
||||
Json(FrontendInit {
|
||||
n_total: index_.read().await.vectors.ntotal(),
|
||||
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect()
|
||||
})
|
||||
}))
|
||||
.route("/reload", post(|_req: axum::http::Request<axum::body::Body>| async move {
|
||||
.route("/reload", post(|_req: ()| async move {
|
||||
log::info!("Requesting index reload");
|
||||
let mut done_rx = done_tx.clone().subscribe();
|
||||
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
|
||||
@ -911,7 +946,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
}))
|
||||
.route("/metrics", get(|_req: axum::http::Request<axum::body::Body>| async move {
|
||||
.route("/metrics", get(|_req: ()| async move {
|
||||
let mut buffer = Vec::new();
|
||||
let encoder = prometheus::TextEncoder::new();
|
||||
let metric_families = prometheus::gather();
|
||||
|
Loading…
Reference in New Issue
Block a user