mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-01-04 22:40:31 +00:00
Predefined embedding modes in search
This commit is contained in:
parent
14387a61a3
commit
d8c147df52
26
.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json
generated
Normal file
26
.sqlx/query-d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b.json
generated
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"db_name": "SQLite",
|
||||||
|
"query": "SELECT * FROM predefined_embeddings",
|
||||||
|
"describe": {
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"name": "name",
|
||||||
|
"ordinal": 0,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "embedding",
|
||||||
|
"ordinal": 1,
|
||||||
|
"type_info": "Blob"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"parameters": {
|
||||||
|
"Right": 0
|
||||||
|
},
|
||||||
|
"nullable": [
|
||||||
|
false,
|
||||||
|
false
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"hash": "d56ec6cf262f4f17cd2aaa1a6413bf0fec1d7a35661cfd02feae01e94b807b7b"
|
||||||
|
}
|
@ -60,6 +60,8 @@
|
|||||||
align-items: center
|
align-items: center
|
||||||
> *
|
> *
|
||||||
margin: 0 2px
|
margin: 0 2px
|
||||||
|
.sliders-ctrl
|
||||||
|
width: 5em
|
||||||
|
|
||||||
.result
|
.result
|
||||||
border: 1px solid gray
|
border: 1px solid gray
|
||||||
@ -70,6 +72,9 @@
|
|||||||
</style>
|
</style>
|
||||||
|
|
||||||
<h1>Meme Search Engine</h1>
|
<h1>Meme Search Engine</h1>
|
||||||
|
{#if config.n_total}
|
||||||
|
<p>{config.n_total} items indexed.</p>
|
||||||
|
{/if}
|
||||||
<details>
|
<details>
|
||||||
<summary>Usage tips</summary>
|
<summary>Usage tips</summary>
|
||||||
<ul>
|
<ul>
|
||||||
@ -78,6 +83,7 @@
|
|||||||
<li>In certain circumstances, it may be useful to postfix your query with "meme".</li>
|
<li>In certain circumstances, it may be useful to postfix your query with "meme".</li>
|
||||||
<li>Capitalization is ignored.</li>
|
<li>Capitalization is ignored.</li>
|
||||||
<li>Only English is supported. Other languages might work slightly.</li>
|
<li>Only English is supported. Other languages might work slightly.</li>
|
||||||
|
<li>Sliders are generated from PCA on the index. The human-readable labels are approximate.</li>
|
||||||
</ul>
|
</ul>
|
||||||
</details>
|
</details>
|
||||||
<div class="controls">
|
<div class="controls">
|
||||||
@ -98,12 +104,21 @@
|
|||||||
{#if term.type === "embedding"}
|
{#if term.type === "embedding"}
|
||||||
<span>[embedding loaded from URL]</span>
|
<span>[embedding loaded from URL]</span>
|
||||||
{/if}
|
{/if}
|
||||||
|
{#if term.type === "predefined_embedding"}
|
||||||
|
<span>{term.predefinedEmbedding}</span>
|
||||||
|
{/if}
|
||||||
</li>
|
</li>
|
||||||
{/each}
|
{/each}
|
||||||
</ul>
|
</ul>
|
||||||
<div class="ctrlbar">
|
<div class="ctrlbar">
|
||||||
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
|
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
|
||||||
<button on:click={pickFile}>Image Query</button>
|
<button on:click={pickFile}>Image Query</button>
|
||||||
|
<select bind:value={predefinedEmbeddingName} on:change={setPredefinedEmbedding} class="sliders-ctrl">
|
||||||
|
<option>Sliders</option>
|
||||||
|
{#each config.predefined_embedding_names ?? [] as name}
|
||||||
|
<option>{name}</option>
|
||||||
|
{/each}
|
||||||
|
</select>
|
||||||
<button on:click={runSearch} style="margin-left: auto">Search</button>
|
<button on:click={runSearch} style="margin-left: auto">Search</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@ -151,6 +166,25 @@
|
|||||||
let queryTerms = []
|
let queryTerms = []
|
||||||
let queryCounter = 0
|
let queryCounter = 0
|
||||||
|
|
||||||
|
let config = {}
|
||||||
|
util.serverConfig.subscribe(x => {
|
||||||
|
config = x
|
||||||
|
})
|
||||||
|
let predefinedEmbeddingName = "Sliders"
|
||||||
|
|
||||||
|
const setPredefinedEmbedding = () => {
|
||||||
|
if (predefinedEmbeddingName !== "Sliders") {
|
||||||
|
queryTerms.push({
|
||||||
|
type: "predefined_embedding",
|
||||||
|
predefinedEmbedding: predefinedEmbeddingName,
|
||||||
|
sign: "+",
|
||||||
|
weight: 0.2
|
||||||
|
})
|
||||||
|
}
|
||||||
|
queryTerms = queryTerms
|
||||||
|
predefinedEmbeddingName = "Sliders"
|
||||||
|
}
|
||||||
|
|
||||||
const decodeFloat16 = uint16 => {
|
const decodeFloat16 = uint16 => {
|
||||||
const sign = (uint16 & 0x8000) ? -1 : 1
|
const sign = (uint16 & 0x8000) ? -1 : 1
|
||||||
const exponent = (uint16 & 0x7C00) >> 10
|
const exponent = (uint16 & 0x7C00) >> 10
|
||||||
@ -200,7 +234,7 @@
|
|||||||
let displayedResults = []
|
let displayedResults = []
|
||||||
const runSearch = async () => {
|
const runSearch = async () => {
|
||||||
if (!resultPromise) {
|
if (!resultPromise) {
|
||||||
let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
|
let args = {"terms": queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
|
||||||
queryCounter += 1
|
queryCounter += 1
|
||||||
resultPromise = util.doQuery(args).then(res => {
|
resultPromise = util.doQuery(args).then(res => {
|
||||||
error = null
|
error = null
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import * as config from "../../frontend_config.json"
|
import * as config from "../../frontend_config.json"
|
||||||
|
import { writable } from "svelte/store"
|
||||||
|
|
||||||
export const getURL = x => config.image_path + x[1]
|
export const getURL = x => config.image_path + x[1]
|
||||||
|
|
||||||
@ -17,3 +18,9 @@ export const hasFormat = (results, result, format) => {
|
|||||||
export const thumbnailURL = (results, result, format) => {
|
export const thumbnailURL = (results, result, format) => {
|
||||||
return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}`
|
return `${config.thumb_path}${result[2]}${format}.${results.extensions[format]}`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export let serverConfig = writable({})
|
||||||
|
fetch(config.backend_url).then(x => x.json().then(x => {
|
||||||
|
serverConfig.set(x)
|
||||||
|
window.serverConfig = x
|
||||||
|
}))
|
14
load_embedding.py
Normal file
14
load_embedding.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import numpy as np
|
||||||
|
import sqlite3
|
||||||
|
import base64
|
||||||
|
|
||||||
|
#db = sqlite3.connect("/srv/mse/data.sqlite3")
|
||||||
|
db = sqlite3.connect("data.sqlite3")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
name = input("Name: ")
|
||||||
|
url = input("Embedding search URL: ")
|
||||||
|
data = base64.urlsafe_b64decode(url.removeprefix("https://mse.osmarks.net/?e="))
|
||||||
|
arr = np.frombuffer(data, dtype=np.float16).copy()
|
||||||
|
db.execute("INSERT OR REPLACE INTO predefined_embeddings VALUES (?, ?)", (name, arr.tobytes()))
|
||||||
|
db.commit()
|
49
src/main.rs
49
src/main.rs
@ -27,6 +27,7 @@ use tower_http::cors::CorsLayer;
|
|||||||
use faiss::index::scalar_quantizer;
|
use faiss::index::scalar_quantizer;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
||||||
|
use ndarray::ArrayBase;
|
||||||
|
|
||||||
mod ocr;
|
mod ocr;
|
||||||
|
|
||||||
@ -93,6 +94,11 @@ CREATE VIRTUAL TABLE IF NOT EXISTS ocr_fts USING fts5 (
|
|||||||
content='files'
|
content='files'
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS predefined_embeddings (
|
||||||
|
name TEXT NOT NULL PRIMARY KEY,
|
||||||
|
embedding BLOB NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
|
CREATE TRIGGER IF NOT EXISTS ocr_fts_ins AFTER INSERT ON files BEGIN
|
||||||
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
INSERT INTO ocr_fts (rowid, filename, ocr) VALUES (new.rowid, new.filename, COALESCE(new.ocr, ''));
|
||||||
END;
|
END;
|
||||||
@ -127,10 +133,11 @@ struct InferenceServerConfig {
|
|||||||
embedding_size: usize,
|
embedding_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct WConfig {
|
struct WConfig {
|
||||||
backend: InferenceServerConfig,
|
backend: InferenceServerConfig,
|
||||||
service: Config
|
service: Config,
|
||||||
|
predefined_embeddings: HashMap<String, ArrayBase<ndarray::OwnedRepr<f32>, ndarray::prelude::Dim<[usize; 1]>>>
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn query_clip_server<I, O>(
|
async fn query_clip_server<I, O>(
|
||||||
@ -699,6 +706,7 @@ struct QueryTerm {
|
|||||||
embedding: Option<EmbeddingVector>,
|
embedding: Option<EmbeddingVector>,
|
||||||
image: Option<String>,
|
image: Option<String>,
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
|
predefined_embedding: Option<String>,
|
||||||
weight: Option<f32>,
|
weight: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -760,6 +768,10 @@ async fn handle_request(config: Arc<WConfig>, client: Arc<Client>, index: &IInde
|
|||||||
total_embedding[i] += value * weight;
|
total_embedding[i] += value * weight;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if let Some(name) = &term.predefined_embedding {
|
||||||
|
let embedding = config.predefined_embeddings.get(name).context("name invalid")?;
|
||||||
|
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut batches = vec![];
|
let mut batches = vec![];
|
||||||
@ -806,6 +818,12 @@ async fn get_backend_config(config: &Config) -> Result<InferenceServerConfig> {
|
|||||||
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
|
Ok(rmp_serde::from_slice(&res.bytes().await?)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct FrontendInit {
|
||||||
|
n_total: u64,
|
||||||
|
predefined_embedding_names: Vec<String>
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
pretty_env_logger::init();
|
pretty_env_logger::init();
|
||||||
@ -826,9 +844,21 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut predefined_embeddings = HashMap::new();
|
||||||
|
|
||||||
|
{
|
||||||
|
let db = initialize_database(&config).await?;
|
||||||
|
let result = sqlx::query!("SELECT * FROM predefined_embeddings")
|
||||||
|
.fetch_all(&db).await?;
|
||||||
|
for row in result {
|
||||||
|
predefined_embeddings.insert(row.name, ndarray::Array::from(decode_fp16_buffer(&row.embedding)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let config = Arc::new(WConfig {
|
let config = Arc::new(WConfig {
|
||||||
service: config,
|
service: config,
|
||||||
backend
|
backend,
|
||||||
|
predefined_embeddings
|
||||||
});
|
});
|
||||||
|
|
||||||
if config.service.no_run_server {
|
if config.service.no_run_server {
|
||||||
@ -878,6 +908,8 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let config_ = config.clone();
|
let config_ = config.clone();
|
||||||
let client = Arc::new(Client::new());
|
let client = Arc::new(Client::new());
|
||||||
|
let index_ = index.clone();
|
||||||
|
let config__ = config.clone();
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/", post(|req| async move {
|
.route("/", post(|req| async move {
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
@ -886,10 +918,13 @@ async fn main() -> Result<()> {
|
|||||||
QUERIES_COUNTER.inc();
|
QUERIES_COUNTER.inc();
|
||||||
handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
|
handle_request(config, client.clone(), &index, req).await.map_err(|e| format!("{:?}", e))
|
||||||
}))
|
}))
|
||||||
.route("/", get(|_req: axum::http::Request<axum::body::Body>| async move {
|
.route("/", get(|_req: ()| async move {
|
||||||
"OK"
|
Json(FrontendInit {
|
||||||
|
n_total: index_.read().await.vectors.ntotal(),
|
||||||
|
predefined_embedding_names: config__.predefined_embeddings.keys().cloned().collect()
|
||||||
|
})
|
||||||
}))
|
}))
|
||||||
.route("/reload", post(|_req: axum::http::Request<axum::body::Body>| async move {
|
.route("/reload", post(|_req: ()| async move {
|
||||||
log::info!("Requesting index reload");
|
log::info!("Requesting index reload");
|
||||||
let mut done_rx = done_tx.clone().subscribe();
|
let mut done_rx = done_tx.clone().subscribe();
|
||||||
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
|
let _ = request_ingest_tx.send(()).await; // ignore possible error, which is presumably because the queue is full
|
||||||
@ -911,7 +946,7 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
.route("/metrics", get(|_req: axum::http::Request<axum::body::Body>| async move {
|
.route("/metrics", get(|_req: ()| async move {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
let encoder = prometheus::TextEncoder::new();
|
let encoder = prometheus::TextEncoder::new();
|
||||||
let metric_families = prometheus::gather();
|
let metric_families = prometheus::gather();
|
||||||
|
Loading…
Reference in New Issue
Block a user