- Advanced Mode sliders are generated from PCA on the index. The human-readable labels are generated manually by looking at things .
+ "Useful"/"aesthetic"/"meme" sliders are defined based on an approximate extrapolation of my preferences and may not agree with your own opinions. Results are otherwise decided by the inscrutable whims of a billion-parameter neural network. Please note that large-scale internet data may contain things you do not like.
- {#if util.hardConfig.telemetryEndpoint}
+ {#if util.hardConfig.telemetry_endpoint}
We do not collect personal information. We do collect usage information (associated with a random ID) to improve the ranking algorithms. You can disable this:
diff --git a/clipfront2/src/App.svelte b/clipfront2/src/App.svelte
index b0a7bef..ed5b896 100644
--- a/clipfront2/src/App.svelte
+++ b/clipfront2/src/App.svelte
@@ -154,7 +154,6 @@
+ {/if}
{#if term.type === "image"}
{:else if term.type === "text"}
@@ -200,6 +203,14 @@
{/each}
+ {#if showDebugSwitch}
+
Image Query
@@ -242,6 +253,9 @@
let queryCounter = 0
let config = {}
+ const showDebugSwitch = localStorage.getItem("debugEnabled") === "true"
+ let debugEnabled = false
+
const newTextQuery = (content=null) => {
queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" })
queryTerms = queryTerms
@@ -253,9 +267,14 @@
const runSearch = async () => {
if (!resultPromise) {
+ const friendlyModeQueryOrRandom = friendlyModeQuery ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : [{ embedding: util.randn(config.d_emb, 1 / (config.d_emb ** 0.5)), weight: 1, sign: "+" }]
+ const terms = friendlyMode ?
+ friendlyModeQueryOrRandom.concat(util.hardConfig.friendly_mode_default_terms ?? []) :
+ queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))
let args = {
- "terms": friendlyMode ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] })),
- "include_video": true
+ "terms": terms,
+ "include_video": true,
+ "debug_enabled": debugEnabled
}
util.sendTelemetry("search", {
@@ -314,7 +333,7 @@
type: "predefined_embedding",
predefinedEmbedding: predefinedEmbeddingName,
sign: "+",
- weight: 0.2
+ weight: 1
})
}
queryTerms = queryTerms
diff --git a/clipfront2/src/QueryRefiner.svelte b/clipfront2/src/QueryRefiner.svelte
index 5055d8d..49e2dcf 100644
--- a/clipfront2/src/QueryRefiner.svelte
+++ b/clipfront2/src/QueryRefiner.svelte
@@ -36,26 +36,14 @@
const d_emb = 1152
- const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
- const vecZero = d => new Array(d).fill(0)
- const vecScale = (xs, s) => xs.map(x => x * s)
-
- const boxMuller = () => {
- let x = Math.random()
- let y = Math.random()
- return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
- }
-
- const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)
-
const K = 2
let candidates = []
const select = async candidate => {
candidates = []
- const direction = randn(d_emb, 1 / d_emb)
+ const direction = util.randn(d_emb, 1 / d_emb)
for (let i = -K; i <= K; i++) {
- const newV = vecSum(vecScale(direction, i / K), candidate.vector)
+ const newV = util.vecSum(util.vecScale(direction, i / K), candidate.vector)
candidates.push({ vector: newV, results: null, i: i + K })
}
await Promise.all(candidates.map(async x => {
@@ -67,7 +55,7 @@
console.log(candidates)
}
- select({ vector: randn(d_emb, 1 / d_emb) })
+ select({ vector: util.randn(d_emb, 1 / d_emb) })
const handleKey = ev => {
const num = parseInt(ev.key)
@@ -75,4 +63,4 @@
select(candidates[num - 1])
}
}
-
\ No newline at end of file
+
diff --git a/clipfront2/src/SearchResults.svelte b/clipfront2/src/SearchResults.svelte
index 7a83a83..d0bdddf 100644
--- a/clipfront2/src/SearchResults.svelte
+++ b/clipfront2/src/SearchResults.svelte
@@ -28,6 +28,10 @@
{#key `${queryCounter}${result.file}`}
+ {#if result[5]}
+
{result[0]}
+
{JSON.stringify(result[5])}
+ {/if}
{/key}
{/each}
@@ -72,7 +76,7 @@
export const redrawGrid = async () => {
if (refreshLayout) {
refreshLayout()
- await recomputeScroll()
+ setTimeout(recomputeScroll, 0)
}
}
@@ -82,7 +86,7 @@
}
const handleScroll = () => {
- if (window.scrollY + window.innerHeight >= heightThreshold && pendingImageLoads === 0) {
+ if (window.scrollY + window.innerHeight >= heightThreshold) {
recomputeScroll()
if (window.scrollY + window.innerHeight < heightThreshold) return;
let init = displayedResults.length
diff --git a/clipfront2/src/common.sass b/clipfront2/src/common.sass
index e5ff03c..1c0b390 100644
--- a/clipfront2/src/common.sass
+++ b/clipfront2/src/common.sass
@@ -1,4 +1,11 @@
@use 'sass:color'
+@font-face
+ font-family: 'Iosevka'
+ font-style: normal
+ font-weight: 400
+ src: url(./iosevka.woff2) format('woff2')
+ unicode-range: U+0000-00C0, U+A2-A9, U+AC-AE, U+00D7, U+00F7, U+FEFF, U+FFFD
+
$palette-primary: #3f9b0b
$palette-secondary: #033500
diff --git a/clipfront2/src/iosevka.woff2 b/clipfront2/src/iosevka.woff2
new file mode 100644
index 0000000..a4b67c5
Binary files /dev/null and b/clipfront2/src/iosevka.woff2 differ
diff --git a/clipfront2/src/util.js b/clipfront2/src/util.js
index c7fb529..be0a185 100644
--- a/clipfront2/src/util.js
+++ b/clipfront2/src/util.js
@@ -78,3 +78,15 @@ fetch(config.backend_url).then(x => x.json().then(x => {
serverConfig.set(x)
window.serverConfig = x
}))
+
+export const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
+export const vecZero = d => new Array(d).fill(0)
+export const vecScale = (xs, s) => xs.map(x => x * s)
+
+const boxMuller = () => {
+ let x = Math.random()
+ let y = Math.random()
+ return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
+}
+
+export const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)
diff --git a/clipfront2/static/index.html b/clipfront2/static/index.html
index e0889fa..3da6bd5 100644
--- a/clipfront2/static/index.html
+++ b/clipfront2/static/index.html
@@ -5,7 +5,7 @@
-
Meme Search Engine
+
Nooscope
diff --git a/config2.json b/config2.json
new file mode 100644
index 0000000..7c3a389
--- /dev/null
+++ b/config2.json
@@ -0,0 +1,10 @@
+{
+ "listen_address": "127.0.0.1:5601",
+ "clip_server": "http://100.64.0.10:1708",
+ "descriptor_names": [
+ "Useful",
+ "Meme",
+ "Aesthetic",
+ "Time"
+ ]
+}
diff --git a/diskann/chainq.py b/diskann/chainq.py
new file mode 100644
index 0000000..0161953
--- /dev/null
+++ b/diskann/chainq.py
@@ -0,0 +1,164 @@
+import numpy as np
+import msgpack
+import math
+import torch
+from torch import autograd
+import tqdm
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# https://github.com/una-dinosauria/local-search-quantization/blob/master/src/encodings/encode_chain.jl#L2
+# vectorized somewhat
+def viterbi_encode(
+ out_codes: torch.Tensor, # N x M (ints)
+ vectors: torch.Tensor, # N x D
+ codebooks: torch.Tensor # M x H x D
+):
+ N, D = vectors.shape
+ M, H, D2 = codebooks.shape
+ assert D == D2
+
+ # M x H x N - ||x-c||^2 ignoring x.T @ x component
+ unary_costs = -2 * (codebooks @ vectors.T) + (torch.linalg.norm(codebooks, dim=2) ** 2).unsqueeze(-1)
+ binary_costs = torch.zeros(M - 1, H, H, dtype=torch.float, device=DEVICE)
+ for i in range(M - 1):
+ binary_costs[i] = 2 * codebooks[i] @ codebooks[i + 1].T
+ print("binary costs", binary_costs)
+
+ min_cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
+ min_idx = torch.zeros(M, H, N, dtype=torch.int, device=DEVICE)
+
+ cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
+
+ # forward pass - propagate optimal costs and indices forward
+ for step in tqdm.trange(M - 1):
+ if step > 0:
+ unary_costs[step] += min_cost
+
+ ucost = unary_costs[step]
+
+ # for all possible costs at this step
+ for j in range(H):
+ bcost = binary_costs[step, j].unsqueeze(-1) # independent of N
+ cost = ucost + bcost
+
+ min_values, min_indices = torch.min(cost, dim=0)
+ min_cost[j] = min_values
+ min_idx[step, j] = min_indices
+
+ unary_costs[-1] += min_cost
+
+ # backward pass - propagate optimal indices backwards
+ out_codes[:, -1] = torch.argmin(unary_costs[-1], dim=0)
+ for i in range(M - 2, -1, -1):
+ out_codes[:, i] = min_idx[i][out_codes[:, i + 1], range(N)]
+
+def dims_for(dim, total_dims, m):
+ dims_per_code = total_dims // m
+ relevant_codebooks = [dim // dims_per_code]
+ if relevant_codebooks[-1] < m - 1:
+ relevant_codebooks.append(relevant_codebooks[-1] + 1)
+ return relevant_codebooks
+
+def update_codebooks(transformed_data, codes, h):
+ n, d = transformed_data.shape
+ n2, m = codes.shape
+ assert n == n2
+
+ new_codebook = torch.zeros(m, h, d, dtype=torch.float, device=DEVICE)
+ for dim in tqdm.trange(d):
+ relevant_codebooks = dims_for(dim, d, m)
+ assignment_matrix = torch.zeros(n, len(relevant_codebooks), h, dtype=torch.float, device=DEVICE)
+ indices = (
+ torch.arange(n, dtype=torch.int, device=DEVICE).repeat(len(relevant_codebooks)),
+ torch.arange(len(relevant_codebooks), dtype=torch.int, device=DEVICE).repeat_interleave(n),
+ codes[:, relevant_codebooks].T.flatten()
+ )
+ assignment_matrix[indices] = 1
+ #print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
+ assignment_matrix = assignment_matrix.reshape(n, len(relevant_codebooks) * h)
+ #print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
+ #soln = torch.linalg.lstsq(assignment_matrix, transformed_data[:, dim])[0]
+
+ reg = 1e-3 * torch.eye(len(relevant_codebooks) * h, device=DEVICE)
+ A = assignment_matrix.T @ assignment_matrix + reg
+ b = assignment_matrix.T @ transformed_data[:, dim]
+
+ #print("matrix", A)
+
+ usage = assignment_matrix.sum(dim=0)
+ unused = usage < 1
+
+ print(unused.sum().detach().item())
+ soln = torch.linalg.solve(A, b)
+
+ #print("solution", soln.reshape(len(relevant_codebooks), h))
+ if unused.any():
+ soln[unused] = torch.randn_like(soln[unused])
+
+ new_codebook[relevant_codebooks, :, dim] = soln.reshape(len(relevant_codebooks), h)
+
+ if torch.isnan(new_codebook[relevant_codebooks, :, dim]).any():
+ print("oh no", dim, new_codebook, relevant_codebooks, new_codebook[relevant_codebooks, :, dim])
+ print("--- dim ---", dim)
+ print("- sum per column:", assignment_matrix.sum(dim=0)) # Check if any columns are all zero
+ print("- rank:", torch.linalg.matrix_rank(assignment_matrix))
+ print("- condition number:", torch.linalg.cond(assignment_matrix))
+ raise SystemExit
+
+ return new_codebook
+
+BATCH = 8192
+
+def train_chainq(vectors, m, h, transform, codebooks, n_iters):
+ for i in range(n_iters):
+ transformed_data = vectors @ transform.T
+ codes = torch.zeros(vectors.shape[0], m, dtype=torch.int, device=DEVICE)
+ for i in range(0, vectors.shape[0], BATCH):
+ viterbi_encode(codes[i:i+BATCH], transformed_data[i:i+BATCH], codebooks)
+ print("encoded")
+ #codebooks = update_codebooks(transformed_data, codes, h)
+ print("codebooks updated")
+
+ quantized = torch.zeros_like(vectors, dtype=torch.float, device=DEVICE)
+ for j in range(m):
+ quantized[:] += codebooks[j, codes[:, j]]
+ print("quantized")
+ print((quantized - transformed_data).abs().mean(), transformed_data.abs().mean())
+
+ print("comparing")
+ res = transformed_data.T @ quantized
+
+ print("running SVD...")
+ u, s, vt = torch.linalg.svd(res)
+ print("done.")
+ transform = u @ vt
+ print("regenerated transform")
+
+ return codebooks, transform
+
+with open("opq.msgpack", "rb") as f:
+ data = msgpack.unpackb(f.read())
+
+n_dims = 1152
+dataset = torch.tensor(np.random.permutation(np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32))[:BATCH*1], device=DEVICE)
+
+codebooks = torch.zeros(64, 256, n_dims, dtype=torch.float, device=DEVICE)
+
+centroids = torch.tensor(np.array(data["centroids"]).astype(np.float32).reshape(256, n_dims), device=DEVICE)
+for dim in range(n_dims):
+ relevant_codebooks = dim // 64
+ codebooks[relevant_codebooks, :, dim] = centroids[:, dim]
+
+print(centroids)
+#print("codebooks", codebooks.tolist())
+
+codebooks, transform = train_chainq(dataset, 64, 256, torch.tensor(np.array(data["transform"]).astype(np.float32).reshape(n_dims, n_dims), device=DEVICE), codebooks, 100)
+
+with open("chainq.msgpack", "wb") as f:
+ msgpack.pack({
+ "codebooks": codebooks.cpu().numpy().flatten().tolist(),
+ "transform": transform.cpu().numpy().flatten().tolist(),
+ "n_dims": n_dims,
+ "n_dims_per_code": n_dims // 64
+ }, f)
diff --git a/diskann/src/vector.rs b/diskann/src/vector.rs
index 0ed93e1..df98ace 100644
--- a/diskann/src/vector.rs
+++ b/diskann/src/vector.rs
@@ -45,7 +45,7 @@ impl Vector {
}
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
-pub const SCALE: f32 = 1099511627776.0;
+pub const SCALE: f32 = 4294967296.0;
pub const SCALE_F64: f64 = SCALE as f64;
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 {
diff --git a/faiss_bench_quantizer.py b/faiss_bench_quantizer.py
new file mode 100644
index 0000000..53a11ac
--- /dev/null
+++ b/faiss_bench_quantizer.py
@@ -0,0 +1,131 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+import faiss
+import time
+import numpy as np
+
+def eval_codec(q, xb):
+ t0 = time.time()
+ codes = q.compute_codes(xb)
+ t1 = time.time()
+ xb_decoded = q.decode(codes)
+ recons_err = ((xb - xb_decoded) ** 2).sum(axis=1).mean()
+ print(f"\tencode time: {t1 - t0:.3f} reconstruction error: {recons_err:.3f} ")
+
+
+def eval_quantizer(q, xb, xt, variants=None):
+ if variants is None:
+ variants = [(None, None)]
+ t0 = time.time()
+ q.train(xt)
+ t1 = time.time()
+ train_t = t1 - t0
+ print(f'\ttraining time: {train_t:.3f} s')
+ for name, val in variants:
+ if name is not None:
+ print(f"{name}={val}")
+
+ if isinstance(q, faiss.ProductAdditiveQuantizer):
+ for i in range(q.nsplits):
+ subq = faiss.downcast_Quantizer(q.subquantizer(i))
+ getattr(subq, name)
+ setattr(subq, name, val)
+ else:
+ getattr(q, name) # make sure field exists
+ setattr(q, name, val)
+
+ eval_codec(q, xb)
+
+
+todo = sys.argv[1:]
+
+ds = np.fromfile(todo[0], dtype=np.float16).reshape(-1, 1152).astype(np.float32)
+print(ds)
+del todo[0]
+
+if len(todo) > 0:
+ if todo[0].count("x") == 1:
+ M, nbits = [int(x) for x in todo[0].split("x")]
+ del todo[0]
+ elif todo[0].count("x") == 2:
+ nsplits, Msub, nbits = [int(x) for x in todo[0].split("x")]
+ M = nsplits * Msub
+ del todo[0]
+
+maxtrain = max(100 << nbits, 10**5)
+print(f"eval on {M}x{nbits} maxtrain={maxtrain}")
+
+xb, xt = ds, ds
+
+nb, d = xb.shape
+nt, d = xt.shape
+
+
+# fastest to slowest
+
+if 'lsq-gpu' in todo:
+ lsq = faiss.LocalSearchQuantizer(d, M, nbits)
+ ngpus = faiss.get_num_gpus()
+ lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
+ lsq.verbose = True
+ eval_quantizer(lsq, xb, xt, 'lsq-gpu')
+
+if 'pq' in todo:
+ pq = faiss.ProductQuantizer(d, M, nbits)
+ print("===== PQ")
+ eval_quantizer(pq, xb, xt)
+
+if 'opq' in todo:
+ d2 = ((d + M - 1) // M) * M
+ print("OPQ d2=", d2)
+ opq = faiss.OPQMatrix(d, M, d2)
+ opq.train(xt)
+ xb2 = opq.apply(xb)
+ xt2 = opq.apply(xt)
+ pq = faiss.ProductQuantizer(d2, M, nbits)
+ print("===== PQ")
+ eval_quantizer(pq, xb2, xt2)
+
+if 'prq' in todo:
+ print(f"===== PRQ{nsplits}x{Msub}x{nbits}")
+ prq = faiss.ProductResidualQuantizer(d, nsplits, Msub, nbits)
+ variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
+ eval_quantizer(prq, xb, xt, variants=variants)
+
+if 'plsq' in todo:
+ print(f"===== PLSQ{nsplits}x{Msub}x{nbits}")
+ plsq = faiss.ProductLocalSearchQuantizer(d, nsplits, Msub, nbits)
+ variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
+ eval_quantizer(plsq, xb, xt, variants=variants)
+
+if 'rq' in todo:
+ print("===== RQ")
+ rq = faiss.ResidualQuantizer(d, M, nbits, )
+ rq.max_beam_size
+ rq.max_beam_size = 30 # for compatibility with older runs
+ # rq.train_type = faiss.ResidualQuantizer.Train_default
+ # rq.verbose = True
+ variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
+ eval_quantizer(rq, xb, xt, variants=variants)
+
+if 'rq_lut' in todo:
+ print("===== RQ")
+ rq = faiss.ResidualQuantizer(d, M, nbits, )
+ rq.max_beam_size
+ rq.max_beam_size = 30 # for compatibility with older runs
+ rq.use_beam_LUT
+ rq.use_beam_LUT = 1
+ # rq.train_type = faiss.ResidualQuantizer.Train_default
+ # rq.verbose = True
+ variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32, 64)]
+ eval_quantizer(rq, xb, xt, variants=variants)
+
+if 'lsq' in todo:
+ print("===== LSQ")
+ lsq = faiss.LocalSearchQuantizer(d, M, nbits)
+ variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
+ eval_quantizer(lsq, xb, xt, variants=variants)
diff --git a/frontend_config.json b/frontend_config.json
index 7f27212..992a413 100644
--- a/frontend_config.json
+++ b/frontend_config.json
@@ -1,8 +1,12 @@
{
- "backend_url": "http://localhost:5601/",
+ "backend_url": "/backend",
"image_path": "",
"thumb_path": null,
"description": "Organizing the world's memes.",
"name": "Nooscope",
- "telemetry_endpoint": "/telemetry"
+ "telemetry_endpoint": "/telemetry",
+ "friendly_mode_default_terms": [
+ { "predefined_embedding": "Meme", "weight": 1, "sign": "+" },
+ { "predefined_embedding": "Aesthetic", "weight": 0.2, "sign": "+" }
+ ]
}
diff --git a/genseahash.py b/genseahash.py
new file mode 100644
index 0000000..5f21fdc
--- /dev/null
+++ b/genseahash.py
@@ -0,0 +1,4 @@
+import seahash, sys
+
+with open(sys.argv[1], "rb") as f:
+ print(seahash.hash(f.read()))
diff --git a/meme-rater/active_learning.py b/meme-rater/active_learning.py
index dac462a..2d8635f 100644
--- a/meme-rater/active_learning.py
+++ b/meme-rater/active_learning.py
@@ -31,6 +31,7 @@ model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
+model.eval()
files = shared.fetch_all_files()
variance = {}
@@ -56,4 +57,4 @@ with torch.inference_mode():
top = sorted(variance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
- json.dump(top[:100], f)
+ json.dump(top[:50], f)
diff --git a/meme-rater/active_learning_find_top.py b/meme-rater/active_learning_find_top.py
index b12b435..d1411c5 100644
--- a/meme-rater/active_learning_find_top.py
+++ b/meme-rater/active_learning_find_top.py
@@ -52,7 +52,7 @@ channel = int(sys.argv[2])
percentile = float(sys.argv[3])
output_pairs = int(sys.argv[4])
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
-top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
+top = sorted(((filename, score) for filename, score in results.items()), key=lambda x: x[1][channel], reverse=True)
select_from = top[:int(len(top) * percentile)]
out = []
diff --git a/meme-rater/active_learning_gradients.py b/meme-rater/active_learning_gradients.py
index 13909df..a349e5f 100644
--- a/meme-rater/active_learning_gradients.py
+++ b/meme-rater/active_learning_gradients.py
@@ -13,14 +13,14 @@ import sys
from model import Config, BradleyTerry
import shared
-batch_size = 128
+batch_size = 32
num_pairs = batch_size * 1024
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
- n_ensemble=1,
+ n_ensemble=16,
device=device,
dtype=torch.float32,
output_channels=3,
@@ -32,6 +32,7 @@ model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
+model.eval()
files = shared.fetch_all_files()
importance = {}
@@ -41,8 +42,8 @@ buffers = {k: v.detach() for k, v in model.named_buffers()}
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
def compute_loss(params, buffers, sample, target):
- batch = sample.unsqueeze(0)
- targets = target.unsqueeze(0)
+ batch = sample.unsqueeze(1)
+ targets = target.unsqueeze(1).expand((config.n_ensemble, 1, config.output_channels))
predictions = functional_call(model, (params, buffers), (batch,))
loss = F.binary_cross_entropy(predictions, targets)
@@ -75,4 +76,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
top = sorted(importance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
- json.dump(top[:256], f)
+ json.dump(top[:50], f)
diff --git a/meme-rater/ensemble_to_wide_model.py b/meme-rater/ensemble_to_wide_model.py
index b357c19..b492857 100644
--- a/meme-rater/ensemble_to_wide_model.py
+++ b/meme-rater/ensemble_to_wide_model.py
@@ -31,6 +31,11 @@ params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
+torch.random.manual_seed(1)
+
+for x in model.ensemble.models:
+ x.output.bias.data.fill_(0)
+
out_layers = []
out_bias = []
@@ -50,7 +55,7 @@ with torch.inference_mode():
downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone()
for i in range(10):
- input = torch.randn(3, config.d_emb)
+ input = torch.randn(4, config.d_emb)
ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T
r_result = input
for (layer, bias) in zip(out_layers, out_bias):
@@ -58,13 +63,14 @@ with torch.inference_mode():
print(r_result.shape, bias.shape)
r_result = F.silu(r_result)
r_result = torch.matmul(downprojection, r_result) / config.n_ensemble
- bias_diff = torch.mean(r_result - ground_truth_result)
- model_issue_diff = torch.max(torch.abs(r_result - bias_diff - ground_truth_result))
- print(model_issue_diff)
+ error = torch.mean(r_result - ground_truth_result)
+ print(error)
+ assert error.detach().cpu().numpy() < 1e-4
print("test vector:")
- print(input.flatten().tolist())
+ #print(input.flatten().tolist())
print("ground truth result:")
+ print(ground_truth_result.shape)
print(ground_truth_result.T.flatten().tolist())
save_file({
diff --git a/meme-rater/model.py b/meme-rater/model.py
index 49cf551..b548949 100644
--- a/meme-rater/model.py
+++ b/meme-rater/model.py
@@ -35,7 +35,7 @@ class Ensemble(nn.Module):
# model batch
def forward(self, embs):
- return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
+ return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_channels
class BradleyTerry(nn.Module):
def __init__(self, config):
diff --git a/meme-rater/rater_server.py b/meme-rater/rater_server.py
index 9ee2080..36a95a6 100644
--- a/meme-rater/rater_server.py
+++ b/meme-rater/rater_server.py
@@ -51,19 +51,25 @@ async def index(request):