From 349fe802f77285c7f4345b623a6f7bb7d6ca2adf Mon Sep 17 00:00:00 2001 From: osmarks Date: Wed, 22 May 2024 16:18:45 +0100 Subject: [PATCH] meme interpretability --- clipfront2/src/App.svelte | 25 +++++++++++++++- meme-rater/pca.py | 63 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 meme-rater/pca.py diff --git a/clipfront2/src/App.svelte b/clipfront2/src/App.svelte index df199e8..b0df369 100644 --- a/clipfront2/src/App.svelte +++ b/clipfront2/src/App.svelte @@ -95,6 +95,9 @@ {:else if term.type === "text"} {/if} + {#if term.type === "embedding"} + [embedding loaded from URL] + {/if} {/each} @@ -148,6 +151,20 @@ let queryTerms = [] let queryCounter = 0 + const decodeFloat16 = uint16 => { + const sign = (uint16 & 0x8000) ? -1 : 1 + const exponent = (uint16 & 0x7C00) >> 10 + const fraction = uint16 & 0x03FF + + if (exponent === 0) { + return sign * Math.pow(2, -14) * (fraction / Math.pow(2, 10)) + } else if (exponent === 0x1F) { + return fraction ? NaN : sign * Infinity + } else { + return sign * Math.pow(2, exponent - 15) * (1 + fraction / Math.pow(2, 10)) + } + } + const focusEl = el => el.focus() const newTextQuery = (content=null) => { queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" }) @@ -183,7 +200,7 @@ let displayedResults = [] const runSearch = async () => { if (!resultPromise) { - let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))} + let args = {"terms": queryTerms.map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))} queryCounter += 1 resultPromise = util.doQuery(args).then(res => { error = null @@ -252,4 +269,10 @@ newTextQuery(queryStringParams.get("q")) runSearch() } + if (queryStringParams.get("e")) { + const binaryData = atob(queryStringParams.get("e").replace(/\-/g, "+").replace(/_/g, "/")) + const uint16s = new Uint16Array(new Uint8Array(binaryData.split('').map(c => c.charCodeAt(0))).buffer) + queryTerms.push({ type: "embedding", weight: 1, sign: "+", embedding: Array.from(uint16s).map(decodeFloat16) }) + runSearch() + } diff --git a/meme-rater/pca.py b/meme-rater/pca.py new file mode 100644 index 0000000..9c7a248 --- /dev/null +++ b/meme-rater/pca.py @@ -0,0 +1,63 @@ +import sklearn.decomposition +import numpy as np +import sqlite3 +import asyncio +import aiohttp +import base64 + +meme_search_backend = "http://localhost:1707/" +memes_url = "https://i.osmarks.net/memes-or-something/" +meme_search_url = "https://mse.osmarks.net/?e=" +db = sqlite3.connect("/srv/mse/data.sqlite3") +db.row_factory = sqlite3.Row + +def fetch_all_files(): + csr = db.execute("SELECT embedding FROM files WHERE embedding IS NOT NULL") + x = [ np.frombuffer(row[0], dtype="float16").copy() for row in csr.fetchall() ] + csr.close() + return np.array(x) + +embeddings = fetch_all_files() + +print("loaded") +pca = sklearn.decomposition.PCA() +pca.fit(embeddings) +print(pca.explained_variance_ratio_) +print(pca.components_) + +def emb_url(embedding): + return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") + +async def get_exemplars(): + with open("components.html", "w") as f: + f.write(""" +Embeddings PCA + +

Embeddings PCA

""") + async with aiohttp.ClientSession(): + async def lookup(embedding): + async with aiohttp.request("POST", meme_search_backend, json={ + "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry + "k": 10 + }) as res: + return (await res.json())["matches"] + + for i, (component, explained_variance_ratio) in enumerate(zip(pca.components_, pca.explained_variance_ratio_)): + f.write(f""" +

Component {i}

+

Explained variance {explained_variance_ratio*100:0.2}%

+
+

Max

+""") + for match in await lookup(component): + f.write(f'') + f.write(f'

Min

') + for match in await lookup(-component): + f.write(f'') + f.write("
") + +asyncio.run(get_exemplars()) \ No newline at end of file