1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-26 20:53:09 +00:00

release version

This commit is contained in:
osmarks 2025-01-24 09:24:28 +00:00
parent 3852d0078d
commit ee23b81444
34 changed files with 774 additions and 147 deletions

5
.gitignore vendored
View File

@ -22,3 +22,8 @@ index
queries.txt
*.zst
.safetensors
*/static/*.woff2
flamegraph.svg
*.jsonl
*.safetensors
perf.data

41
Cargo.lock generated
View File

@ -416,6 +416,18 @@ version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@ -1117,6 +1129,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "futures-channel"
version = "0.3.31"
@ -2046,6 +2064,7 @@ dependencies = [
"axum",
"base64 0.22.1",
"bitcode",
"bitvec",
"bytemuck",
"candle-core",
"chrono",
@ -2067,6 +2086,7 @@ dependencies = [
"itertools 0.13.0",
"json5",
"lazy_static",
"matrixmultiply",
"maud",
"memmap2",
"mimalloc",
@ -2850,6 +2870,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rand"
version = "0.8.5"
@ -3873,6 +3899,12 @@ dependencies = [
"version-compare",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "target-lexicon"
version = "0.12.16"
@ -4732,6 +4764,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]
[[package]]
name = "yoke"
version = "0.7.5"

View File

@ -62,6 +62,8 @@ monoio = "0.2"
hyper = "1"
monoio-compat = { version = "0.2", features = ["hyper"] }
http-body-util = "0.1"
matrixmultiply = "0.3"
bitvec = "1"
[[bin]]
name = "reddit-dump"

View File

@ -9,15 +9,15 @@
<div class="about">
<p>
Welcome to {util.hardConfig.name} by <a href="https://osmarks.net/">osmarks.net Computational Memetics</a>. {util.hardConfig.name} searches images using semantic image/text embedding models. In general, search by thinking of what caption your desired image might have been given by random people on the internet. The model currently in use can read text fairly well and understands moderately abstract properties of images, but is limited to English and case-insensitive.
Welcome to {util.hardConfig.name} by <a href="https://osmarks.net/">osmarks.net</a> Computational Memetics. {util.hardConfig.name} searches images using semantic image/text embedding models (refer to <a href="https://arxiv.org/abs/2303.15343">https://arxiv.org/abs/2303.15343</a>, <a href="https://arxiv.org/abs/2103.00020">https://arxiv.org/abs/2103.00020</a>). In general, search by thinking of what caption your desired image might have been given by random people on the internet. The model currently in use can read text fairly well and understands moderately abstract properties of images, but is limited to English and case-insensitive.
</p>
<p>
Advanced Mode sliders are generated from PCA on the index. The human-readable labels are generated manually by <a href="https://datasets.osmarks.net/components.html">looking at things</a>.
"Useful"/"aesthetic"/"meme" sliders are defined based on an approximate extrapolation of my preferences and may not agree with your own opinions. Results are otherwise decided by the inscrutable whims of a billion-parameter neural network. Please note that large-scale internet data may contain things you do not like.
</p>
<p>
The code is open-source and available on <a href="https://github.com/osmarks/meme-search-engine/">GitHub.</a>
The code is open-source and available on <a href="https://github.com/osmarks/meme-search-engine/">GitHub</a>.
</p>
{#if util.hardConfig.telemetryEndpoint}
{#if util.hardConfig.telemetry_endpoint}
<h2>Privacy</h2>
<p>
We do not collect personal information. We do collect usage information (associated with a random ID) to improve the ranking algorithms. You can disable this:

View File

@ -154,7 +154,6 @@
<div class="right">
<NavItem page="advanced">Advanced</NavItem>
<NavItem page="about">About</NavItem>
<NavItem page="refine">Refine</NavItem>
</div>
</nav>
@ -185,7 +184,11 @@
<option>+</option>
<option>-</option>
</select>
<input type="range" min="0" max="2" bind:value={term.weight} step="0.01">
{#if debugEnabled}
<input type="number" bind:value={term.weight} step="0.01">
{:else}
<input type="range" min="0" max="2" bind:value={term.weight} step="0.01">
{/if}
{#if term.type === "image"}
<span>{term.file.name}</span>
{:else if term.type === "text"}
@ -200,6 +203,14 @@
</li>
{/each}
</ul>
{#if showDebugSwitch}
<div class="ctrlbar">
<input type="checkbox" bind:checked={debugEnabled} id="debug" />
<label for="debug">Debug</label>
{#if debugEnabled}
{/if}
</div>
{/if}
<div class="ctrlbar">
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
<button on:click={pickFile}>Image Query</button>
@ -242,6 +253,9 @@
let queryCounter = 0
let config = {}
const showDebugSwitch = localStorage.getItem("debugEnabled") === "true"
let debugEnabled = false
const newTextQuery = (content=null) => {
queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" })
queryTerms = queryTerms
@ -253,9 +267,14 @@
const runSearch = async () => {
if (!resultPromise) {
const friendlyModeQueryOrRandom = friendlyModeQuery ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : [{ embedding: util.randn(config.d_emb, 1 / (config.d_emb ** 0.5)), weight: 1, sign: "+" }]
const terms = friendlyMode ?
friendlyModeQueryOrRandom.concat(util.hardConfig.friendly_mode_default_terms ?? []) :
queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))
let args = {
"terms": friendlyMode ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] })),
"include_video": true
"terms": terms,
"include_video": true,
"debug_enabled": debugEnabled
}
util.sendTelemetry("search", {
@ -314,7 +333,7 @@
type: "predefined_embedding",
predefinedEmbedding: predefinedEmbeddingName,
sign: "+",
weight: 0.2
weight: 1
})
}
queryTerms = queryTerms

View File

@ -36,26 +36,14 @@
const d_emb = 1152
const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
const vecZero = d => new Array(d).fill(0)
const vecScale = (xs, s) => xs.map(x => x * s)
const boxMuller = () => {
let x = Math.random()
let y = Math.random()
return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
}
const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)
const K = 2
let candidates = []
const select = async candidate => {
candidates = []
const direction = randn(d_emb, 1 / d_emb)
const direction = util.randn(d_emb, 1 / d_emb)
for (let i = -K; i <= K; i++) {
const newV = vecSum(vecScale(direction, i / K), candidate.vector)
const newV = util.vecSum(util.vecScale(direction, i / K), candidate.vector)
candidates.push({ vector: newV, results: null, i: i + K })
}
await Promise.all(candidates.map(async x => {
@ -67,7 +55,7 @@
console.log(candidates)
}
select({ vector: randn(d_emb, 1 / d_emb) })
select({ vector: util.randn(d_emb, 1 / d_emb) })
const handleKey = ev => {
const num = parseInt(ev.key)
@ -75,4 +63,4 @@
select(candidates[num - 1])
}
}
</script>
</script>

View File

@ -28,6 +28,10 @@
{#key `${queryCounter}${result.file}`}
<div class="result">
<ResultImage {result} {results} {updateCounter} {redrawGrid} constrainBy="width" />
{#if result[5]} <!-- debug info -->
<div>{result[0]}</div>
<div>{JSON.stringify(result[5])}</div>
{/if}
</div>
{/key}
{/each}
@ -72,7 +76,7 @@
export const redrawGrid = async () => {
if (refreshLayout) {
refreshLayout()
await recomputeScroll()
setTimeout(recomputeScroll, 0)
}
}
@ -82,7 +86,7 @@
}
const handleScroll = () => {
if (window.scrollY + window.innerHeight >= heightThreshold && pendingImageLoads === 0) {
if (window.scrollY + window.innerHeight >= heightThreshold) {
recomputeScroll()
if (window.scrollY + window.innerHeight < heightThreshold) return;
let init = displayedResults.length

View File

@ -1,4 +1,11 @@
@use 'sass:color'
@font-face
font-family: 'Iosevka'
font-style: normal
font-weight: 400
src: url(./iosevka.woff2) format('woff2')
unicode-range: U+0000-00C0, U+A2-A9, U+AC-AE, U+00D7, U+00F7, U+FEFF, U+FFFD
$palette-primary: #3f9b0b
$palette-secondary: #033500

Binary file not shown.

View File

@ -78,3 +78,15 @@ fetch(config.backend_url).then(x => x.json().then(x => {
serverConfig.set(x)
window.serverConfig = x
}))
export const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
export const vecZero = d => new Array(d).fill(0)
export const vecScale = (xs, s) => xs.map(x => x * s)
const boxMuller = () => {
let x = Math.random()
let y = Math.random()
return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
}
export const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)

View File

@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, minimum-scale=1, initial-scale=1, user-scalable=yes">
<meta name="description" content="Organizing the world's memes.">
<title>Meme Search Engine</title>
<title>Nooscope</title>
</style>
<link rel="stylesheet" href="app.css">
<link rel="icon" type="image/png" href="logo.png">

10
config2.json Normal file
View File

@ -0,0 +1,10 @@
{
"listen_address": "127.0.0.1:5601",
"clip_server": "http://100.64.0.10:1708",
"descriptor_names": [
"Useful",
"Meme",
"Aesthetic",
"Time"
]
}

164
diskann/chainq.py Normal file
View File

@ -0,0 +1,164 @@
import numpy as np
import msgpack
import math
import torch
from torch import autograd
import tqdm
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# https://github.com/una-dinosauria/local-search-quantization/blob/master/src/encodings/encode_chain.jl#L2
# vectorized somewhat
def viterbi_encode(
out_codes: torch.Tensor, # N x M (ints)
vectors: torch.Tensor, # N x D
codebooks: torch.Tensor # M x H x D
):
N, D = vectors.shape
M, H, D2 = codebooks.shape
assert D == D2
# M x H x N - ||x-c||^2 ignoring x.T @ x component
unary_costs = -2 * (codebooks @ vectors.T) + (torch.linalg.norm(codebooks, dim=2) ** 2).unsqueeze(-1)
binary_costs = torch.zeros(M - 1, H, H, dtype=torch.float, device=DEVICE)
for i in range(M - 1):
binary_costs[i] = 2 * codebooks[i] @ codebooks[i + 1].T
print("binary costs", binary_costs)
min_cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
min_idx = torch.zeros(M, H, N, dtype=torch.int, device=DEVICE)
cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
# forward pass - propagate optimal costs and indices forward
for step in tqdm.trange(M - 1):
if step > 0:
unary_costs[step] += min_cost
ucost = unary_costs[step]
# for all possible costs at this step
for j in range(H):
bcost = binary_costs[step, j].unsqueeze(-1) # independent of N
cost = ucost + bcost
min_values, min_indices = torch.min(cost, dim=0)
min_cost[j] = min_values
min_idx[step, j] = min_indices
unary_costs[-1] += min_cost
# backward pass - propagate optimal indices backwards
out_codes[:, -1] = torch.argmin(unary_costs[-1], dim=0)
for i in range(M - 2, -1, -1):
out_codes[:, i] = min_idx[i][out_codes[:, i + 1], range(N)]
def dims_for(dim, total_dims, m):
dims_per_code = total_dims // m
relevant_codebooks = [dim // dims_per_code]
if relevant_codebooks[-1] < m - 1:
relevant_codebooks.append(relevant_codebooks[-1] + 1)
return relevant_codebooks
def update_codebooks(transformed_data, codes, h):
n, d = transformed_data.shape
n2, m = codes.shape
assert n == n2
new_codebook = torch.zeros(m, h, d, dtype=torch.float, device=DEVICE)
for dim in tqdm.trange(d):
relevant_codebooks = dims_for(dim, d, m)
assignment_matrix = torch.zeros(n, len(relevant_codebooks), h, dtype=torch.float, device=DEVICE)
indices = (
torch.arange(n, dtype=torch.int, device=DEVICE).repeat(len(relevant_codebooks)),
torch.arange(len(relevant_codebooks), dtype=torch.int, device=DEVICE).repeat_interleave(n),
codes[:, relevant_codebooks].T.flatten()
)
assignment_matrix[indices] = 1
#print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
assignment_matrix = assignment_matrix.reshape(n, len(relevant_codebooks) * h)
#print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
#soln = torch.linalg.lstsq(assignment_matrix, transformed_data[:, dim])[0]
reg = 1e-3 * torch.eye(len(relevant_codebooks) * h, device=DEVICE)
A = assignment_matrix.T @ assignment_matrix + reg
b = assignment_matrix.T @ transformed_data[:, dim]
#print("matrix", A)
usage = assignment_matrix.sum(dim=0)
unused = usage < 1
print(unused.sum().detach().item())
soln = torch.linalg.solve(A, b)
#print("solution", soln.reshape(len(relevant_codebooks), h))
if unused.any():
soln[unused] = torch.randn_like(soln[unused])
new_codebook[relevant_codebooks, :, dim] = soln.reshape(len(relevant_codebooks), h)
if torch.isnan(new_codebook[relevant_codebooks, :, dim]).any():
print("oh no", dim, new_codebook, relevant_codebooks, new_codebook[relevant_codebooks, :, dim])
print("--- dim ---", dim)
print("- sum per column:", assignment_matrix.sum(dim=0)) # Check if any columns are all zero
print("- rank:", torch.linalg.matrix_rank(assignment_matrix))
print("- condition number:", torch.linalg.cond(assignment_matrix))
raise SystemExit
return new_codebook
BATCH = 8192
def train_chainq(vectors, m, h, transform, codebooks, n_iters):
for i in range(n_iters):
transformed_data = vectors @ transform.T
codes = torch.zeros(vectors.shape[0], m, dtype=torch.int, device=DEVICE)
for i in range(0, vectors.shape[0], BATCH):
viterbi_encode(codes[i:i+BATCH], transformed_data[i:i+BATCH], codebooks)
print("encoded")
#codebooks = update_codebooks(transformed_data, codes, h)
print("codebooks updated")
quantized = torch.zeros_like(vectors, dtype=torch.float, device=DEVICE)
for j in range(m):
quantized[:] += codebooks[j, codes[:, j]]
print("quantized")
print((quantized - transformed_data).abs().mean(), transformed_data.abs().mean())
print("comparing")
res = transformed_data.T @ quantized
print("running SVD...")
u, s, vt = torch.linalg.svd(res)
print("done.")
transform = u @ vt
print("regenerated transform")
return codebooks, transform
with open("opq.msgpack", "rb") as f:
data = msgpack.unpackb(f.read())
n_dims = 1152
dataset = torch.tensor(np.random.permutation(np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32))[:BATCH*1], device=DEVICE)
codebooks = torch.zeros(64, 256, n_dims, dtype=torch.float, device=DEVICE)
centroids = torch.tensor(np.array(data["centroids"]).astype(np.float32).reshape(256, n_dims), device=DEVICE)
for dim in range(n_dims):
relevant_codebooks = dim // 64
codebooks[relevant_codebooks, :, dim] = centroids[:, dim]
print(centroids)
#print("codebooks", codebooks.tolist())
codebooks, transform = train_chainq(dataset, 64, 256, torch.tensor(np.array(data["transform"]).astype(np.float32).reshape(n_dims, n_dims), device=DEVICE), codebooks, 100)
with open("chainq.msgpack", "wb") as f:
msgpack.pack({
"codebooks": codebooks.cpu().numpy().flatten().tolist(),
"transform": transform.cpu().numpy().flatten().tolist(),
"n_dims": n_dims,
"n_dims_per_code": n_dims // 64
}, f)

View File

@ -45,7 +45,7 @@ impl Vector {
}
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
pub const SCALE: f32 = 1099511627776.0;
pub const SCALE: f32 = 4294967296.0;
pub const SCALE_F64: f64 = SCALE as f64;
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 {

131
faiss_bench_quantizer.py Normal file
View File

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import faiss
import time
import numpy as np
def eval_codec(q, xb):
t0 = time.time()
codes = q.compute_codes(xb)
t1 = time.time()
xb_decoded = q.decode(codes)
recons_err = ((xb - xb_decoded) ** 2).sum(axis=1).mean()
print(f"\tencode time: {t1 - t0:.3f} reconstruction error: {recons_err:.3f} ")
def eval_quantizer(q, xb, xt, variants=None):
if variants is None:
variants = [(None, None)]
t0 = time.time()
q.train(xt)
t1 = time.time()
train_t = t1 - t0
print(f'\ttraining time: {train_t:.3f} s')
for name, val in variants:
if name is not None:
print(f"{name}={val}")
if isinstance(q, faiss.ProductAdditiveQuantizer):
for i in range(q.nsplits):
subq = faiss.downcast_Quantizer(q.subquantizer(i))
getattr(subq, name)
setattr(subq, name, val)
else:
getattr(q, name) # make sure field exists
setattr(q, name, val)
eval_codec(q, xb)
todo = sys.argv[1:]
ds = np.fromfile(todo[0], dtype=np.float16).reshape(-1, 1152).astype(np.float32)
print(ds)
del todo[0]
if len(todo) > 0:
if todo[0].count("x") == 1:
M, nbits = [int(x) for x in todo[0].split("x")]
del todo[0]
elif todo[0].count("x") == 2:
nsplits, Msub, nbits = [int(x) for x in todo[0].split("x")]
M = nsplits * Msub
del todo[0]
maxtrain = max(100 << nbits, 10**5)
print(f"eval on {M}x{nbits} maxtrain={maxtrain}")
xb, xt = ds, ds
nb, d = xb.shape
nt, d = xt.shape
# fastest to slowest
if 'lsq-gpu' in todo:
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
ngpus = faiss.get_num_gpus()
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
lsq.verbose = True
eval_quantizer(lsq, xb, xt, 'lsq-gpu')
if 'pq' in todo:
pq = faiss.ProductQuantizer(d, M, nbits)
print("===== PQ")
eval_quantizer(pq, xb, xt)
if 'opq' in todo:
d2 = ((d + M - 1) // M) * M
print("OPQ d2=", d2)
opq = faiss.OPQMatrix(d, M, d2)
opq.train(xt)
xb2 = opq.apply(xb)
xt2 = opq.apply(xt)
pq = faiss.ProductQuantizer(d2, M, nbits)
print("===== PQ")
eval_quantizer(pq, xb2, xt2)
if 'prq' in todo:
print(f"===== PRQ{nsplits}x{Msub}x{nbits}")
prq = faiss.ProductResidualQuantizer(d, nsplits, Msub, nbits)
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
eval_quantizer(prq, xb, xt, variants=variants)
if 'plsq' in todo:
print(f"===== PLSQ{nsplits}x{Msub}x{nbits}")
plsq = faiss.ProductLocalSearchQuantizer(d, nsplits, Msub, nbits)
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
eval_quantizer(plsq, xb, xt, variants=variants)
if 'rq' in todo:
print("===== RQ")
rq = faiss.ResidualQuantizer(d, M, nbits, )
rq.max_beam_size
rq.max_beam_size = 30 # for compatibility with older runs
# rq.train_type = faiss.ResidualQuantizer.Train_default
# rq.verbose = True
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
eval_quantizer(rq, xb, xt, variants=variants)
if 'rq_lut' in todo:
print("===== RQ")
rq = faiss.ResidualQuantizer(d, M, nbits, )
rq.max_beam_size
rq.max_beam_size = 30 # for compatibility with older runs
rq.use_beam_LUT
rq.use_beam_LUT = 1
# rq.train_type = faiss.ResidualQuantizer.Train_default
# rq.verbose = True
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32, 64)]
eval_quantizer(rq, xb, xt, variants=variants)
if 'lsq' in todo:
print("===== LSQ")
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
eval_quantizer(lsq, xb, xt, variants=variants)

View File

@ -1,8 +1,12 @@
{
"backend_url": "http://localhost:5601/",
"backend_url": "/backend",
"image_path": "",
"thumb_path": null,
"description": "Organizing the world's memes.",
"name": "Nooscope",
"telemetry_endpoint": "/telemetry"
"telemetry_endpoint": "/telemetry",
"friendly_mode_default_terms": [
{ "predefined_embedding": "Meme", "weight": 1, "sign": "+" },
{ "predefined_embedding": "Aesthetic", "weight": 0.2, "sign": "+" }
]
}

4
genseahash.py Normal file
View File

@ -0,0 +1,4 @@
import seahash, sys
with open(sys.argv[1], "rb") as f:
print(seahash.hash(f.read()))

View File

@ -31,6 +31,7 @@ model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
model.eval()
files = shared.fetch_all_files()
variance = {}
@ -56,4 +57,4 @@ with torch.inference_mode():
top = sorted(variance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
json.dump(top[:100], f)
json.dump(top[:50], f)

View File

@ -52,7 +52,7 @@ channel = int(sys.argv[2])
percentile = float(sys.argv[3])
output_pairs = int(sys.argv[4])
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
top = sorted(((filename, score) for filename, score in results.items()), key=lambda x: x[1][channel], reverse=True)
select_from = top[:int(len(top) * percentile)]
out = []

View File

@ -13,14 +13,14 @@ import sys
from model import Config, BradleyTerry
import shared
batch_size = 128
batch_size = 32
num_pairs = batch_size * 1024
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
output_channels=3,
@ -32,6 +32,7 @@ model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
model.eval()
files = shared.fetch_all_files()
importance = {}
@ -41,8 +42,8 @@ buffers = {k: v.detach() for k, v in model.named_buffers()}
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
batch = sample.unsqueeze(1)
targets = target.unsqueeze(1).expand((config.n_ensemble, 1, config.output_channels))
predictions = functional_call(model, (params, buffers), (batch,))
loss = F.binary_cross_entropy(predictions, targets)
@ -75,4 +76,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
top = sorted(importance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f:
json.dump(top[:256], f)
json.dump(top[:50], f)

View File

@ -31,6 +31,11 @@ params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
torch.random.manual_seed(1)
for x in model.ensemble.models:
x.output.bias.data.fill_(0)
out_layers = []
out_bias = []
@ -50,7 +55,7 @@ with torch.inference_mode():
downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone()
for i in range(10):
input = torch.randn(3, config.d_emb)
input = torch.randn(4, config.d_emb)
ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T
r_result = input
for (layer, bias) in zip(out_layers, out_bias):
@ -58,13 +63,14 @@ with torch.inference_mode():
print(r_result.shape, bias.shape)
r_result = F.silu(r_result)
r_result = torch.matmul(downprojection, r_result) / config.n_ensemble
bias_diff = torch.mean(r_result - ground_truth_result)
model_issue_diff = torch.max(torch.abs(r_result - bias_diff - ground_truth_result))
print(model_issue_diff)
error = torch.mean(r_result - ground_truth_result)
print(error)
assert error.detach().cpu().numpy() < 1e-4
print("test vector:")
print(input.flatten().tolist())
#print(input.flatten().tolist())
print("ground truth result:")
print(ground_truth_result.shape)
print(ground_truth_result.T.flatten().tolist())
save_file({

View File

@ -35,7 +35,7 @@ class Ensemble(nn.Module):
# model batch
def forward(self, embs):
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_channels
class BradleyTerry(nn.Module):
def __init__(self, config):

View File

@ -51,19 +51,25 @@ async def index(request):
<form action="/rate" method="POST">
<table>
<tr>
<td><input type="radio" name="rating-useful" value="1+" id="rq1p"> <label for="rq1p">LHS is much better (useful)</label></td>
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
<td><input type="radio" name="rating-useful" value="2+" id="rq2p"> <label for="rq2p">LHS is much better (useful)</label></td>
</tr>
<tr>
<td><input type="radio" name="rating-meme" value="1+" id="rm1p"> <label for="rm1p">LHS is much better (memetically)</label></td>
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
<td><input type="radio" name="rating-meme" value="2+" id="rm2p"> <label for="rm2p">LHS is much better (memetically)</label></td>
</tr>
<tr>
<td><input type="radio" name="rating-aesthetic" value="1+" id="ra1p"> <label for="ra1p">LHS is much better (aesthetically)</label></td>
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
<td><input type="radio" name="rating-aesthetic" value="2+" id="ra2p"> <label for="ra2p">LHS is much better (aesthetically)</label></td>
</td>
</table>
@ -82,35 +88,29 @@ async def index(request):
document.querySelector("form").submit()
}}
}}
const keys = {{
"q": "rq1p",
"w": "rq1",
"e": "rqe",
"r": "rq2",
"t": "rq2p",
"a": "rm1p",
"s": "rm1",
"d": "rme",
"f": "rm2",
"g": "rm2p",
"z": "ra1p",
"x": "ra1",
"c": "rae",
"v": "ra2",
"b": "ra2p",
}}
document.addEventListener("keypress", function(event) {{
if (event.key === "q") {{
document.querySelector("input[name='rating-useful'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "w") {{
document.querySelector("input[name='rating-useful'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "e") {{
document.querySelector("input[name='rating-useful'][value='2']").checked = true
commitIfReady()
}} else if (event.key === "a") {{
document.querySelector("input[name='rating-meme'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "s") {{
document.querySelector("input[name='rating-meme'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "d") {{
document.querySelector("input[name='rating-meme'][value='2']").checked = true
commitIfReady()
}} else if (event.key === "z") {{
document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "x") {{
document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "c") {{
document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true
commitIfReady()
const key = keys[event.key]
if (key) {{
document.getElementById(key).checked = true
}}
commitIfReady()
}});
</script>
</body>

View File

@ -20,27 +20,35 @@ def fetch_embedding(filename):
csr.close()
return x.copy() # PyTorch complains otherwise due to bad
def map_rating(rating, uncertainty=0.05):
def map_rating(rating):
def map_one(rating):
match rating:
case "1": # meme 1 is better
return 1 - uncertainty
return 0.9
case "2":
return uncertainty
return 0.1
case "2+" | "2p":
return 0.3
case "1+" | "1p":
return 0.7
case "eq":
return 0.5
case _: raise ValueError("invalid rating, please fix")
return np.array([map_one(r) for r in rating.split(",")])
def fetch_ratings():
def fetch_ratings(sets):
trains = defaultdict(list)
validations = defaultdict(list)
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
its = set()
for meme1, meme2, rating, iteration in csr.fetchall():
if iteration not in its:
print(iteration)
its.add(iteration)
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
csr.close()
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
return list(x[1] for x in sorted(trains.items()) if str(x[0]) in sets), list(x[1] for x in sorted(validations.items()) if str(x[0]) in sets)
def generate_random_permutations(x, n):
out = []

View File

@ -9,11 +9,12 @@ import time
from tqdm import tqdm
import math
from dataclasses import dataclass, asdict
import sys
from model import Config as ModelConfig, BradleyTerry
import shared
trains, validations = shared.fetch_ratings()
trains, validations = shared.fetch_ratings(sys.argv[1:])
for train, validation in zip(trains, validations):
print(len(train), len(validation))
@ -36,11 +37,11 @@ config = TrainConfig(
n_ensemble=16,
device=device,
dtype=torch.float32,
dropout=0.1,
dropout=0.0,
output_channels=3
),
lr=3e-4,
weight_decay=0.2,
weight_decay=0.0,
batch_size=1,
epochs=5,
compile=False,

View File

@ -1,12 +1,12 @@
import torch
import pyarrow as pa
import numpy as np
torch.set_float32_matmul_precision("high")
with pa.memory_map("../../sample_1m.arrow", "r") as source:
loaded_arrays = pa.ipc.open_file(source).read_all()
loaded_arrays = np.memmap("embeddings.bin", dtype=np.float16).reshape(-1, 1152)
loaded_arrays_permutation = np.random.permutation(len(loaded_arrays))
train_split = 0.8
def ckpt_path(steps):
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"

View File

@ -6,9 +6,10 @@ import json
import time
from tqdm import tqdm
from dataclasses import dataclass, asdict
import math
from model import SAEConfig, SAE
from shared import train_split, loaded_arrays, ckpt_path
from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation
device = "cuda"
@ -24,7 +25,7 @@ class TrainConfig:
config = TrainConfig(
model=SAEConfig(
d_emb=1152,
d_hidden=65536,
d_hidden=262144,
top_k=128,
device=device,
dtype=torch.float32,
@ -33,7 +34,7 @@ config = TrainConfig(
lr=3e-4,
weight_decay=0.0,
batch_size=64,
epochs=5,
epochs=1,
compile=True,
)
@ -81,7 +82,7 @@ with open(logfile, "w") as log:
batch = []
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
for batch_start in t:
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
batch = numpy.stack([ embedding for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + config.batch_size]] ])
if len(batch) == config.batch_size:
batch = torch.Tensor(batch).to(device)
@ -98,4 +99,4 @@ with open(logfile, "w") as log:
print(ctr)
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
print(logfile)
print(logfile)

35
slow_dump_parse_script.py Normal file
View File

@ -0,0 +1,35 @@
import umsgpack
import zstandard
import pyarrow as pa
data = []
with open("sample.zst", "rb") as f:
decomp = zstandard.ZstdDecompressor()
reader = decomp.stream_reader(f)
count = 0
while True:
try:
url, id, title, subreddit, author, timestamp, embedding = umsgpack.unpack(reader)
embedding = bytes(embedding)
data.append({"url": url, "id": id, "title": title, "subreddit": subreddit, "author": author, "timestamp": timestamp, "embedding": embedding})
count += 1
except umsgpack.InsufficientDataException:
break
print(count)
schema = pa.schema([
("url", pa.string()),
("id", pa.string()),
("title", pa.string()),
("subreddit", pa.string()),
("author", pa.string()),
("timestamp", pa.int64()),
("embedding", pa.binary())
])
table = pa.Table.from_pylist(data, schema=schema)
with pa.OSFile("output.parquet", "wb") as sink:
with pa.RecordBatchFileWriter(sink, schema) as writer:
writer.write_table(table)

View File

@ -183,8 +183,8 @@ pub struct FrontendInit {
pub type EmbeddingVector = Vec<f32>;
#[derive(Debug, Serialize)]
pub struct QueryResult {
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>)>,
pub struct QueryResult<T> {
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>, Option<T>)>,
pub formats: Vec<String>,
pub extensions: HashMap<String, String>,
}
@ -203,7 +203,9 @@ pub struct QueryRequest {
pub terms: Vec<QueryTerm>,
pub k: Option<usize>,
#[serde(default)]
pub include_video: bool
pub include_video: bool,
#[serde(default)]
pub debug_enabled: bool
}
lazy_static::lazy_static! {
@ -238,8 +240,9 @@ pub async fn get_total_embedding<A: Future<Output = Result<Vec<Vec<u8>>>>, B: Fu
}
}
if let Some(name) = &term.predefined_embedding {
let embedding = predefined_embeddings.get(name).context("name invalid")?;
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
if let Some(embedding) = predefined_embeddings.get(name) {
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
}
}
}

View File

@ -69,6 +69,10 @@ struct CLIArguments {
gpu: Option<usize>,
#[argh(option, description="descriptor CDFs")]
cdfs: Option<String>,
#[argh(option, description="postfilter by embedding (late discard if dot product above threshold)")]
postfilter: Vec<String>,
#[argh(option, description="postfilter by scorer")]
postfilter_scorer: Vec<String>,
}
#[derive(Clone, Deserialize, Serialize, Debug)]
@ -162,10 +166,22 @@ fn main() -> Result<()> {
} else {
(snd, None)
};
let mut post_threshold = None;
for x in &args.postfilter {
let (tname, snd) = x.split_once(':').context("invalid postfilter argument")?;
if tname == name {
post_threshold = Some(snd.parse::<f32>().context("parse postfilter threshold")?);
}
}
let blob = fs::read(path).context("read embedding")?;
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold));
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold, post_threshold));
}
let postfilter_scorer = args.postfilter_scorer.iter().map(|x| {
let (id, snd) = x.split_once(':').context("invalid postfilter scorer argument")?;
Ok((id.parse::<usize>().context("invalid postfilter scorer id")?, snd.parse::<f32>().context("parse postfilter scorer threshold")?))
}).collect::<Result<Vec<_>>>()?;
let pq_codec = if let Some(pq_codec) = args.pq_codec {
let data = fs::read(pq_codec).context("read pq codec")?;
let pq_codec: ProductQuantizer = rmp_serde::from_read(&data[..]).context("decode pq codec")?;
@ -255,7 +271,6 @@ fn main() -> Result<()> {
files.sort_by_key(|(id, _)| *id);
shard_id_mappings.sort_by_key(|(id, _)| *id);
let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
let mut out_vertices: Vec<u32> = vec![];
let mut shards: Vec<u32> = vec![];
@ -323,6 +338,8 @@ fn main() -> Result<()> {
let th = std::thread::spawn(move || reader_thread(&args.paths, tx));
let mut postfilter_count = 0;
let mut rng2 = rng.fork();
let initial_filter = |x: ProcessedEntry| {
i += 1;
@ -337,7 +354,9 @@ fn main() -> Result<()> {
latest_timestamp = latest_timestamp.max(timestamp);
earliest_timestamp = earliest_timestamp.min(timestamp);
for (_name, vec, histogram, threshold) in &mut embeddings {
let mut postfilter = false;
for (_name, vec, histogram, threshold, postfilter_threshold) in &mut embeddings {
let dot = SpatialSimilarity::dot(&embedding, vec).unwrap() as f32;
histogram.add(dot);
if let Some(threshold) = threshold {
@ -345,6 +364,12 @@ fn main() -> Result<()> {
return None;
}
}
if let Some(threshold) = postfilter_threshold {
if dot >= *threshold {
postfilter = true;
postfilter_count += 1; // somewhat wrong because could be duplicated
}
}
}
// distance thresholding is too costly to do over a long range so just do it badly
@ -393,7 +418,7 @@ fn main() -> Result<()> {
println!("{}", data);
}
Some((x, embedding))
Some((x, embedding, postfilter))
};
let mut dead_count = 0;
@ -404,14 +429,14 @@ fn main() -> Result<()> {
let batch: Vec<_> = batch.collect();
let batch_len = batch.len();
for (x, _embedding) in batch.iter() {
for (x, _embedding, _postfilter) in batch.iter() {
if let Some(ref mut file) = output_file {
file.write_all(&x.embedding)?;
}
}
if let Some(shards) = &mut shards_out {
for (i, (x, embedding)) in batch.iter().enumerate() {
for (i, (x, embedding, _postfilter)) in batch.iter().enumerate() {
// closest matches first
shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| {
let mut dot = SpatialSimilarity::dot(&centroid, &embedding).unwrap();
@ -439,7 +464,7 @@ fn main() -> Result<()> {
let quantizer = pq_codec.as_ref().context("PQ codec needed to output index")?;
let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize);
for (_x, embedding) in batch.iter() {
for (_x, embedding, _postfilter) in batch.iter() {
batch_embeddings.extend_from_slice(&embedding);
}
let codes = quantizer.quantize_batch(&batch_embeddings);
@ -448,7 +473,7 @@ fn main() -> Result<()> {
let cdfs = cdfs.as_ref().context("score model CDFs needed to output index")?;
let scores = score_model.score_batch(&batch_embeddings)?;
for (i, (x, _embedding)) in batch.into_iter().enumerate() {
for (i, (x, _embedding, mut postfilter)) in batch.into_iter().enumerate() {
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
let mut entry_scores = scores[(i * score_model.output_channels)..((i + 1) * score_model.output_channels)].to_vec();
@ -465,6 +490,13 @@ fn main() -> Result<()> {
index_output_file.2.write_all(&[cdf_bucket])?;
}
for (index, score) in postfilter_scorer.iter() {
if entry_scores[*index] < *score {
postfilter = true;
break;
}
}
let mut entry = PackedIndexEntry {
id: count + i as u32,
vertices,
@ -476,9 +508,10 @@ fn main() -> Result<()> {
shards
};
let mut bytes = bitcode::encode(&entry);
if bytes.len() > (RECORD_PAD_SIZE - 2) {
// as an ugly hack for removing entries already in the index shards, kill the URL and make it a graph node only
if bytes.len() > (RECORD_PAD_SIZE - 2) || postfilter {
// we do need the records to fit in a fixed size and can't really drop things, so discard URL so it can exist as a graph node only
entry.url = String::new();
entry.url = String::new(); // URL is only input-controlled, arbitrary-length field
bytes = bitcode::encode(&entry);
dead_count += 1;
}
@ -494,11 +527,11 @@ fn main() -> Result<()> {
}
if args.print_aggregates {
println!("earliest={} latest={} count={} read={} deduped={}", earliest_timestamp, latest_timestamp, count, i, deduped_count);
println!("earliest={} latest={} count={} read={} deduped={} postfiltered={}", earliest_timestamp, latest_timestamp, count, i, deduped_count, postfilter_count);
}
if let Some(histogram_path) = args.histograms {
let mut file = fs::File::create(histogram_path)?;
for (name, _, histogram, _) in &embeddings {
for (name, _, histogram, _, _) in &embeddings {
let width = 800.0;
let padding = 40.0;
let bars_height = 300 as f64;

View File

@ -10,14 +10,16 @@ with open("mse_config.json") as f:
def get_embedding(req):
return msgpack.unpackb(requests.post(config["clip_server"], data=msgpack.packb(req)).content)
output, input, *xs = sys.argv[1:]
mode, output, input = sys.argv[1:]
with open(output, "wb") as f:
with open(input, "rb") as g:
input_data = g.read()
if not xs:
if mode == "image":
with open(input, "rb") as g:
input_data = g.read()
result = get_embedding({"images": [input_data]})[0]
elif mode == "text":
result = get_embedding({"text": input})[0]
else:
result = get_embedding({"text": xs})[0]
raise Exception("unknown mode")
f.write(result)
print(base64.urlsafe_b64encode(result).decode("ascii"))

View File

@ -892,7 +892,7 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
}
#[instrument(skip(index))]
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult<()>> {
let result = index.vectors.search(&query, k as usize)?;
let mut seen_videos = HashSet::new();
@ -916,7 +916,8 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
index.filenames[id].container_filename(),
generate_filename_hash(&index.filenames[id as usize]).clone(),
index.format_codes[id],
index.metadata[id].as_ref().map(|x| (x.width, x.height))
index.metadata[id].as_ref().map(|x| (x.width, x.height)),
Option::<()>::None
))
})
.collect();

View File

@ -7,7 +7,7 @@ use argh::FromArgs;
use itertools::Itertools;
use foldhash::{HashSet, HashSetExt};
use half::f16;
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64}};
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64, SCALE_F64}};
use simsimd::SpatialSimilarity;
use memmap2::{Mmap, MmapOptions};
use std::rc::Rc;
@ -18,13 +18,14 @@ use http_body_util::{BodyExt, Empty, Full};
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
use std::pin::Pin;
use std::future::Future;
use serde::Serialize;
use serde::{Serialize, Deserialize};
use std::str::FromStr;
use std::collections::HashMap;
use std::io::Write;
mod common;
use common::{resize_for_embed_sync, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
use common::{resize_for_embed_sync, QueryTerm, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
#[derive(FromArgs, Clone)]
#[argh(description="Query disk index")]
@ -43,10 +44,18 @@ struct CLIArguments {
search_list_size: Option<usize>,
#[argh(switch, description="always use full-precision vectors (slow)")]
disable_pq: bool,
#[argh(option, short='l', description="listen address")]
listen_address: Option<String>,
#[argh(option, short='c', description="clip server")]
clip_server: Option<String>,
#[argh(option, short='c', description="server config file")]
config_path: Option<String>
}
#[derive(Deserialize, Clone)]
struct ServerConfig {
listen_address: String,
clip_server: String,
descriptor_names: Vec<String>,
telemetry_file: String,
search_list: usize,
beam_width: usize
}
lazy_static! {
@ -82,17 +91,30 @@ fn next_several_unvisited(s: &mut NeighbourBuffer, n: usize) -> Option<Vec<u32>>
}
}
const DUPLICATES_THRESHOLD: f32 = 0.95;
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
let loc = (id as usize) * index.pq_code_size;
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
}
struct VisitedNode {
image_url: String,
scores: Vec<f32>,
shards: Vec<u32>,
id: u32,
score: i64,
timestamp: u64,
dimensions: (u32, u32)
}
struct Scratch {
visited_adjacent: HashSet<u32>,
visited: HashSet<u32>,
neighbour_buffer: NeighbourBuffer,
neighbour_pre_buffer: Vec<u32>,
visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)>
visited_list: Vec<VisitedNode>,
visited_embeddings: Vec<f32>
}
struct Index {
@ -140,10 +162,20 @@ async fn greedy_search<'a>(scratch: &mut Scratch, start: u32, query: &[f16], que
let index = index.clone();
let node = handle.await?;
let vector = bytemuck::cast_slice(&node.vector);
let distance = fast_dot_noprefetch(query, &vector);
let mut score = fast_dot_noprefetch(query, &vector);
score += descriptor_product(index.clone(), &descriptor_scales, node.id);
cmps += 1;
if scratch.visited.insert(node.id) {
scratch.visited_list.push((node.id, distance, node.url, node.shards, node.scores));
if scratch.visited.insert(node.id) && node.url.len() > 0 {
scratch.visited_list.push(VisitedNode {
image_url: node.url,
scores: node.scores,
shards: node.shards,
id: node.id,
score,
timestamp: node.timestamp,
dimensions: node.dimensions
});
scratch.visited_embeddings.extend(bytemuck::cast_slice(&node.vector).iter().map(|x: &f16| x.to_f32()));
};
for &neighbour in node.vertices.iter() {
if scratch.visited_adjacent.insert(neighbour) {
@ -250,7 +282,8 @@ async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new(),
visited_adjacent: HashSet::new()
visited_adjacent: HashSet::new(),
visited_embeddings: Vec::new()
};
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
@ -265,14 +298,14 @@ async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
println!("index scan {}: {:?} cmps", shard, cmps_result);
}
scratch.visited_list.sort_by_key(|x| -x.1);
for (i, (id, distance, url, shards, scores)) in scratch.visited_list.iter().take(20).enumerate() {
let found_id = match matches.binary_search(&(*id, 0)) {
scratch.visited_list.sort_by_key(|x| -x.score);
for (i, node) in scratch.visited_list.iter().take(20).enumerate() {
let found_id = match matches.binary_search(&(node.id, 0)) {
Ok(pos) => pos,
Err(pos) => pos
};
if args.verbose {
println!("index scan: {} {} {} {:?} {:?}; rank {}", id, distance, url, shards, scores, matches[found_id].1 + 1);
println!("index scan: {} {} {} {:?} {:?}; rank {}", node.id, node.score, node.image_url, node.shards, node.scores, matches[found_id].1 + 1);
};
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
}
@ -341,11 +374,23 @@ pub async fn query_clip_server<I, O>(base_url: &str, path: &str, data: Option<I>
Ok(result)
}
#[derive(Serialize, Deserialize)]
struct TelemetryMessage {
#[serde(rename="correlationId")]
correlation_id: String,
data: serde_json::Value,
event: String,
#[serde(rename="instanceId")]
instance_id: String,
page: String
}
#[derive(Clone)]
struct Service {
index: Rc<Index>,
inference_server_config: Rc<InferenceServerConfig>,
args: Rc<CLIArguments>
config: Rc<ServerConfig>,
telemetry_channel: std::sync::mpsc::Sender<TelemetryMessage>
}
impl hyper::service::Service<Request<Incoming>> for Service {
@ -355,15 +400,16 @@ impl hyper::service::Service<Request<Incoming>> for Service {
fn call(&self, req: Request<Incoming>) -> Self::Future {
let index = self.index.clone();
let args = self.args.clone();
let config = self.config.clone();
let inference_server_config = self.inference_server_config.clone();
let channel = self.telemetry_channel.clone();
Box::pin(async move {
let mut body = match (req.method(), req.uri().path()) {
(&Method::GET, "/") => Response::new(Full::new(Bytes::from(serde_json::to_vec(&FrontendInit {
n_total: (index.header.count - index.header.dead_count) as u64,
d_emb: index.header.quantizer.n_dims,
predefined_embedding_names: vec![]
predefined_embedding_names: config.descriptor_names.clone()
})?))),
(&Method::POST, "/") => {
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
@ -381,7 +427,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
&body.terms,
&*inference_server_config,
|batch, _config| {
query_clip_server(args.clip_server.as_ref().unwrap(), "/", Some(batch))
query_clip_server(config.clip_server.as_str(), "/", Some(batch))
},
|image, config| async move {
let image = image::load_from_memory(&image)?;
@ -397,27 +443,89 @@ impl hyper::service::Service<Request<Incoming>> for Service {
}).unwrap();
let selected_start = index.header.shards[selected_shard].1;
let beamwidth = 3;
let beamwidth = config.beam_width;
let mut scratch = Scratch {
visited: HashSet::new(),
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
neighbour_buffer: NeighbourBuffer::new(config.search_list),
neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new(),
visited_adjacent: HashSet::new()
visited_adjacent: HashSet::new(),
visited_embeddings: Vec::new()
};
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
let mut desc = vec![0.0, 0.0, 0.0, 0.0];
for term in &body.terms {
if let Some(name) = &term.predefined_embedding {
if let Some(index) = config.descriptor_names.iter().position(|x| x == name) {
desc[index] = term.weight.unwrap_or(1.0) * 1.0/512.0;
}
}
}
let descriptor_scales = DescriptorScales(desc);
let query_preprocessed = index.header.quantizer.preprocess_query(&query);
let query = query.iter().map(|x| half::f16::from_f32(*x)).collect::<Vec<f16>>();
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), args.disable_pq, beamwidth).await?;
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), false, beamwidth).await?;
scratch.visited_list.sort_by_key(|x| -x.1);
let n_visited = scratch.visited_list.len();
let matches = scratch.visited_list.drain(..).map(|(id, score, url, shards, scores)| (score as f32, url, String::new(), 0, None)).collect::<Vec<_>>();
let mut similarities_against_self = vec![0.0f32; n_visited * n_visited];
// runtime deduplicate of results list
unsafe {
// vecs @ vecs.T
matrixmultiply::sgemm(
n_visited,
index.header.quantizer.n_dims,
n_visited,
1.0,
scratch.visited_embeddings.as_ptr(),
index.header.quantizer.n_dims as isize,
1,
scratch.visited_embeddings.as_ptr(),
1,
index.header.quantizer.n_dims as isize,
0.0,
similarities_against_self.as_mut_ptr(),
n_visited as isize,
1
);
}
// discard anything similar to something already in list
let mut i = 0;
let mut included = bitvec::bitvec![0; n_visited];
scratch.visited_list.retain(|_node| {
let row = &similarities_against_self[(i * n_visited)..((i + 1) * n_visited)];
let old_i = i;
i += 1;
for (other_i, similarity) in row.iter().enumerate() {
if similarity > &DUPLICATES_THRESHOLD && included[other_i] {
return false;
}
}
included.set(old_i, true);
true
});
scratch.visited_list.sort_unstable_by_key(|x| -x.score);
let matches = scratch.visited_list
.drain(..)
.map(|node| {
let debug = if body.debug_enabled {
Some((node.scores, node.shards, node.timestamp))
} else {
None
};
((node.score as f64 / SCALE_F64) as f32, node.image_url, String::new(), 0, Some(node.dimensions), debug)
})
.collect::<Vec<_>>();
let result = QueryResult {
formats: vec![],
@ -438,6 +546,25 @@ impl hyper::service::Service<Request<Incoming>> for Service {
.header(hyper::header::CONTENT_TYPE, "text/plain; version=0.0.4")
.body(Full::new(Bytes::from(buffer))).unwrap()
},
(&Method::POST, "/telemetry") => {
// TODO refactor
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
if upper > 1000 {
let mut resp = Response::new(Full::new(Bytes::from("Body too big")));
*resp.status_mut() = hyper::StatusCode::PAYLOAD_TOO_LARGE;
return Ok(resp);
}
let whole_body = req.collect().await?.to_bytes();
let message = serde_json::from_slice::<TelemetryMessage>(&whole_body)?;
channel.send(message)?;
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Full::new(Bytes::from(""))).unwrap()
}
(&Method::OPTIONS, "/") => {
Response::builder()
.status(StatusCode::NO_CONTENT)
@ -459,9 +586,9 @@ impl hyper::service::Service<Request<Incoming>> for Service {
}
}
async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceServerConfig> {
async fn get_backend_config(clip_server: &String) -> Result<InferenceServerConfig> {
loop {
match query_clip_server(clip_server.as_ref().unwrap(), "/config", Option::<()>::None).await {
match query_clip_server(clip_server, "/config", Option::<()>::None).await {
Ok(config) => return Ok(config),
Err(err) => {
tracing::warn!("waiting for clip server: {}", err);
@ -471,14 +598,31 @@ async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceSer
}
}
// can't run this as an async task because monoio File API is positional writes only
fn telemetry_handler(rx: std::sync::mpsc::Receiver<TelemetryMessage>, config: ServerConfig) -> Result<()> {
let mut telemetry_file = std::fs::OpenOptions::new().create(true).create(true).append(true).open(&config.telemetry_file)?;
while let Ok(message) = rx.recv() {
telemetry_file.write_all(rmp_serde::to_vec(&message)?.as_slice())?;
}
Ok(())
}
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
let config: ServerConfig = serde_json::from_slice(&std::fs::read(args.config_path.as_ref().unwrap())?)?;
let (telemetry_channel, telemetry_receiver) = std::sync::mpsc::channel();
let config_ = config.clone();
std::thread::spawn(move || telemetry_handler(telemetry_receiver, config_));
let service = Service {
index,
inference_server_config: Rc::new(get_backend_config(&args.clip_server).await?),
args: Rc::new(args.clone())
inference_server_config: Rc::new(get_backend_config(&config.clip_server).await?),
config: Rc::new(config.clone()),
telemetry_channel
};
let listener = TcpListener::bind(args.listen_address.as_ref().unwrap())?;
let listener = TcpListener::bind(config.listen_address)?;
println!("Listening");
loop {
let (stream, _) = listener.accept().await?;
@ -531,7 +675,7 @@ async fn main() -> Result<()> {
n_descriptors: header.descriptor_cdfs.len(),
});
if args.listen_address.is_some() {
if args.config_path.is_some() {
serve(&args, index).await?;
} else {
evaluate(&args, index).await?;

File diff suppressed because one or more lines are too long