mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-13 07:19:54 +00:00
meme interpretability
This commit is contained in:
parent
bd426a30ba
commit
349fe802f7
@ -95,6 +95,9 @@
|
|||||||
{:else if term.type === "text"}
|
{:else if term.type === "text"}
|
||||||
<input type="search" use:focusEl on:keydown={handleKey} bind:value={term.text} />
|
<input type="search" use:focusEl on:keydown={handleKey} bind:value={term.text} />
|
||||||
{/if}
|
{/if}
|
||||||
|
{#if term.type === "embedding"}
|
||||||
|
<span>[embedding loaded from URL]</span>
|
||||||
|
{/if}
|
||||||
</li>
|
</li>
|
||||||
{/each}
|
{/each}
|
||||||
</ul>
|
</ul>
|
||||||
@ -148,6 +151,20 @@
|
|||||||
let queryTerms = []
|
let queryTerms = []
|
||||||
let queryCounter = 0
|
let queryCounter = 0
|
||||||
|
|
||||||
|
const decodeFloat16 = uint16 => {
|
||||||
|
const sign = (uint16 & 0x8000) ? -1 : 1
|
||||||
|
const exponent = (uint16 & 0x7C00) >> 10
|
||||||
|
const fraction = uint16 & 0x03FF
|
||||||
|
|
||||||
|
if (exponent === 0) {
|
||||||
|
return sign * Math.pow(2, -14) * (fraction / Math.pow(2, 10))
|
||||||
|
} else if (exponent === 0x1F) {
|
||||||
|
return fraction ? NaN : sign * Infinity
|
||||||
|
} else {
|
||||||
|
return sign * Math.pow(2, exponent - 15) * (1 + fraction / Math.pow(2, 10))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const focusEl = el => el.focus()
|
const focusEl = el => el.focus()
|
||||||
const newTextQuery = (content=null) => {
|
const newTextQuery = (content=null) => {
|
||||||
queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" })
|
queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" })
|
||||||
@ -183,7 +200,7 @@
|
|||||||
let displayedResults = []
|
let displayedResults = []
|
||||||
const runSearch = async () => {
|
const runSearch = async () => {
|
||||||
if (!resultPromise) {
|
if (!resultPromise) {
|
||||||
let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
|
let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))}
|
||||||
queryCounter += 1
|
queryCounter += 1
|
||||||
resultPromise = util.doQuery(args).then(res => {
|
resultPromise = util.doQuery(args).then(res => {
|
||||||
error = null
|
error = null
|
||||||
@ -252,4 +269,10 @@
|
|||||||
newTextQuery(queryStringParams.get("q"))
|
newTextQuery(queryStringParams.get("q"))
|
||||||
runSearch()
|
runSearch()
|
||||||
}
|
}
|
||||||
|
if (queryStringParams.get("e")) {
|
||||||
|
const binaryData = atob(queryStringParams.get("e").replace(/\-/g, "+").replace(/_/g, "/"))
|
||||||
|
const uint16s = new Uint16Array(new Uint8Array(binaryData.split('').map(c => c.charCodeAt(0))).buffer)
|
||||||
|
queryTerms.push({ type: "embedding", weight: 1, sign: "+", embedding: Array.from(uint16s).map(decodeFloat16) })
|
||||||
|
runSearch()
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
63
meme-rater/pca.py
Normal file
63
meme-rater/pca.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import sklearn.decomposition
|
||||||
|
import numpy as np
|
||||||
|
import sqlite3
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import base64
|
||||||
|
|
||||||
|
meme_search_backend = "http://localhost:1707/"
|
||||||
|
memes_url = "https://i.osmarks.net/memes-or-something/"
|
||||||
|
meme_search_url = "https://mse.osmarks.net/?e="
|
||||||
|
db = sqlite3.connect("/srv/mse/data.sqlite3")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
def fetch_all_files():
|
||||||
|
csr = db.execute("SELECT embedding FROM files WHERE embedding IS NOT NULL")
|
||||||
|
x = [ np.frombuffer(row[0], dtype="float16").copy() for row in csr.fetchall() ]
|
||||||
|
csr.close()
|
||||||
|
return np.array(x)
|
||||||
|
|
||||||
|
embeddings = fetch_all_files()
|
||||||
|
|
||||||
|
print("loaded")
|
||||||
|
pca = sklearn.decomposition.PCA()
|
||||||
|
pca.fit(embeddings)
|
||||||
|
print(pca.explained_variance_ratio_)
|
||||||
|
print(pca.components_)
|
||||||
|
|
||||||
|
def emb_url(embedding):
|
||||||
|
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
|
||||||
|
|
||||||
|
async def get_exemplars():
|
||||||
|
with open("components.html", "w") as f:
|
||||||
|
f.write("""<!DOCTYPE html>
|
||||||
|
<title>Embeddings PCA</title>
|
||||||
|
<style>
|
||||||
|
div img {
|
||||||
|
width: 20%
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
<body><h1>Embeddings PCA</h1>""")
|
||||||
|
async with aiohttp.ClientSession():
|
||||||
|
async def lookup(embedding):
|
||||||
|
async with aiohttp.request("POST", meme_search_backend, json={
|
||||||
|
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
||||||
|
"k": 10
|
||||||
|
}) as res:
|
||||||
|
return (await res.json())["matches"]
|
||||||
|
|
||||||
|
for i, (component, explained_variance_ratio) in enumerate(zip(pca.components_, pca.explained_variance_ratio_)):
|
||||||
|
f.write(f"""
|
||||||
|
<h2>Component {i}</h2>
|
||||||
|
<h3>Explained variance {explained_variance_ratio*100:0.2}%</h3>
|
||||||
|
<div>
|
||||||
|
<h4><a href="{emb_url(component)}">Max</a></h4>
|
||||||
|
""")
|
||||||
|
for match in await lookup(component):
|
||||||
|
f.write(f'<img loading="lazy" src="{memes_url+match[1]}">')
|
||||||
|
f.write(f'<h4><a href="{emb_url(-component)}">Min</a></h4>')
|
||||||
|
for match in await lookup(-component):
|
||||||
|
f.write(f'<img loading="lazy" src="{memes_url+match[1]}">')
|
||||||
|
f.write("</div>")
|
||||||
|
|
||||||
|
asyncio.run(get_exemplars())
|
Loading…
Reference in New Issue
Block a user