1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-27 13:03:12 +00:00

release version

This commit is contained in:
osmarks 2025-01-24 09:24:28 +00:00
parent 3852d0078d
commit ee23b81444
34 changed files with 774 additions and 147 deletions

5
.gitignore vendored
View File

@ -22,3 +22,8 @@ index
queries.txt queries.txt
*.zst *.zst
.safetensors .safetensors
*/static/*.woff2
flamegraph.svg
*.jsonl
*.safetensors
perf.data

41
Cargo.lock generated
View File

@ -416,6 +416,18 @@ version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
version = "0.10.4" version = "0.10.4"
@ -1117,6 +1129,12 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.31" version = "0.3.31"
@ -2046,6 +2064,7 @@ dependencies = [
"axum", "axum",
"base64 0.22.1", "base64 0.22.1",
"bitcode", "bitcode",
"bitvec",
"bytemuck", "bytemuck",
"candle-core", "candle-core",
"chrono", "chrono",
@ -2067,6 +2086,7 @@ dependencies = [
"itertools 0.13.0", "itertools 0.13.0",
"json5", "json5",
"lazy_static", "lazy_static",
"matrixmultiply",
"maud", "maud",
"memmap2", "memmap2",
"mimalloc", "mimalloc",
@ -2850,6 +2870,12 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]] [[package]]
name = "rand" name = "rand"
version = "0.8.5" version = "0.8.5"
@ -3873,6 +3899,12 @@ dependencies = [
"version-compare", "version-compare",
] ]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]] [[package]]
name = "target-lexicon" name = "target-lexicon"
version = "0.12.16" version = "0.12.16"
@ -4732,6 +4764,15 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]
[[package]] [[package]]
name = "yoke" name = "yoke"
version = "0.7.5" version = "0.7.5"

View File

@ -62,6 +62,8 @@ monoio = "0.2"
hyper = "1" hyper = "1"
monoio-compat = { version = "0.2", features = ["hyper"] } monoio-compat = { version = "0.2", features = ["hyper"] }
http-body-util = "0.1" http-body-util = "0.1"
matrixmultiply = "0.3"
bitvec = "1"
[[bin]] [[bin]]
name = "reddit-dump" name = "reddit-dump"

View File

@ -9,15 +9,15 @@
<div class="about"> <div class="about">
<p> <p>
Welcome to {util.hardConfig.name} by <a href="https://osmarks.net/">osmarks.net Computational Memetics</a>. {util.hardConfig.name} searches images using semantic image/text embedding models. In general, search by thinking of what caption your desired image might have been given by random people on the internet. The model currently in use can read text fairly well and understands moderately abstract properties of images, but is limited to English and case-insensitive. Welcome to {util.hardConfig.name} by <a href="https://osmarks.net/">osmarks.net</a> Computational Memetics. {util.hardConfig.name} searches images using semantic image/text embedding models (refer to <a href="https://arxiv.org/abs/2303.15343">https://arxiv.org/abs/2303.15343</a>, <a href="https://arxiv.org/abs/2103.00020">https://arxiv.org/abs/2103.00020</a>). In general, search by thinking of what caption your desired image might have been given by random people on the internet. The model currently in use can read text fairly well and understands moderately abstract properties of images, but is limited to English and case-insensitive.
</p> </p>
<p> <p>
Advanced Mode sliders are generated from PCA on the index. The human-readable labels are generated manually by <a href="https://datasets.osmarks.net/components.html">looking at things</a>. "Useful"/"aesthetic"/"meme" sliders are defined based on an approximate extrapolation of my preferences and may not agree with your own opinions. Results are otherwise decided by the inscrutable whims of a billion-parameter neural network. Please note that large-scale internet data may contain things you do not like.
</p> </p>
<p> <p>
The code is open-source and available on <a href="https://github.com/osmarks/meme-search-engine/">GitHub.</a> The code is open-source and available on <a href="https://github.com/osmarks/meme-search-engine/">GitHub</a>.
</p> </p>
{#if util.hardConfig.telemetryEndpoint} {#if util.hardConfig.telemetry_endpoint}
<h2>Privacy</h2> <h2>Privacy</h2>
<p> <p>
We do not collect personal information. We do collect usage information (associated with a random ID) to improve the ranking algorithms. You can disable this: We do not collect personal information. We do collect usage information (associated with a random ID) to improve the ranking algorithms. You can disable this:

View File

@ -154,7 +154,6 @@
<div class="right"> <div class="right">
<NavItem page="advanced">Advanced</NavItem> <NavItem page="advanced">Advanced</NavItem>
<NavItem page="about">About</NavItem> <NavItem page="about">About</NavItem>
<NavItem page="refine">Refine</NavItem>
</div> </div>
</nav> </nav>
@ -185,7 +184,11 @@
<option>+</option> <option>+</option>
<option>-</option> <option>-</option>
</select> </select>
<input type="range" min="0" max="2" bind:value={term.weight} step="0.01"> {#if debugEnabled}
<input type="number" bind:value={term.weight} step="0.01">
{:else}
<input type="range" min="0" max="2" bind:value={term.weight} step="0.01">
{/if}
{#if term.type === "image"} {#if term.type === "image"}
<span>{term.file.name}</span> <span>{term.file.name}</span>
{:else if term.type === "text"} {:else if term.type === "text"}
@ -200,6 +203,14 @@
</li> </li>
{/each} {/each}
</ul> </ul>
{#if showDebugSwitch}
<div class="ctrlbar">
<input type="checkbox" bind:checked={debugEnabled} id="debug" />
<label for="debug">Debug</label>
{#if debugEnabled}
{/if}
</div>
{/if}
<div class="ctrlbar"> <div class="ctrlbar">
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}> <input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
<button on:click={pickFile}>Image Query</button> <button on:click={pickFile}>Image Query</button>
@ -242,6 +253,9 @@
let queryCounter = 0 let queryCounter = 0
let config = {} let config = {}
const showDebugSwitch = localStorage.getItem("debugEnabled") === "true"
let debugEnabled = false
const newTextQuery = (content=null) => { const newTextQuery = (content=null) => {
queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" }) queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" })
queryTerms = queryTerms queryTerms = queryTerms
@ -253,9 +267,14 @@
const runSearch = async () => { const runSearch = async () => {
if (!resultPromise) { if (!resultPromise) {
const friendlyModeQueryOrRandom = friendlyModeQuery ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : [{ embedding: util.randn(config.d_emb, 1 / (config.d_emb ** 0.5)), weight: 1, sign: "+" }]
const terms = friendlyMode ?
friendlyModeQueryOrRandom.concat(util.hardConfig.friendly_mode_default_terms ?? []) :
queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))
let args = { let args = {
"terms": friendlyMode ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] })), "terms": terms,
"include_video": true "include_video": true,
"debug_enabled": debugEnabled
} }
util.sendTelemetry("search", { util.sendTelemetry("search", {
@ -314,7 +333,7 @@
type: "predefined_embedding", type: "predefined_embedding",
predefinedEmbedding: predefinedEmbeddingName, predefinedEmbedding: predefinedEmbeddingName,
sign: "+", sign: "+",
weight: 0.2 weight: 1
}) })
} }
queryTerms = queryTerms queryTerms = queryTerms

View File

@ -36,26 +36,14 @@
const d_emb = 1152 const d_emb = 1152
const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
const vecZero = d => new Array(d).fill(0)
const vecScale = (xs, s) => xs.map(x => x * s)
const boxMuller = () => {
let x = Math.random()
let y = Math.random()
return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
}
const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)
const K = 2 const K = 2
let candidates = [] let candidates = []
const select = async candidate => { const select = async candidate => {
candidates = [] candidates = []
const direction = randn(d_emb, 1 / d_emb) const direction = util.randn(d_emb, 1 / d_emb)
for (let i = -K; i <= K; i++) { for (let i = -K; i <= K; i++) {
const newV = vecSum(vecScale(direction, i / K), candidate.vector) const newV = util.vecSum(util.vecScale(direction, i / K), candidate.vector)
candidates.push({ vector: newV, results: null, i: i + K }) candidates.push({ vector: newV, results: null, i: i + K })
} }
await Promise.all(candidates.map(async x => { await Promise.all(candidates.map(async x => {
@ -67,7 +55,7 @@
console.log(candidates) console.log(candidates)
} }
select({ vector: randn(d_emb, 1 / d_emb) }) select({ vector: util.randn(d_emb, 1 / d_emb) })
const handleKey = ev => { const handleKey = ev => {
const num = parseInt(ev.key) const num = parseInt(ev.key)
@ -75,4 +63,4 @@
select(candidates[num - 1]) select(candidates[num - 1])
} }
} }
</script> </script>

View File

@ -28,6 +28,10 @@
{#key `${queryCounter}${result.file}`} {#key `${queryCounter}${result.file}`}
<div class="result"> <div class="result">
<ResultImage {result} {results} {updateCounter} {redrawGrid} constrainBy="width" /> <ResultImage {result} {results} {updateCounter} {redrawGrid} constrainBy="width" />
{#if result[5]} <!-- debug info -->
<div>{result[0]}</div>
<div>{JSON.stringify(result[5])}</div>
{/if}
</div> </div>
{/key} {/key}
{/each} {/each}
@ -72,7 +76,7 @@
export const redrawGrid = async () => { export const redrawGrid = async () => {
if (refreshLayout) { if (refreshLayout) {
refreshLayout() refreshLayout()
await recomputeScroll() setTimeout(recomputeScroll, 0)
} }
} }
@ -82,7 +86,7 @@
} }
const handleScroll = () => { const handleScroll = () => {
if (window.scrollY + window.innerHeight >= heightThreshold && pendingImageLoads === 0) { if (window.scrollY + window.innerHeight >= heightThreshold) {
recomputeScroll() recomputeScroll()
if (window.scrollY + window.innerHeight < heightThreshold) return; if (window.scrollY + window.innerHeight < heightThreshold) return;
let init = displayedResults.length let init = displayedResults.length

View File

@ -1,4 +1,11 @@
@use 'sass:color' @use 'sass:color'
@font-face
font-family: 'Iosevka'
font-style: normal
font-weight: 400
src: url(./iosevka.woff2) format('woff2')
unicode-range: U+0000-00C0, U+A2-A9, U+AC-AE, U+00D7, U+00F7, U+FEFF, U+FFFD
$palette-primary: #3f9b0b $palette-primary: #3f9b0b
$palette-secondary: #033500 $palette-secondary: #033500

Binary file not shown.

View File

@ -78,3 +78,15 @@ fetch(config.backend_url).then(x => x.json().then(x => {
serverConfig.set(x) serverConfig.set(x)
window.serverConfig = x window.serverConfig = x
})) }))
export const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
export const vecZero = d => new Array(d).fill(0)
export const vecScale = (xs, s) => xs.map(x => x * s)
const boxMuller = () => {
let x = Math.random()
let y = Math.random()
return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
}
export const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)

View File

@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, minimum-scale=1, initial-scale=1, user-scalable=yes"> <meta name="viewport" content="width=device-width, minimum-scale=1, initial-scale=1, user-scalable=yes">
<meta name="description" content="Organizing the world's memes."> <meta name="description" content="Organizing the world's memes.">
<title>Meme Search Engine</title> <title>Nooscope</title>
</style> </style>
<link rel="stylesheet" href="app.css"> <link rel="stylesheet" href="app.css">
<link rel="icon" type="image/png" href="logo.png"> <link rel="icon" type="image/png" href="logo.png">

10
config2.json Normal file
View File

@ -0,0 +1,10 @@
{
"listen_address": "127.0.0.1:5601",
"clip_server": "http://100.64.0.10:1708",
"descriptor_names": [
"Useful",
"Meme",
"Aesthetic",
"Time"
]
}

164
diskann/chainq.py Normal file
View File

@ -0,0 +1,164 @@
import numpy as np
import msgpack
import math
import torch
from torch import autograd
import tqdm
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# https://github.com/una-dinosauria/local-search-quantization/blob/master/src/encodings/encode_chain.jl#L2
# vectorized somewhat
def viterbi_encode(
out_codes: torch.Tensor, # N x M (ints)
vectors: torch.Tensor, # N x D
codebooks: torch.Tensor # M x H x D
):
N, D = vectors.shape
M, H, D2 = codebooks.shape
assert D == D2
# M x H x N - ||x-c||^2 ignoring x.T @ x component
unary_costs = -2 * (codebooks @ vectors.T) + (torch.linalg.norm(codebooks, dim=2) ** 2).unsqueeze(-1)
binary_costs = torch.zeros(M - 1, H, H, dtype=torch.float, device=DEVICE)
for i in range(M - 1):
binary_costs[i] = 2 * codebooks[i] @ codebooks[i + 1].T
print("binary costs", binary_costs)
min_cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
min_idx = torch.zeros(M, H, N, dtype=torch.int, device=DEVICE)
cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
# forward pass - propagate optimal costs and indices forward
for step in tqdm.trange(M - 1):
if step > 0:
unary_costs[step] += min_cost
ucost = unary_costs[step]
# for all possible costs at this step
for j in range(H):
bcost = binary_costs[step, j].unsqueeze(-1) # independent of N
cost = ucost + bcost
min_values, min_indices = torch.min(cost, dim=0)
min_cost[j] = min_values
min_idx[step, j] = min_indices
unary_costs[-1] += min_cost
# backward pass - propagate optimal indices backwards
out_codes[:, -1] = torch.argmin(unary_costs[-1], dim=0)
for i in range(M - 2, -1, -1):
out_codes[:, i] = min_idx[i][out_codes[:, i + 1], range(N)]
def dims_for(dim, total_dims, m):
dims_per_code = total_dims // m
relevant_codebooks = [dim // dims_per_code]
if relevant_codebooks[-1] < m - 1:
relevant_codebooks.append(relevant_codebooks[-1] + 1)
return relevant_codebooks
def update_codebooks(transformed_data, codes, h):
n, d = transformed_data.shape
n2, m = codes.shape
assert n == n2
new_codebook = torch.zeros(m, h, d, dtype=torch.float, device=DEVICE)
for dim in tqdm.trange(d):
relevant_codebooks = dims_for(dim, d, m)
assignment_matrix = torch.zeros(n, len(relevant_codebooks), h, dtype=torch.float, device=DEVICE)
indices = (
torch.arange(n, dtype=torch.int, device=DEVICE).repeat(len(relevant_codebooks)),
torch.arange(len(relevant_codebooks), dtype=torch.int, device=DEVICE).repeat_interleave(n),
codes[:, relevant_codebooks].T.flatten()
)
assignment_matrix[indices] = 1
#print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
assignment_matrix = assignment_matrix.reshape(n, len(relevant_codebooks) * h)
#print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
#soln = torch.linalg.lstsq(assignment_matrix, transformed_data[:, dim])[0]
reg = 1e-3 * torch.eye(len(relevant_codebooks) * h, device=DEVICE)
A = assignment_matrix.T @ assignment_matrix + reg
b = assignment_matrix.T @ transformed_data[:, dim]
#print("matrix", A)
usage = assignment_matrix.sum(dim=0)
unused = usage < 1
print(unused.sum().detach().item())
soln = torch.linalg.solve(A, b)
#print("solution", soln.reshape(len(relevant_codebooks), h))
if unused.any():
soln[unused] = torch.randn_like(soln[unused])
new_codebook[relevant_codebooks, :, dim] = soln.reshape(len(relevant_codebooks), h)
if torch.isnan(new_codebook[relevant_codebooks, :, dim]).any():
print("oh no", dim, new_codebook, relevant_codebooks, new_codebook[relevant_codebooks, :, dim])
print("--- dim ---", dim)
print("- sum per column:", assignment_matrix.sum(dim=0)) # Check if any columns are all zero
print("- rank:", torch.linalg.matrix_rank(assignment_matrix))
print("- condition number:", torch.linalg.cond(assignment_matrix))
raise SystemExit
return new_codebook
BATCH = 8192
def train_chainq(vectors, m, h, transform, codebooks, n_iters):
for i in range(n_iters):
transformed_data = vectors @ transform.T
codes = torch.zeros(vectors.shape[0], m, dtype=torch.int, device=DEVICE)
for i in range(0, vectors.shape[0], BATCH):
viterbi_encode(codes[i:i+BATCH], transformed_data[i:i+BATCH], codebooks)
print("encoded")
#codebooks = update_codebooks(transformed_data, codes, h)
print("codebooks updated")
quantized = torch.zeros_like(vectors, dtype=torch.float, device=DEVICE)
for j in range(m):
quantized[:] += codebooks[j, codes[:, j]]
print("quantized")
print((quantized - transformed_data).abs().mean(), transformed_data.abs().mean())
print("comparing")
res = transformed_data.T @ quantized
print("running SVD...")
u, s, vt = torch.linalg.svd(res)
print("done.")
transform = u @ vt
print("regenerated transform")
return codebooks, transform
with open("opq.msgpack", "rb") as f:
data = msgpack.unpackb(f.read())
n_dims = 1152
dataset = torch.tensor(np.random.permutation(np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32))[:BATCH*1], device=DEVICE)
codebooks = torch.zeros(64, 256, n_dims, dtype=torch.float, device=DEVICE)
centroids = torch.tensor(np.array(data["centroids"]).astype(np.float32).reshape(256, n_dims), device=DEVICE)
for dim in range(n_dims):
relevant_codebooks = dim // 64
codebooks[relevant_codebooks, :, dim] = centroids[:, dim]
print(centroids)
#print("codebooks", codebooks.tolist())
codebooks, transform = train_chainq(dataset, 64, 256, torch.tensor(np.array(data["transform"]).astype(np.float32).reshape(n_dims, n_dims), device=DEVICE), codebooks, 100)
with open("chainq.msgpack", "wb") as f:
msgpack.pack({
"codebooks": codebooks.cpu().numpy().flatten().tolist(),
"transform": transform.cpu().numpy().flatten().tolist(),
"n_dims": n_dims,
"n_dims_per_code": n_dims // 64
}, f)

View File

@ -45,7 +45,7 @@ impl Vector {
} }
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers // Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
pub const SCALE: f32 = 1099511627776.0; pub const SCALE: f32 = 4294967296.0;
pub const SCALE_F64: f64 = SCALE as f64; pub const SCALE_F64: f64 = SCALE as f64;
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 { pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 {

131
faiss_bench_quantizer.py Normal file
View File

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import faiss
import time
import numpy as np
def eval_codec(q, xb):
t0 = time.time()
codes = q.compute_codes(xb)
t1 = time.time()
xb_decoded = q.decode(codes)
recons_err = ((xb - xb_decoded) ** 2).sum(axis=1).mean()
print(f"\tencode time: {t1 - t0:.3f} reconstruction error: {recons_err:.3f} ")
def eval_quantizer(q, xb, xt, variants=None):
if variants is None:
variants = [(None, None)]
t0 = time.time()
q.train(xt)
t1 = time.time()
train_t = t1 - t0
print(f'\ttraining time: {train_t:.3f} s')
for name, val in variants:
if name is not None:
print(f"{name}={val}")
if isinstance(q, faiss.ProductAdditiveQuantizer):
for i in range(q.nsplits):
subq = faiss.downcast_Quantizer(q.subquantizer(i))
getattr(subq, name)
setattr(subq, name, val)
else:
getattr(q, name) # make sure field exists
setattr(q, name, val)
eval_codec(q, xb)
todo = sys.argv[1:]
ds = np.fromfile(todo[0], dtype=np.float16).reshape(-1, 1152).astype(np.float32)
print(ds)
del todo[0]
if len(todo) > 0:
if todo[0].count("x") == 1:
M, nbits = [int(x) for x in todo[0].split("x")]
del todo[0]
elif todo[0].count("x") == 2:
nsplits, Msub, nbits = [int(x) for x in todo[0].split("x")]
M = nsplits * Msub
del todo[0]
maxtrain = max(100 << nbits, 10**5)
print(f"eval on {M}x{nbits} maxtrain={maxtrain}")
xb, xt = ds, ds
nb, d = xb.shape
nt, d = xt.shape
# fastest to slowest
if 'lsq-gpu' in todo:
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
ngpus = faiss.get_num_gpus()
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
lsq.verbose = True
eval_quantizer(lsq, xb, xt, 'lsq-gpu')
if 'pq' in todo:
pq = faiss.ProductQuantizer(d, M, nbits)
print("===== PQ")
eval_quantizer(pq, xb, xt)
if 'opq' in todo:
d2 = ((d + M - 1) // M) * M
print("OPQ d2=", d2)
opq = faiss.OPQMatrix(d, M, d2)
opq.train(xt)
xb2 = opq.apply(xb)
xt2 = opq.apply(xt)
pq = faiss.ProductQuantizer(d2, M, nbits)
print("===== PQ")
eval_quantizer(pq, xb2, xt2)
if 'prq' in todo:
print(f"===== PRQ{nsplits}x{Msub}x{nbits}")
prq = faiss.ProductResidualQuantizer(d, nsplits, Msub, nbits)
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
eval_quantizer(prq, xb, xt, variants=variants)
if 'plsq' in todo:
print(f"===== PLSQ{nsplits}x{Msub}x{nbits}")
plsq = faiss.ProductLocalSearchQuantizer(d, nsplits, Msub, nbits)
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
eval_quantizer(plsq, xb, xt, variants=variants)
if 'rq' in todo:
print("===== RQ")
rq = faiss.ResidualQuantizer(d, M, nbits, )
rq.max_beam_size
rq.max_beam_size = 30 # for compatibility with older runs
# rq.train_type = faiss.ResidualQuantizer.Train_default
# rq.verbose = True
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
eval_quantizer(rq, xb, xt, variants=variants)
if 'rq_lut' in todo:
print("===== RQ")
rq = faiss.ResidualQuantizer(d, M, nbits, )
rq.max_beam_size
rq.max_beam_size = 30 # for compatibility with older runs
rq.use_beam_LUT
rq.use_beam_LUT = 1
# rq.train_type = faiss.ResidualQuantizer.Train_default
# rq.verbose = True
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32, 64)]
eval_quantizer(rq, xb, xt, variants=variants)
if 'lsq' in todo:
print("===== LSQ")
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
eval_quantizer(lsq, xb, xt, variants=variants)

View File

@ -1,8 +1,12 @@
{ {
"backend_url": "http://localhost:5601/", "backend_url": "/backend",
"image_path": "", "image_path": "",
"thumb_path": null, "thumb_path": null,
"description": "Organizing the world's memes.", "description": "Organizing the world's memes.",
"name": "Nooscope", "name": "Nooscope",
"telemetry_endpoint": "/telemetry" "telemetry_endpoint": "/telemetry",
"friendly_mode_default_terms": [
{ "predefined_embedding": "Meme", "weight": 1, "sign": "+" },
{ "predefined_embedding": "Aesthetic", "weight": 0.2, "sign": "+" }
]
} }

4
genseahash.py Normal file
View File

@ -0,0 +1,4 @@
import seahash, sys
with open(sys.argv[1], "rb") as f:
print(seahash.hash(f.read()))

View File

@ -31,6 +31,7 @@ model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters()) params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters") print(f"{params/1e6:.1f}M parameters")
print(model) print(model)
model.eval()
files = shared.fetch_all_files() files = shared.fetch_all_files()
variance = {} variance = {}
@ -56,4 +57,4 @@ with torch.inference_mode():
top = sorted(variance.items(), key=lambda x: -x[1]) top = sorted(variance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f: with open("top.json", "w") as f:
json.dump(top[:100], f) json.dump(top[:50], f)

View File

@ -52,7 +52,7 @@ channel = int(sys.argv[2])
percentile = float(sys.argv[3]) percentile = float(sys.argv[3])
output_pairs = int(sys.argv[4]) output_pairs = int(sys.argv[4])
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()])) mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True) top = sorted(((filename, score) for filename, score in results.items()), key=lambda x: x[1][channel], reverse=True)
select_from = top[:int(len(top) * percentile)] select_from = top[:int(len(top) * percentile)]
out = [] out = []

View File

@ -13,14 +13,14 @@ import sys
from model import Config, BradleyTerry from model import Config, BradleyTerry
import shared import shared
batch_size = 128 batch_size = 32
num_pairs = batch_size * 1024 num_pairs = batch_size * 1024
device = "cuda" device = "cuda"
config = Config( config = Config(
d_emb=1152, d_emb=1152,
n_hidden=1, n_hidden=1,
n_ensemble=1, n_ensemble=16,
device=device, device=device,
dtype=torch.float32, dtype=torch.float32,
output_channels=3, output_channels=3,
@ -32,6 +32,7 @@ model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters()) params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters") print(f"{params/1e6:.1f}M parameters")
print(model) print(model)
model.eval()
files = shared.fetch_all_files() files = shared.fetch_all_files()
importance = {} importance = {}
@ -41,8 +42,8 @@ buffers = {k: v.detach() for k, v in model.named_buffers()}
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html # https://pytorch.org/tutorials/intermediate/per_sample_grads.html
def compute_loss(params, buffers, sample, target): def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0) batch = sample.unsqueeze(1)
targets = target.unsqueeze(0) targets = target.unsqueeze(1).expand((config.n_ensemble, 1, config.output_channels))
predictions = functional_call(model, (params, buffers), (batch,)) predictions = functional_call(model, (params, buffers), (batch,))
loss = F.binary_cross_entropy(predictions, targets) loss = F.binary_cross_entropy(predictions, targets)
@ -75,4 +76,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
top = sorted(importance.items(), key=lambda x: -x[1]) top = sorted(importance.items(), key=lambda x: -x[1])
with open("top.json", "w") as f: with open("top.json", "w") as f:
json.dump(top[:256], f) json.dump(top[:50], f)

View File

@ -31,6 +31,11 @@ params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters") print(f"{params/1e6:.1f}M parameters")
print(model) print(model)
torch.random.manual_seed(1)
for x in model.ensemble.models:
x.output.bias.data.fill_(0)
out_layers = [] out_layers = []
out_bias = [] out_bias = []
@ -50,7 +55,7 @@ with torch.inference_mode():
downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone() downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone()
for i in range(10): for i in range(10):
input = torch.randn(3, config.d_emb) input = torch.randn(4, config.d_emb)
ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T
r_result = input r_result = input
for (layer, bias) in zip(out_layers, out_bias): for (layer, bias) in zip(out_layers, out_bias):
@ -58,13 +63,14 @@ with torch.inference_mode():
print(r_result.shape, bias.shape) print(r_result.shape, bias.shape)
r_result = F.silu(r_result) r_result = F.silu(r_result)
r_result = torch.matmul(downprojection, r_result) / config.n_ensemble r_result = torch.matmul(downprojection, r_result) / config.n_ensemble
bias_diff = torch.mean(r_result - ground_truth_result) error = torch.mean(r_result - ground_truth_result)
model_issue_diff = torch.max(torch.abs(r_result - bias_diff - ground_truth_result)) print(error)
print(model_issue_diff) assert error.detach().cpu().numpy() < 1e-4
print("test vector:") print("test vector:")
print(input.flatten().tolist()) #print(input.flatten().tolist())
print("ground truth result:") print("ground truth result:")
print(ground_truth_result.shape)
print(ground_truth_result.T.flatten().tolist()) print(ground_truth_result.T.flatten().tolist())
save_file({ save_file({

View File

@ -35,7 +35,7 @@ class Ensemble(nn.Module):
# model batch # model batch
def forward(self, embs): def forward(self, embs):
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_channels
class BradleyTerry(nn.Module): class BradleyTerry(nn.Module):
def __init__(self, config): def __init__(self, config):

View File

@ -51,19 +51,25 @@ async def index(request):
<form action="/rate" method="POST"> <form action="/rate" method="POST">
<table> <table>
<tr> <tr>
<td><input type="radio" name="rating-useful" value="1+" id="rq1p"> <label for="rq1p">LHS is much better (useful)</label></td>
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td> <td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td> <td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td> <td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
<td><input type="radio" name="rating-useful" value="2+" id="rq2p"> <label for="rq2p">LHS is much better (useful)</label></td>
</tr> </tr>
<tr> <tr>
<td><input type="radio" name="rating-meme" value="1+" id="rm1p"> <label for="rm1p">LHS is much better (memetically)</label></td>
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td> <td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td> <td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td> <td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
<td><input type="radio" name="rating-meme" value="2+" id="rm2p"> <label for="rm2p">LHS is much better (memetically)</label></td>
</tr> </tr>
<tr> <tr>
<td><input type="radio" name="rating-aesthetic" value="1+" id="ra1p"> <label for="ra1p">LHS is much better (aesthetically)</label></td>
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td> <td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td> <td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td> <td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
<td><input type="radio" name="rating-aesthetic" value="2+" id="ra2p"> <label for="ra2p">LHS is much better (aesthetically)</label></td>
</td> </td>
</table> </table>
@ -82,35 +88,29 @@ async def index(request):
document.querySelector("form").submit() document.querySelector("form").submit()
}} }}
}} }}
const keys = {{
"q": "rq1p",
"w": "rq1",
"e": "rqe",
"r": "rq2",
"t": "rq2p",
"a": "rm1p",
"s": "rm1",
"d": "rme",
"f": "rm2",
"g": "rm2p",
"z": "ra1p",
"x": "ra1",
"c": "rae",
"v": "ra2",
"b": "ra2p",
}}
document.addEventListener("keypress", function(event) {{ document.addEventListener("keypress", function(event) {{
if (event.key === "q") {{ const key = keys[event.key]
document.querySelector("input[name='rating-useful'][value='1']").checked = true if (key) {{
commitIfReady() document.getElementById(key).checked = true
}} else if (event.key === "w") {{
document.querySelector("input[name='rating-useful'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "e") {{
document.querySelector("input[name='rating-useful'][value='2']").checked = true
commitIfReady()
}} else if (event.key === "a") {{
document.querySelector("input[name='rating-meme'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "s") {{
document.querySelector("input[name='rating-meme'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "d") {{
document.querySelector("input[name='rating-meme'][value='2']").checked = true
commitIfReady()
}} else if (event.key === "z") {{
document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true
commitIfReady()
}} else if (event.key === "x") {{
document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true
commitIfReady()
}} else if (event.key === "c") {{
document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true
commitIfReady()
}} }}
commitIfReady()
}}); }});
</script> </script>
</body> </body>

View File

@ -20,27 +20,35 @@ def fetch_embedding(filename):
csr.close() csr.close()
return x.copy() # PyTorch complains otherwise due to bad return x.copy() # PyTorch complains otherwise due to bad
def map_rating(rating, uncertainty=0.05): def map_rating(rating):
def map_one(rating): def map_one(rating):
match rating: match rating:
case "1": # meme 1 is better case "1": # meme 1 is better
return 1 - uncertainty return 0.9
case "2": case "2":
return uncertainty return 0.1
case "2+" | "2p":
return 0.3
case "1+" | "1p":
return 0.7
case "eq": case "eq":
return 0.5 return 0.5
case _: raise ValueError("invalid rating, please fix") case _: raise ValueError("invalid rating, please fix")
return np.array([map_one(r) for r in rating.split(",")]) return np.array([map_one(r) for r in rating.split(",")])
def fetch_ratings(): def fetch_ratings(sets):
trains = defaultdict(list) trains = defaultdict(list)
validations = defaultdict(list) validations = defaultdict(list)
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings") csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
its = set()
for meme1, meme2, rating, iteration in csr.fetchall(): for meme1, meme2, rating, iteration in csr.fetchall():
if iteration not in its:
print(iteration)
its.add(iteration)
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating))) (validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
csr.close() csr.close()
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items())) return list(x[1] for x in sorted(trains.items()) if str(x[0]) in sets), list(x[1] for x in sorted(validations.items()) if str(x[0]) in sets)
def generate_random_permutations(x, n): def generate_random_permutations(x, n):
out = [] out = []

View File

@ -9,11 +9,12 @@ import time
from tqdm import tqdm from tqdm import tqdm
import math import math
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
import sys
from model import Config as ModelConfig, BradleyTerry from model import Config as ModelConfig, BradleyTerry
import shared import shared
trains, validations = shared.fetch_ratings() trains, validations = shared.fetch_ratings(sys.argv[1:])
for train, validation in zip(trains, validations): for train, validation in zip(trains, validations):
print(len(train), len(validation)) print(len(train), len(validation))
@ -36,11 +37,11 @@ config = TrainConfig(
n_ensemble=16, n_ensemble=16,
device=device, device=device,
dtype=torch.float32, dtype=torch.float32,
dropout=0.1, dropout=0.0,
output_channels=3 output_channels=3
), ),
lr=3e-4, lr=3e-4,
weight_decay=0.2, weight_decay=0.0,
batch_size=1, batch_size=1,
epochs=5, epochs=5,
compile=False, compile=False,

View File

@ -1,12 +1,12 @@
import torch import torch
import pyarrow as pa import numpy as np
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
with pa.memory_map("../../sample_1m.arrow", "r") as source: loaded_arrays = np.memmap("embeddings.bin", dtype=np.float16).reshape(-1, 1152)
loaded_arrays = pa.ipc.open_file(source).read_all() loaded_arrays_permutation = np.random.permutation(len(loaded_arrays))
train_split = 0.8 train_split = 0.8
def ckpt_path(steps): def ckpt_path(steps):
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt" return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"

View File

@ -6,9 +6,10 @@ import json
import time import time
from tqdm import tqdm from tqdm import tqdm
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
import math
from model import SAEConfig, SAE from model import SAEConfig, SAE
from shared import train_split, loaded_arrays, ckpt_path from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation
device = "cuda" device = "cuda"
@ -24,7 +25,7 @@ class TrainConfig:
config = TrainConfig( config = TrainConfig(
model=SAEConfig( model=SAEConfig(
d_emb=1152, d_emb=1152,
d_hidden=65536, d_hidden=262144,
top_k=128, top_k=128,
device=device, device=device,
dtype=torch.float32, dtype=torch.float32,
@ -33,7 +34,7 @@ config = TrainConfig(
lr=3e-4, lr=3e-4,
weight_decay=0.0, weight_decay=0.0,
batch_size=64, batch_size=64,
epochs=5, epochs=1,
compile=True, compile=True,
) )
@ -81,7 +82,7 @@ with open(logfile, "w") as log:
batch = [] batch = []
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size)) t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
for batch_start in t: for batch_start in t:
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ]) batch = numpy.stack([ embedding for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + config.batch_size]] ])
if len(batch) == config.batch_size: if len(batch) == config.batch_size:
batch = torch.Tensor(batch).to(device) batch = torch.Tensor(batch).to(device)
@ -98,4 +99,4 @@ with open(logfile, "w") as log:
print(ctr) print(ctr)
numpy.save(f"ckpt/{steps}.counters.npy", ctr) numpy.save(f"ckpt/{steps}.counters.npy", ctr)
print(logfile) print(logfile)

35
slow_dump_parse_script.py Normal file
View File

@ -0,0 +1,35 @@
import umsgpack
import zstandard
import pyarrow as pa
data = []
with open("sample.zst", "rb") as f:
decomp = zstandard.ZstdDecompressor()
reader = decomp.stream_reader(f)
count = 0
while True:
try:
url, id, title, subreddit, author, timestamp, embedding = umsgpack.unpack(reader)
embedding = bytes(embedding)
data.append({"url": url, "id": id, "title": title, "subreddit": subreddit, "author": author, "timestamp": timestamp, "embedding": embedding})
count += 1
except umsgpack.InsufficientDataException:
break
print(count)
schema = pa.schema([
("url", pa.string()),
("id", pa.string()),
("title", pa.string()),
("subreddit", pa.string()),
("author", pa.string()),
("timestamp", pa.int64()),
("embedding", pa.binary())
])
table = pa.Table.from_pylist(data, schema=schema)
with pa.OSFile("output.parquet", "wb") as sink:
with pa.RecordBatchFileWriter(sink, schema) as writer:
writer.write_table(table)

View File

@ -183,8 +183,8 @@ pub struct FrontendInit {
pub type EmbeddingVector = Vec<f32>; pub type EmbeddingVector = Vec<f32>;
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct QueryResult { pub struct QueryResult<T> {
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>)>, pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>, Option<T>)>,
pub formats: Vec<String>, pub formats: Vec<String>,
pub extensions: HashMap<String, String>, pub extensions: HashMap<String, String>,
} }
@ -203,7 +203,9 @@ pub struct QueryRequest {
pub terms: Vec<QueryTerm>, pub terms: Vec<QueryTerm>,
pub k: Option<usize>, pub k: Option<usize>,
#[serde(default)] #[serde(default)]
pub include_video: bool pub include_video: bool,
#[serde(default)]
pub debug_enabled: bool
} }
lazy_static::lazy_static! { lazy_static::lazy_static! {
@ -238,8 +240,9 @@ pub async fn get_total_embedding<A: Future<Output = Result<Vec<Vec<u8>>>>, B: Fu
} }
} }
if let Some(name) = &term.predefined_embedding { if let Some(name) = &term.predefined_embedding {
let embedding = predefined_embeddings.get(name).context("name invalid")?; if let Some(embedding) = predefined_embeddings.get(name) {
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0); total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
}
} }
} }

View File

@ -69,6 +69,10 @@ struct CLIArguments {
gpu: Option<usize>, gpu: Option<usize>,
#[argh(option, description="descriptor CDFs")] #[argh(option, description="descriptor CDFs")]
cdfs: Option<String>, cdfs: Option<String>,
#[argh(option, description="postfilter by embedding (late discard if dot product above threshold)")]
postfilter: Vec<String>,
#[argh(option, description="postfilter by scorer")]
postfilter_scorer: Vec<String>,
} }
#[derive(Clone, Deserialize, Serialize, Debug)] #[derive(Clone, Deserialize, Serialize, Debug)]
@ -162,10 +166,22 @@ fn main() -> Result<()> {
} else { } else {
(snd, None) (snd, None)
}; };
let mut post_threshold = None;
for x in &args.postfilter {
let (tname, snd) = x.split_once(':').context("invalid postfilter argument")?;
if tname == name {
post_threshold = Some(snd.parse::<f32>().context("parse postfilter threshold")?);
}
}
let blob = fs::read(path).context("read embedding")?; let blob = fs::read(path).context("read embedding")?;
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold)); embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold, post_threshold));
} }
let postfilter_scorer = args.postfilter_scorer.iter().map(|x| {
let (id, snd) = x.split_once(':').context("invalid postfilter scorer argument")?;
Ok((id.parse::<usize>().context("invalid postfilter scorer id")?, snd.parse::<f32>().context("parse postfilter scorer threshold")?))
}).collect::<Result<Vec<_>>>()?;
let pq_codec = if let Some(pq_codec) = args.pq_codec { let pq_codec = if let Some(pq_codec) = args.pq_codec {
let data = fs::read(pq_codec).context("read pq codec")?; let data = fs::read(pq_codec).context("read pq codec")?;
let pq_codec: ProductQuantizer = rmp_serde::from_read(&data[..]).context("decode pq codec")?; let pq_codec: ProductQuantizer = rmp_serde::from_read(&data[..]).context("decode pq codec")?;
@ -255,7 +271,6 @@ fn main() -> Result<()> {
files.sort_by_key(|(id, _)| *id); files.sort_by_key(|(id, _)| *id);
shard_id_mappings.sort_by_key(|(id, _)| *id); shard_id_mappings.sort_by_key(|(id, _)| *id);
let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> { let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
let mut out_vertices: Vec<u32> = vec![]; let mut out_vertices: Vec<u32> = vec![];
let mut shards: Vec<u32> = vec![]; let mut shards: Vec<u32> = vec![];
@ -323,6 +338,8 @@ fn main() -> Result<()> {
let th = std::thread::spawn(move || reader_thread(&args.paths, tx)); let th = std::thread::spawn(move || reader_thread(&args.paths, tx));
let mut postfilter_count = 0;
let mut rng2 = rng.fork(); let mut rng2 = rng.fork();
let initial_filter = |x: ProcessedEntry| { let initial_filter = |x: ProcessedEntry| {
i += 1; i += 1;
@ -337,7 +354,9 @@ fn main() -> Result<()> {
latest_timestamp = latest_timestamp.max(timestamp); latest_timestamp = latest_timestamp.max(timestamp);
earliest_timestamp = earliest_timestamp.min(timestamp); earliest_timestamp = earliest_timestamp.min(timestamp);
for (_name, vec, histogram, threshold) in &mut embeddings { let mut postfilter = false;
for (_name, vec, histogram, threshold, postfilter_threshold) in &mut embeddings {
let dot = SpatialSimilarity::dot(&embedding, vec).unwrap() as f32; let dot = SpatialSimilarity::dot(&embedding, vec).unwrap() as f32;
histogram.add(dot); histogram.add(dot);
if let Some(threshold) = threshold { if let Some(threshold) = threshold {
@ -345,6 +364,12 @@ fn main() -> Result<()> {
return None; return None;
} }
} }
if let Some(threshold) = postfilter_threshold {
if dot >= *threshold {
postfilter = true;
postfilter_count += 1; // somewhat wrong because could be duplicated
}
}
} }
// distance thresholding is too costly to do over a long range so just do it badly // distance thresholding is too costly to do over a long range so just do it badly
@ -393,7 +418,7 @@ fn main() -> Result<()> {
println!("{}", data); println!("{}", data);
} }
Some((x, embedding)) Some((x, embedding, postfilter))
}; };
let mut dead_count = 0; let mut dead_count = 0;
@ -404,14 +429,14 @@ fn main() -> Result<()> {
let batch: Vec<_> = batch.collect(); let batch: Vec<_> = batch.collect();
let batch_len = batch.len(); let batch_len = batch.len();
for (x, _embedding) in batch.iter() { for (x, _embedding, _postfilter) in batch.iter() {
if let Some(ref mut file) = output_file { if let Some(ref mut file) = output_file {
file.write_all(&x.embedding)?; file.write_all(&x.embedding)?;
} }
} }
if let Some(shards) = &mut shards_out { if let Some(shards) = &mut shards_out {
for (i, (x, embedding)) in batch.iter().enumerate() { for (i, (x, embedding, _postfilter)) in batch.iter().enumerate() {
// closest matches first // closest matches first
shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| { shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| {
let mut dot = SpatialSimilarity::dot(&centroid, &embedding).unwrap(); let mut dot = SpatialSimilarity::dot(&centroid, &embedding).unwrap();
@ -439,7 +464,7 @@ fn main() -> Result<()> {
let quantizer = pq_codec.as_ref().context("PQ codec needed to output index")?; let quantizer = pq_codec.as_ref().context("PQ codec needed to output index")?;
let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize); let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize);
for (_x, embedding) in batch.iter() { for (_x, embedding, _postfilter) in batch.iter() {
batch_embeddings.extend_from_slice(&embedding); batch_embeddings.extend_from_slice(&embedding);
} }
let codes = quantizer.quantize_batch(&batch_embeddings); let codes = quantizer.quantize_batch(&batch_embeddings);
@ -448,7 +473,7 @@ fn main() -> Result<()> {
let cdfs = cdfs.as_ref().context("score model CDFs needed to output index")?; let cdfs = cdfs.as_ref().context("score model CDFs needed to output index")?;
let scores = score_model.score_batch(&batch_embeddings)?; let scores = score_model.score_batch(&batch_embeddings)?;
for (i, (x, _embedding)) in batch.into_iter().enumerate() { for (i, (x, _embedding, mut postfilter)) in batch.into_iter().enumerate() {
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
let mut entry_scores = scores[(i * score_model.output_channels)..((i + 1) * score_model.output_channels)].to_vec(); let mut entry_scores = scores[(i * score_model.output_channels)..((i + 1) * score_model.output_channels)].to_vec();
@ -465,6 +490,13 @@ fn main() -> Result<()> {
index_output_file.2.write_all(&[cdf_bucket])?; index_output_file.2.write_all(&[cdf_bucket])?;
} }
for (index, score) in postfilter_scorer.iter() {
if entry_scores[*index] < *score {
postfilter = true;
break;
}
}
let mut entry = PackedIndexEntry { let mut entry = PackedIndexEntry {
id: count + i as u32, id: count + i as u32,
vertices, vertices,
@ -476,9 +508,10 @@ fn main() -> Result<()> {
shards shards
}; };
let mut bytes = bitcode::encode(&entry); let mut bytes = bitcode::encode(&entry);
if bytes.len() > (RECORD_PAD_SIZE - 2) { // as an ugly hack for removing entries already in the index shards, kill the URL and make it a graph node only
if bytes.len() > (RECORD_PAD_SIZE - 2) || postfilter {
// we do need the records to fit in a fixed size and can't really drop things, so discard URL so it can exist as a graph node only // we do need the records to fit in a fixed size and can't really drop things, so discard URL so it can exist as a graph node only
entry.url = String::new(); entry.url = String::new(); // URL is only input-controlled, arbitrary-length field
bytes = bitcode::encode(&entry); bytes = bitcode::encode(&entry);
dead_count += 1; dead_count += 1;
} }
@ -494,11 +527,11 @@ fn main() -> Result<()> {
} }
if args.print_aggregates { if args.print_aggregates {
println!("earliest={} latest={} count={} read={} deduped={}", earliest_timestamp, latest_timestamp, count, i, deduped_count); println!("earliest={} latest={} count={} read={} deduped={} postfiltered={}", earliest_timestamp, latest_timestamp, count, i, deduped_count, postfilter_count);
} }
if let Some(histogram_path) = args.histograms { if let Some(histogram_path) = args.histograms {
let mut file = fs::File::create(histogram_path)?; let mut file = fs::File::create(histogram_path)?;
for (name, _, histogram, _) in &embeddings { for (name, _, histogram, _, _) in &embeddings {
let width = 800.0; let width = 800.0;
let padding = 40.0; let padding = 40.0;
let bars_height = 300 as f64; let bars_height = 300 as f64;

View File

@ -10,14 +10,16 @@ with open("mse_config.json") as f:
def get_embedding(req): def get_embedding(req):
return msgpack.unpackb(requests.post(config["clip_server"], data=msgpack.packb(req)).content) return msgpack.unpackb(requests.post(config["clip_server"], data=msgpack.packb(req)).content)
output, input, *xs = sys.argv[1:] mode, output, input = sys.argv[1:]
with open(output, "wb") as f: with open(output, "wb") as f:
with open(input, "rb") as g: if mode == "image":
input_data = g.read() with open(input, "rb") as g:
if not xs: input_data = g.read()
result = get_embedding({"images": [input_data]})[0] result = get_embedding({"images": [input_data]})[0]
elif mode == "text":
result = get_embedding({"text": input})[0]
else: else:
result = get_embedding({"text": xs})[0] raise Exception("unknown mode")
f.write(result) f.write(result)
print(base64.urlsafe_b64encode(result).decode("ascii")) print(base64.urlsafe_b64encode(result).decode("ascii"))

View File

@ -892,7 +892,7 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
} }
#[instrument(skip(index))] #[instrument(skip(index))]
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> { async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult<()>> {
let result = index.vectors.search(&query, k as usize)?; let result = index.vectors.search(&query, k as usize)?;
let mut seen_videos = HashSet::new(); let mut seen_videos = HashSet::new();
@ -916,7 +916,8 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
index.filenames[id].container_filename(), index.filenames[id].container_filename(),
generate_filename_hash(&index.filenames[id as usize]).clone(), generate_filename_hash(&index.filenames[id as usize]).clone(),
index.format_codes[id], index.format_codes[id],
index.metadata[id].as_ref().map(|x| (x.width, x.height)) index.metadata[id].as_ref().map(|x| (x.width, x.height)),
Option::<()>::None
)) ))
}) })
.collect(); .collect();

View File

@ -7,7 +7,7 @@ use argh::FromArgs;
use itertools::Itertools; use itertools::Itertools;
use foldhash::{HashSet, HashSetExt}; use foldhash::{HashSet, HashSetExt};
use half::f16; use half::f16;
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64}}; use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64, SCALE_F64}};
use simsimd::SpatialSimilarity; use simsimd::SpatialSimilarity;
use memmap2::{Mmap, MmapOptions}; use memmap2::{Mmap, MmapOptions};
use std::rc::Rc; use std::rc::Rc;
@ -18,13 +18,14 @@ use http_body_util::{BodyExt, Empty, Full};
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec}; use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
use std::pin::Pin; use std::pin::Pin;
use std::future::Future; use std::future::Future;
use serde::Serialize; use serde::{Serialize, Deserialize};
use std::str::FromStr; use std::str::FromStr;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Write;
mod common; mod common;
use common::{resize_for_embed_sync, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult}; use common::{resize_for_embed_sync, QueryTerm, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
#[derive(FromArgs, Clone)] #[derive(FromArgs, Clone)]
#[argh(description="Query disk index")] #[argh(description="Query disk index")]
@ -43,10 +44,18 @@ struct CLIArguments {
search_list_size: Option<usize>, search_list_size: Option<usize>,
#[argh(switch, description="always use full-precision vectors (slow)")] #[argh(switch, description="always use full-precision vectors (slow)")]
disable_pq: bool, disable_pq: bool,
#[argh(option, short='l', description="listen address")] #[argh(option, short='c', description="server config file")]
listen_address: Option<String>, config_path: Option<String>
#[argh(option, short='c', description="clip server")] }
clip_server: Option<String>,
#[derive(Deserialize, Clone)]
struct ServerConfig {
listen_address: String,
clip_server: String,
descriptor_names: Vec<String>,
telemetry_file: String,
search_list: usize,
beam_width: usize
} }
lazy_static! { lazy_static! {
@ -82,17 +91,30 @@ fn next_several_unvisited(s: &mut NeighbourBuffer, n: usize) -> Option<Vec<u32>>
} }
} }
const DUPLICATES_THRESHOLD: f32 = 0.95;
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) { fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
let loc = (id as usize) * index.pq_code_size; let loc = (id as usize) * index.pq_code_size;
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size]) buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
} }
struct VisitedNode {
image_url: String,
scores: Vec<f32>,
shards: Vec<u32>,
id: u32,
score: i64,
timestamp: u64,
dimensions: (u32, u32)
}
struct Scratch { struct Scratch {
visited_adjacent: HashSet<u32>, visited_adjacent: HashSet<u32>,
visited: HashSet<u32>, visited: HashSet<u32>,
neighbour_buffer: NeighbourBuffer, neighbour_buffer: NeighbourBuffer,
neighbour_pre_buffer: Vec<u32>, neighbour_pre_buffer: Vec<u32>,
visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)> visited_list: Vec<VisitedNode>,
visited_embeddings: Vec<f32>
} }
struct Index { struct Index {
@ -140,10 +162,20 @@ async fn greedy_search<'a>(scratch: &mut Scratch, start: u32, query: &[f16], que
let index = index.clone(); let index = index.clone();
let node = handle.await?; let node = handle.await?;
let vector = bytemuck::cast_slice(&node.vector); let vector = bytemuck::cast_slice(&node.vector);
let distance = fast_dot_noprefetch(query, &vector); let mut score = fast_dot_noprefetch(query, &vector);
score += descriptor_product(index.clone(), &descriptor_scales, node.id);
cmps += 1; cmps += 1;
if scratch.visited.insert(node.id) { if scratch.visited.insert(node.id) && node.url.len() > 0 {
scratch.visited_list.push((node.id, distance, node.url, node.shards, node.scores)); scratch.visited_list.push(VisitedNode {
image_url: node.url,
scores: node.scores,
shards: node.shards,
id: node.id,
score,
timestamp: node.timestamp,
dimensions: node.dimensions
});
scratch.visited_embeddings.extend(bytemuck::cast_slice(&node.vector).iter().map(|x: &f16| x.to_f32()));
}; };
for &neighbour in node.vertices.iter() { for &neighbour in node.vertices.iter() {
if scratch.visited_adjacent.insert(neighbour) { if scratch.visited_adjacent.insert(neighbour) {
@ -250,7 +282,8 @@ async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)), neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
neighbour_pre_buffer: Vec::new(), neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new(), visited_list: Vec::new(),
visited_adjacent: HashSet::new() visited_adjacent: HashSet::new(),
visited_embeddings: Vec::new()
}; };
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]); let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
@ -265,14 +298,14 @@ async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
println!("index scan {}: {:?} cmps", shard, cmps_result); println!("index scan {}: {:?} cmps", shard, cmps_result);
} }
scratch.visited_list.sort_by_key(|x| -x.1); scratch.visited_list.sort_by_key(|x| -x.score);
for (i, (id, distance, url, shards, scores)) in scratch.visited_list.iter().take(20).enumerate() { for (i, node) in scratch.visited_list.iter().take(20).enumerate() {
let found_id = match matches.binary_search(&(*id, 0)) { let found_id = match matches.binary_search(&(node.id, 0)) {
Ok(pos) => pos, Ok(pos) => pos,
Err(pos) => pos Err(pos) => pos
}; };
if args.verbose { if args.verbose {
println!("index scan: {} {} {} {:?} {:?}; rank {}", id, distance, url, shards, scores, matches[found_id].1 + 1); println!("index scan: {} {} {} {:?} {:?}; rank {}", node.id, node.score, node.image_url, node.shards, node.scores, matches[found_id].1 + 1);
}; };
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1); top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
} }
@ -341,11 +374,23 @@ pub async fn query_clip_server<I, O>(base_url: &str, path: &str, data: Option<I>
Ok(result) Ok(result)
} }
#[derive(Serialize, Deserialize)]
struct TelemetryMessage {
#[serde(rename="correlationId")]
correlation_id: String,
data: serde_json::Value,
event: String,
#[serde(rename="instanceId")]
instance_id: String,
page: String
}
#[derive(Clone)] #[derive(Clone)]
struct Service { struct Service {
index: Rc<Index>, index: Rc<Index>,
inference_server_config: Rc<InferenceServerConfig>, inference_server_config: Rc<InferenceServerConfig>,
args: Rc<CLIArguments> config: Rc<ServerConfig>,
telemetry_channel: std::sync::mpsc::Sender<TelemetryMessage>
} }
impl hyper::service::Service<Request<Incoming>> for Service { impl hyper::service::Service<Request<Incoming>> for Service {
@ -355,15 +400,16 @@ impl hyper::service::Service<Request<Incoming>> for Service {
fn call(&self, req: Request<Incoming>) -> Self::Future { fn call(&self, req: Request<Incoming>) -> Self::Future {
let index = self.index.clone(); let index = self.index.clone();
let args = self.args.clone(); let config = self.config.clone();
let inference_server_config = self.inference_server_config.clone(); let inference_server_config = self.inference_server_config.clone();
let channel = self.telemetry_channel.clone();
Box::pin(async move { Box::pin(async move {
let mut body = match (req.method(), req.uri().path()) { let mut body = match (req.method(), req.uri().path()) {
(&Method::GET, "/") => Response::new(Full::new(Bytes::from(serde_json::to_vec(&FrontendInit { (&Method::GET, "/") => Response::new(Full::new(Bytes::from(serde_json::to_vec(&FrontendInit {
n_total: (index.header.count - index.header.dead_count) as u64, n_total: (index.header.count - index.header.dead_count) as u64,
d_emb: index.header.quantizer.n_dims, d_emb: index.header.quantizer.n_dims,
predefined_embedding_names: vec![] predefined_embedding_names: config.descriptor_names.clone()
})?))), })?))),
(&Method::POST, "/") => { (&Method::POST, "/") => {
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX); let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
@ -381,7 +427,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
&body.terms, &body.terms,
&*inference_server_config, &*inference_server_config,
|batch, _config| { |batch, _config| {
query_clip_server(args.clip_server.as_ref().unwrap(), "/", Some(batch)) query_clip_server(config.clip_server.as_str(), "/", Some(batch))
}, },
|image, config| async move { |image, config| async move {
let image = image::load_from_memory(&image)?; let image = image::load_from_memory(&image)?;
@ -397,27 +443,89 @@ impl hyper::service::Service<Request<Incoming>> for Service {
}).unwrap(); }).unwrap();
let selected_start = index.header.shards[selected_shard].1; let selected_start = index.header.shards[selected_shard].1;
let beamwidth = 3; let beamwidth = config.beam_width;
let mut scratch = Scratch { let mut scratch = Scratch {
visited: HashSet::new(), visited: HashSet::new(),
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)), neighbour_buffer: NeighbourBuffer::new(config.search_list),
neighbour_pre_buffer: Vec::new(), neighbour_pre_buffer: Vec::new(),
visited_list: Vec::new(), visited_list: Vec::new(),
visited_adjacent: HashSet::new() visited_adjacent: HashSet::new(),
visited_embeddings: Vec::new()
}; };
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]); let mut desc = vec![0.0, 0.0, 0.0, 0.0];
for term in &body.terms {
if let Some(name) = &term.predefined_embedding {
if let Some(index) = config.descriptor_names.iter().position(|x| x == name) {
desc[index] = term.weight.unwrap_or(1.0) * 1.0/512.0;
}
}
}
let descriptor_scales = DescriptorScales(desc);
let query_preprocessed = index.header.quantizer.preprocess_query(&query); let query_preprocessed = index.header.quantizer.preprocess_query(&query);
let query = query.iter().map(|x| half::f16::from_f32(*x)).collect::<Vec<f16>>(); let query = query.iter().map(|x| half::f16::from_f32(*x)).collect::<Vec<f16>>();
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), args.disable_pq, beamwidth).await?; let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), false, beamwidth).await?;
scratch.visited_list.sort_by_key(|x| -x.1); let n_visited = scratch.visited_list.len();
let matches = scratch.visited_list.drain(..).map(|(id, score, url, shards, scores)| (score as f32, url, String::new(), 0, None)).collect::<Vec<_>>(); let mut similarities_against_self = vec![0.0f32; n_visited * n_visited];
// runtime deduplicate of results list
unsafe {
// vecs @ vecs.T
matrixmultiply::sgemm(
n_visited,
index.header.quantizer.n_dims,
n_visited,
1.0,
scratch.visited_embeddings.as_ptr(),
index.header.quantizer.n_dims as isize,
1,
scratch.visited_embeddings.as_ptr(),
1,
index.header.quantizer.n_dims as isize,
0.0,
similarities_against_self.as_mut_ptr(),
n_visited as isize,
1
);
}
// discard anything similar to something already in list
let mut i = 0;
let mut included = bitvec::bitvec![0; n_visited];
scratch.visited_list.retain(|_node| {
let row = &similarities_against_self[(i * n_visited)..((i + 1) * n_visited)];
let old_i = i;
i += 1;
for (other_i, similarity) in row.iter().enumerate() {
if similarity > &DUPLICATES_THRESHOLD && included[other_i] {
return false;
}
}
included.set(old_i, true);
true
});
scratch.visited_list.sort_unstable_by_key(|x| -x.score);
let matches = scratch.visited_list
.drain(..)
.map(|node| {
let debug = if body.debug_enabled {
Some((node.scores, node.shards, node.timestamp))
} else {
None
};
((node.score as f64 / SCALE_F64) as f32, node.image_url, String::new(), 0, Some(node.dimensions), debug)
})
.collect::<Vec<_>>();
let result = QueryResult { let result = QueryResult {
formats: vec![], formats: vec![],
@ -438,6 +546,25 @@ impl hyper::service::Service<Request<Incoming>> for Service {
.header(hyper::header::CONTENT_TYPE, "text/plain; version=0.0.4") .header(hyper::header::CONTENT_TYPE, "text/plain; version=0.0.4")
.body(Full::new(Bytes::from(buffer))).unwrap() .body(Full::new(Bytes::from(buffer))).unwrap()
}, },
(&Method::POST, "/telemetry") => {
// TODO refactor
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
if upper > 1000 {
let mut resp = Response::new(Full::new(Bytes::from("Body too big")));
*resp.status_mut() = hyper::StatusCode::PAYLOAD_TOO_LARGE;
return Ok(resp);
}
let whole_body = req.collect().await?.to_bytes();
let message = serde_json::from_slice::<TelemetryMessage>(&whole_body)?;
channel.send(message)?;
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Full::new(Bytes::from(""))).unwrap()
}
(&Method::OPTIONS, "/") => { (&Method::OPTIONS, "/") => {
Response::builder() Response::builder()
.status(StatusCode::NO_CONTENT) .status(StatusCode::NO_CONTENT)
@ -459,9 +586,9 @@ impl hyper::service::Service<Request<Incoming>> for Service {
} }
} }
async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceServerConfig> { async fn get_backend_config(clip_server: &String) -> Result<InferenceServerConfig> {
loop { loop {
match query_clip_server(clip_server.as_ref().unwrap(), "/config", Option::<()>::None).await { match query_clip_server(clip_server, "/config", Option::<()>::None).await {
Ok(config) => return Ok(config), Ok(config) => return Ok(config),
Err(err) => { Err(err) => {
tracing::warn!("waiting for clip server: {}", err); tracing::warn!("waiting for clip server: {}", err);
@ -471,14 +598,31 @@ async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceSer
} }
} }
// can't run this as an async task because monoio File API is positional writes only
fn telemetry_handler(rx: std::sync::mpsc::Receiver<TelemetryMessage>, config: ServerConfig) -> Result<()> {
let mut telemetry_file = std::fs::OpenOptions::new().create(true).create(true).append(true).open(&config.telemetry_file)?;
while let Ok(message) = rx.recv() {
telemetry_file.write_all(rmp_serde::to_vec(&message)?.as_slice())?;
}
Ok(())
}
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> { async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
let config: ServerConfig = serde_json::from_slice(&std::fs::read(args.config_path.as_ref().unwrap())?)?;
let (telemetry_channel, telemetry_receiver) = std::sync::mpsc::channel();
let config_ = config.clone();
std::thread::spawn(move || telemetry_handler(telemetry_receiver, config_));
let service = Service { let service = Service {
index, index,
inference_server_config: Rc::new(get_backend_config(&args.clip_server).await?), inference_server_config: Rc::new(get_backend_config(&config.clip_server).await?),
args: Rc::new(args.clone()) config: Rc::new(config.clone()),
telemetry_channel
}; };
let listener = TcpListener::bind(args.listen_address.as_ref().unwrap())?; let listener = TcpListener::bind(config.listen_address)?;
println!("Listening"); println!("Listening");
loop { loop {
let (stream, _) = listener.accept().await?; let (stream, _) = listener.accept().await?;
@ -531,7 +675,7 @@ async fn main() -> Result<()> {
n_descriptors: header.descriptor_cdfs.len(), n_descriptors: header.descriptor_cdfs.len(),
}); });
if args.listen_address.is_some() { if args.config_path.is_some() {
serve(&args, index).await?; serve(&args, index).await?;
} else { } else {
evaluate(&args, index).await?; evaluate(&args, index).await?;

File diff suppressed because one or more lines are too long