mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-26 20:53:09 +00:00
release version
This commit is contained in:
parent
3852d0078d
commit
ee23b81444
5
.gitignore
vendored
5
.gitignore
vendored
@ -22,3 +22,8 @@ index
|
||||
queries.txt
|
||||
*.zst
|
||||
.safetensors
|
||||
*/static/*.woff2
|
||||
flamegraph.svg
|
||||
*.jsonl
|
||||
*.safetensors
|
||||
perf.data
|
||||
|
41
Cargo.lock
generated
41
Cargo.lock
generated
@ -416,6 +416,18 @@ version = "2.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
|
||||
dependencies = [
|
||||
"funty",
|
||||
"radium",
|
||||
"tap",
|
||||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
@ -1117,6 +1129,12 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "funty"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.31"
|
||||
@ -2046,6 +2064,7 @@ dependencies = [
|
||||
"axum",
|
||||
"base64 0.22.1",
|
||||
"bitcode",
|
||||
"bitvec",
|
||||
"bytemuck",
|
||||
"candle-core",
|
||||
"chrono",
|
||||
@ -2067,6 +2086,7 @@ dependencies = [
|
||||
"itertools 0.13.0",
|
||||
"json5",
|
||||
"lazy_static",
|
||||
"matrixmultiply",
|
||||
"maud",
|
||||
"memmap2",
|
||||
"mimalloc",
|
||||
@ -2850,6 +2870,12 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "radium"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
@ -3873,6 +3899,12 @@ dependencies = [
|
||||
"version-compare",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tap"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.16"
|
||||
@ -4732,6 +4764,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wyz"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
|
||||
dependencies = [
|
||||
"tap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.7.5"
|
||||
|
@ -62,6 +62,8 @@ monoio = "0.2"
|
||||
hyper = "1"
|
||||
monoio-compat = { version = "0.2", features = ["hyper"] }
|
||||
http-body-util = "0.1"
|
||||
matrixmultiply = "0.3"
|
||||
bitvec = "1"
|
||||
|
||||
[[bin]]
|
||||
name = "reddit-dump"
|
||||
|
@ -9,15 +9,15 @@
|
||||
|
||||
<div class="about">
|
||||
<p>
|
||||
Welcome to {util.hardConfig.name} by <a href="https://osmarks.net/">osmarks.net Computational Memetics</a>. {util.hardConfig.name} searches images using semantic image/text embedding models. In general, search by thinking of what caption your desired image might have been given by random people on the internet. The model currently in use can read text fairly well and understands moderately abstract properties of images, but is limited to English and case-insensitive.
|
||||
Welcome to {util.hardConfig.name} by <a href="https://osmarks.net/">osmarks.net</a> Computational Memetics. {util.hardConfig.name} searches images using semantic image/text embedding models (refer to <a href="https://arxiv.org/abs/2303.15343">https://arxiv.org/abs/2303.15343</a>, <a href="https://arxiv.org/abs/2103.00020">https://arxiv.org/abs/2103.00020</a>). In general, search by thinking of what caption your desired image might have been given by random people on the internet. The model currently in use can read text fairly well and understands moderately abstract properties of images, but is limited to English and case-insensitive.
|
||||
</p>
|
||||
<p>
|
||||
Advanced Mode sliders are generated from PCA on the index. The human-readable labels are generated manually by <a href="https://datasets.osmarks.net/components.html">looking at things</a>.
|
||||
"Useful"/"aesthetic"/"meme" sliders are defined based on an approximate extrapolation of my preferences and may not agree with your own opinions. Results are otherwise decided by the inscrutable whims of a billion-parameter neural network. Please note that large-scale internet data may contain things you do not like.
|
||||
</p>
|
||||
<p>
|
||||
The code is open-source and available on <a href="https://github.com/osmarks/meme-search-engine/">GitHub.</a>
|
||||
The code is open-source and available on <a href="https://github.com/osmarks/meme-search-engine/">GitHub</a>.
|
||||
</p>
|
||||
{#if util.hardConfig.telemetryEndpoint}
|
||||
{#if util.hardConfig.telemetry_endpoint}
|
||||
<h2>Privacy</h2>
|
||||
<p>
|
||||
We do not collect personal information. We do collect usage information (associated with a random ID) to improve the ranking algorithms. You can disable this:
|
||||
|
@ -154,7 +154,6 @@
|
||||
<div class="right">
|
||||
<NavItem page="advanced">Advanced</NavItem>
|
||||
<NavItem page="about">About</NavItem>
|
||||
<NavItem page="refine">Refine</NavItem>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
@ -185,7 +184,11 @@
|
||||
<option>+</option>
|
||||
<option>-</option>
|
||||
</select>
|
||||
<input type="range" min="0" max="2" bind:value={term.weight} step="0.01">
|
||||
{#if debugEnabled}
|
||||
<input type="number" bind:value={term.weight} step="0.01">
|
||||
{:else}
|
||||
<input type="range" min="0" max="2" bind:value={term.weight} step="0.01">
|
||||
{/if}
|
||||
{#if term.type === "image"}
|
||||
<span>{term.file.name}</span>
|
||||
{:else if term.type === "text"}
|
||||
@ -200,6 +203,14 @@
|
||||
</li>
|
||||
{/each}
|
||||
</ul>
|
||||
{#if showDebugSwitch}
|
||||
<div class="ctrlbar">
|
||||
<input type="checkbox" bind:checked={debugEnabled} id="debug" />
|
||||
<label for="debug">Debug</label>
|
||||
{#if debugEnabled}
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="ctrlbar">
|
||||
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
|
||||
<button on:click={pickFile}>Image Query</button>
|
||||
@ -242,6 +253,9 @@
|
||||
let queryCounter = 0
|
||||
let config = {}
|
||||
|
||||
const showDebugSwitch = localStorage.getItem("debugEnabled") === "true"
|
||||
let debugEnabled = false
|
||||
|
||||
const newTextQuery = (content=null) => {
|
||||
queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" })
|
||||
queryTerms = queryTerms
|
||||
@ -253,9 +267,14 @@
|
||||
|
||||
const runSearch = async () => {
|
||||
if (!resultPromise) {
|
||||
const friendlyModeQueryOrRandom = friendlyModeQuery ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : [{ embedding: util.randn(config.d_emb, 1 / (config.d_emb ** 0.5)), weight: 1, sign: "+" }]
|
||||
const terms = friendlyMode ?
|
||||
friendlyModeQueryOrRandom.concat(util.hardConfig.friendly_mode_default_terms ?? []) :
|
||||
queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))
|
||||
let args = {
|
||||
"terms": friendlyMode ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] })),
|
||||
"include_video": true
|
||||
"terms": terms,
|
||||
"include_video": true,
|
||||
"debug_enabled": debugEnabled
|
||||
}
|
||||
|
||||
util.sendTelemetry("search", {
|
||||
@ -314,7 +333,7 @@
|
||||
type: "predefined_embedding",
|
||||
predefinedEmbedding: predefinedEmbeddingName,
|
||||
sign: "+",
|
||||
weight: 0.2
|
||||
weight: 1
|
||||
})
|
||||
}
|
||||
queryTerms = queryTerms
|
||||
|
@ -36,26 +36,14 @@
|
||||
|
||||
const d_emb = 1152
|
||||
|
||||
const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
|
||||
const vecZero = d => new Array(d).fill(0)
|
||||
const vecScale = (xs, s) => xs.map(x => x * s)
|
||||
|
||||
const boxMuller = () => {
|
||||
let x = Math.random()
|
||||
let y = Math.random()
|
||||
return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
|
||||
}
|
||||
|
||||
const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)
|
||||
|
||||
const K = 2
|
||||
let candidates = []
|
||||
|
||||
const select = async candidate => {
|
||||
candidates = []
|
||||
const direction = randn(d_emb, 1 / d_emb)
|
||||
const direction = util.randn(d_emb, 1 / d_emb)
|
||||
for (let i = -K; i <= K; i++) {
|
||||
const newV = vecSum(vecScale(direction, i / K), candidate.vector)
|
||||
const newV = util.vecSum(util.vecScale(direction, i / K), candidate.vector)
|
||||
candidates.push({ vector: newV, results: null, i: i + K })
|
||||
}
|
||||
await Promise.all(candidates.map(async x => {
|
||||
@ -67,7 +55,7 @@
|
||||
console.log(candidates)
|
||||
}
|
||||
|
||||
select({ vector: randn(d_emb, 1 / d_emb) })
|
||||
select({ vector: util.randn(d_emb, 1 / d_emb) })
|
||||
|
||||
const handleKey = ev => {
|
||||
const num = parseInt(ev.key)
|
||||
@ -75,4 +63,4 @@
|
||||
select(candidates[num - 1])
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</script>
|
||||
|
@ -28,6 +28,10 @@
|
||||
{#key `${queryCounter}${result.file}`}
|
||||
<div class="result">
|
||||
<ResultImage {result} {results} {updateCounter} {redrawGrid} constrainBy="width" />
|
||||
{#if result[5]} <!-- debug info -->
|
||||
<div>{result[0]}</div>
|
||||
<div>{JSON.stringify(result[5])}</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/key}
|
||||
{/each}
|
||||
@ -72,7 +76,7 @@
|
||||
export const redrawGrid = async () => {
|
||||
if (refreshLayout) {
|
||||
refreshLayout()
|
||||
await recomputeScroll()
|
||||
setTimeout(recomputeScroll, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,7 +86,7 @@
|
||||
}
|
||||
|
||||
const handleScroll = () => {
|
||||
if (window.scrollY + window.innerHeight >= heightThreshold && pendingImageLoads === 0) {
|
||||
if (window.scrollY + window.innerHeight >= heightThreshold) {
|
||||
recomputeScroll()
|
||||
if (window.scrollY + window.innerHeight < heightThreshold) return;
|
||||
let init = displayedResults.length
|
||||
|
@ -1,4 +1,11 @@
|
||||
@use 'sass:color'
|
||||
|
||||
@font-face
|
||||
font-family: 'Iosevka'
|
||||
font-style: normal
|
||||
font-weight: 400
|
||||
src: url(./iosevka.woff2) format('woff2')
|
||||
unicode-range: U+0000-00C0, U+A2-A9, U+AC-AE, U+00D7, U+00F7, U+FEFF, U+FFFD
|
||||
|
||||
$palette-primary: #3f9b0b
|
||||
$palette-secondary: #033500
|
||||
|
BIN
clipfront2/src/iosevka.woff2
Normal file
BIN
clipfront2/src/iosevka.woff2
Normal file
Binary file not shown.
@ -78,3 +78,15 @@ fetch(config.backend_url).then(x => x.json().then(x => {
|
||||
serverConfig.set(x)
|
||||
window.serverConfig = x
|
||||
}))
|
||||
|
||||
export const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
|
||||
export const vecZero = d => new Array(d).fill(0)
|
||||
export const vecScale = (xs, s) => xs.map(x => x * s)
|
||||
|
||||
const boxMuller = () => {
|
||||
let x = Math.random()
|
||||
let y = Math.random()
|
||||
return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
|
||||
}
|
||||
|
||||
export const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)
|
||||
|
@ -5,7 +5,7 @@
|
||||
<meta name="viewport" content="width=device-width, minimum-scale=1, initial-scale=1, user-scalable=yes">
|
||||
<meta name="description" content="Organizing the world's memes.">
|
||||
|
||||
<title>Meme Search Engine</title>
|
||||
<title>Nooscope</title>
|
||||
</style>
|
||||
<link rel="stylesheet" href="app.css">
|
||||
<link rel="icon" type="image/png" href="logo.png">
|
||||
|
10
config2.json
Normal file
10
config2.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"listen_address": "127.0.0.1:5601",
|
||||
"clip_server": "http://100.64.0.10:1708",
|
||||
"descriptor_names": [
|
||||
"Useful",
|
||||
"Meme",
|
||||
"Aesthetic",
|
||||
"Time"
|
||||
]
|
||||
}
|
164
diskann/chainq.py
Normal file
164
diskann/chainq.py
Normal file
@ -0,0 +1,164 @@
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import math
|
||||
import torch
|
||||
from torch import autograd
|
||||
import tqdm
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# https://github.com/una-dinosauria/local-search-quantization/blob/master/src/encodings/encode_chain.jl#L2
|
||||
# vectorized somewhat
|
||||
def viterbi_encode(
|
||||
out_codes: torch.Tensor, # N x M (ints)
|
||||
vectors: torch.Tensor, # N x D
|
||||
codebooks: torch.Tensor # M x H x D
|
||||
):
|
||||
N, D = vectors.shape
|
||||
M, H, D2 = codebooks.shape
|
||||
assert D == D2
|
||||
|
||||
# M x H x N - ||x-c||^2 ignoring x.T @ x component
|
||||
unary_costs = -2 * (codebooks @ vectors.T) + (torch.linalg.norm(codebooks, dim=2) ** 2).unsqueeze(-1)
|
||||
binary_costs = torch.zeros(M - 1, H, H, dtype=torch.float, device=DEVICE)
|
||||
for i in range(M - 1):
|
||||
binary_costs[i] = 2 * codebooks[i] @ codebooks[i + 1].T
|
||||
print("binary costs", binary_costs)
|
||||
|
||||
min_cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
|
||||
min_idx = torch.zeros(M, H, N, dtype=torch.int, device=DEVICE)
|
||||
|
||||
cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
|
||||
|
||||
# forward pass - propagate optimal costs and indices forward
|
||||
for step in tqdm.trange(M - 1):
|
||||
if step > 0:
|
||||
unary_costs[step] += min_cost
|
||||
|
||||
ucost = unary_costs[step]
|
||||
|
||||
# for all possible costs at this step
|
||||
for j in range(H):
|
||||
bcost = binary_costs[step, j].unsqueeze(-1) # independent of N
|
||||
cost = ucost + bcost
|
||||
|
||||
min_values, min_indices = torch.min(cost, dim=0)
|
||||
min_cost[j] = min_values
|
||||
min_idx[step, j] = min_indices
|
||||
|
||||
unary_costs[-1] += min_cost
|
||||
|
||||
# backward pass - propagate optimal indices backwards
|
||||
out_codes[:, -1] = torch.argmin(unary_costs[-1], dim=0)
|
||||
for i in range(M - 2, -1, -1):
|
||||
out_codes[:, i] = min_idx[i][out_codes[:, i + 1], range(N)]
|
||||
|
||||
def dims_for(dim, total_dims, m):
|
||||
dims_per_code = total_dims // m
|
||||
relevant_codebooks = [dim // dims_per_code]
|
||||
if relevant_codebooks[-1] < m - 1:
|
||||
relevant_codebooks.append(relevant_codebooks[-1] + 1)
|
||||
return relevant_codebooks
|
||||
|
||||
def update_codebooks(transformed_data, codes, h):
|
||||
n, d = transformed_data.shape
|
||||
n2, m = codes.shape
|
||||
assert n == n2
|
||||
|
||||
new_codebook = torch.zeros(m, h, d, dtype=torch.float, device=DEVICE)
|
||||
for dim in tqdm.trange(d):
|
||||
relevant_codebooks = dims_for(dim, d, m)
|
||||
assignment_matrix = torch.zeros(n, len(relevant_codebooks), h, dtype=torch.float, device=DEVICE)
|
||||
indices = (
|
||||
torch.arange(n, dtype=torch.int, device=DEVICE).repeat(len(relevant_codebooks)),
|
||||
torch.arange(len(relevant_codebooks), dtype=torch.int, device=DEVICE).repeat_interleave(n),
|
||||
codes[:, relevant_codebooks].T.flatten()
|
||||
)
|
||||
assignment_matrix[indices] = 1
|
||||
#print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
|
||||
assignment_matrix = assignment_matrix.reshape(n, len(relevant_codebooks) * h)
|
||||
#print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
|
||||
#soln = torch.linalg.lstsq(assignment_matrix, transformed_data[:, dim])[0]
|
||||
|
||||
reg = 1e-3 * torch.eye(len(relevant_codebooks) * h, device=DEVICE)
|
||||
A = assignment_matrix.T @ assignment_matrix + reg
|
||||
b = assignment_matrix.T @ transformed_data[:, dim]
|
||||
|
||||
#print("matrix", A)
|
||||
|
||||
usage = assignment_matrix.sum(dim=0)
|
||||
unused = usage < 1
|
||||
|
||||
print(unused.sum().detach().item())
|
||||
soln = torch.linalg.solve(A, b)
|
||||
|
||||
#print("solution", soln.reshape(len(relevant_codebooks), h))
|
||||
if unused.any():
|
||||
soln[unused] = torch.randn_like(soln[unused])
|
||||
|
||||
new_codebook[relevant_codebooks, :, dim] = soln.reshape(len(relevant_codebooks), h)
|
||||
|
||||
if torch.isnan(new_codebook[relevant_codebooks, :, dim]).any():
|
||||
print("oh no", dim, new_codebook, relevant_codebooks, new_codebook[relevant_codebooks, :, dim])
|
||||
print("--- dim ---", dim)
|
||||
print("- sum per column:", assignment_matrix.sum(dim=0)) # Check if any columns are all zero
|
||||
print("- rank:", torch.linalg.matrix_rank(assignment_matrix))
|
||||
print("- condition number:", torch.linalg.cond(assignment_matrix))
|
||||
raise SystemExit
|
||||
|
||||
return new_codebook
|
||||
|
||||
BATCH = 8192
|
||||
|
||||
def train_chainq(vectors, m, h, transform, codebooks, n_iters):
|
||||
for i in range(n_iters):
|
||||
transformed_data = vectors @ transform.T
|
||||
codes = torch.zeros(vectors.shape[0], m, dtype=torch.int, device=DEVICE)
|
||||
for i in range(0, vectors.shape[0], BATCH):
|
||||
viterbi_encode(codes[i:i+BATCH], transformed_data[i:i+BATCH], codebooks)
|
||||
print("encoded")
|
||||
#codebooks = update_codebooks(transformed_data, codes, h)
|
||||
print("codebooks updated")
|
||||
|
||||
quantized = torch.zeros_like(vectors, dtype=torch.float, device=DEVICE)
|
||||
for j in range(m):
|
||||
quantized[:] += codebooks[j, codes[:, j]]
|
||||
print("quantized")
|
||||
print((quantized - transformed_data).abs().mean(), transformed_data.abs().mean())
|
||||
|
||||
print("comparing")
|
||||
res = transformed_data.T @ quantized
|
||||
|
||||
print("running SVD...")
|
||||
u, s, vt = torch.linalg.svd(res)
|
||||
print("done.")
|
||||
transform = u @ vt
|
||||
print("regenerated transform")
|
||||
|
||||
return codebooks, transform
|
||||
|
||||
with open("opq.msgpack", "rb") as f:
|
||||
data = msgpack.unpackb(f.read())
|
||||
|
||||
n_dims = 1152
|
||||
dataset = torch.tensor(np.random.permutation(np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32))[:BATCH*1], device=DEVICE)
|
||||
|
||||
codebooks = torch.zeros(64, 256, n_dims, dtype=torch.float, device=DEVICE)
|
||||
|
||||
centroids = torch.tensor(np.array(data["centroids"]).astype(np.float32).reshape(256, n_dims), device=DEVICE)
|
||||
for dim in range(n_dims):
|
||||
relevant_codebooks = dim // 64
|
||||
codebooks[relevant_codebooks, :, dim] = centroids[:, dim]
|
||||
|
||||
print(centroids)
|
||||
#print("codebooks", codebooks.tolist())
|
||||
|
||||
codebooks, transform = train_chainq(dataset, 64, 256, torch.tensor(np.array(data["transform"]).astype(np.float32).reshape(n_dims, n_dims), device=DEVICE), codebooks, 100)
|
||||
|
||||
with open("chainq.msgpack", "wb") as f:
|
||||
msgpack.pack({
|
||||
"codebooks": codebooks.cpu().numpy().flatten().tolist(),
|
||||
"transform": transform.cpu().numpy().flatten().tolist(),
|
||||
"n_dims": n_dims,
|
||||
"n_dims_per_code": n_dims // 64
|
||||
}, f)
|
@ -45,7 +45,7 @@ impl Vector {
|
||||
}
|
||||
|
||||
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
|
||||
pub const SCALE: f32 = 1099511627776.0;
|
||||
pub const SCALE: f32 = 4294967296.0;
|
||||
pub const SCALE_F64: f64 = SCALE as f64;
|
||||
|
||||
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 {
|
||||
|
131
faiss_bench_quantizer.py
Normal file
131
faiss_bench_quantizer.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
import faiss
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
def eval_codec(q, xb):
|
||||
t0 = time.time()
|
||||
codes = q.compute_codes(xb)
|
||||
t1 = time.time()
|
||||
xb_decoded = q.decode(codes)
|
||||
recons_err = ((xb - xb_decoded) ** 2).sum(axis=1).mean()
|
||||
print(f"\tencode time: {t1 - t0:.3f} reconstruction error: {recons_err:.3f} ")
|
||||
|
||||
|
||||
def eval_quantizer(q, xb, xt, variants=None):
|
||||
if variants is None:
|
||||
variants = [(None, None)]
|
||||
t0 = time.time()
|
||||
q.train(xt)
|
||||
t1 = time.time()
|
||||
train_t = t1 - t0
|
||||
print(f'\ttraining time: {train_t:.3f} s')
|
||||
for name, val in variants:
|
||||
if name is not None:
|
||||
print(f"{name}={val}")
|
||||
|
||||
if isinstance(q, faiss.ProductAdditiveQuantizer):
|
||||
for i in range(q.nsplits):
|
||||
subq = faiss.downcast_Quantizer(q.subquantizer(i))
|
||||
getattr(subq, name)
|
||||
setattr(subq, name, val)
|
||||
else:
|
||||
getattr(q, name) # make sure field exists
|
||||
setattr(q, name, val)
|
||||
|
||||
eval_codec(q, xb)
|
||||
|
||||
|
||||
todo = sys.argv[1:]
|
||||
|
||||
ds = np.fromfile(todo[0], dtype=np.float16).reshape(-1, 1152).astype(np.float32)
|
||||
print(ds)
|
||||
del todo[0]
|
||||
|
||||
if len(todo) > 0:
|
||||
if todo[0].count("x") == 1:
|
||||
M, nbits = [int(x) for x in todo[0].split("x")]
|
||||
del todo[0]
|
||||
elif todo[0].count("x") == 2:
|
||||
nsplits, Msub, nbits = [int(x) for x in todo[0].split("x")]
|
||||
M = nsplits * Msub
|
||||
del todo[0]
|
||||
|
||||
maxtrain = max(100 << nbits, 10**5)
|
||||
print(f"eval on {M}x{nbits} maxtrain={maxtrain}")
|
||||
|
||||
xb, xt = ds, ds
|
||||
|
||||
nb, d = xb.shape
|
||||
nt, d = xt.shape
|
||||
|
||||
|
||||
# fastest to slowest
|
||||
|
||||
if 'lsq-gpu' in todo:
|
||||
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
||||
ngpus = faiss.get_num_gpus()
|
||||
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
|
||||
lsq.verbose = True
|
||||
eval_quantizer(lsq, xb, xt, 'lsq-gpu')
|
||||
|
||||
if 'pq' in todo:
|
||||
pq = faiss.ProductQuantizer(d, M, nbits)
|
||||
print("===== PQ")
|
||||
eval_quantizer(pq, xb, xt)
|
||||
|
||||
if 'opq' in todo:
|
||||
d2 = ((d + M - 1) // M) * M
|
||||
print("OPQ d2=", d2)
|
||||
opq = faiss.OPQMatrix(d, M, d2)
|
||||
opq.train(xt)
|
||||
xb2 = opq.apply(xb)
|
||||
xt2 = opq.apply(xt)
|
||||
pq = faiss.ProductQuantizer(d2, M, nbits)
|
||||
print("===== PQ")
|
||||
eval_quantizer(pq, xb2, xt2)
|
||||
|
||||
if 'prq' in todo:
|
||||
print(f"===== PRQ{nsplits}x{Msub}x{nbits}")
|
||||
prq = faiss.ProductResidualQuantizer(d, nsplits, Msub, nbits)
|
||||
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
|
||||
eval_quantizer(prq, xb, xt, variants=variants)
|
||||
|
||||
if 'plsq' in todo:
|
||||
print(f"===== PLSQ{nsplits}x{Msub}x{nbits}")
|
||||
plsq = faiss.ProductLocalSearchQuantizer(d, nsplits, Msub, nbits)
|
||||
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
|
||||
eval_quantizer(plsq, xb, xt, variants=variants)
|
||||
|
||||
if 'rq' in todo:
|
||||
print("===== RQ")
|
||||
rq = faiss.ResidualQuantizer(d, M, nbits, )
|
||||
rq.max_beam_size
|
||||
rq.max_beam_size = 30 # for compatibility with older runs
|
||||
# rq.train_type = faiss.ResidualQuantizer.Train_default
|
||||
# rq.verbose = True
|
||||
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
|
||||
eval_quantizer(rq, xb, xt, variants=variants)
|
||||
|
||||
if 'rq_lut' in todo:
|
||||
print("===== RQ")
|
||||
rq = faiss.ResidualQuantizer(d, M, nbits, )
|
||||
rq.max_beam_size
|
||||
rq.max_beam_size = 30 # for compatibility with older runs
|
||||
rq.use_beam_LUT
|
||||
rq.use_beam_LUT = 1
|
||||
# rq.train_type = faiss.ResidualQuantizer.Train_default
|
||||
# rq.verbose = True
|
||||
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32, 64)]
|
||||
eval_quantizer(rq, xb, xt, variants=variants)
|
||||
|
||||
if 'lsq' in todo:
|
||||
print("===== LSQ")
|
||||
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
||||
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
|
||||
eval_quantizer(lsq, xb, xt, variants=variants)
|
@ -1,8 +1,12 @@
|
||||
{
|
||||
"backend_url": "http://localhost:5601/",
|
||||
"backend_url": "/backend",
|
||||
"image_path": "",
|
||||
"thumb_path": null,
|
||||
"description": "Organizing the world's memes.",
|
||||
"name": "Nooscope",
|
||||
"telemetry_endpoint": "/telemetry"
|
||||
"telemetry_endpoint": "/telemetry",
|
||||
"friendly_mode_default_terms": [
|
||||
{ "predefined_embedding": "Meme", "weight": 1, "sign": "+" },
|
||||
{ "predefined_embedding": "Aesthetic", "weight": 0.2, "sign": "+" }
|
||||
]
|
||||
}
|
||||
|
4
genseahash.py
Normal file
4
genseahash.py
Normal file
@ -0,0 +1,4 @@
|
||||
import seahash, sys
|
||||
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
print(seahash.hash(f.read()))
|
@ -31,6 +31,7 @@ model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
model.eval()
|
||||
|
||||
files = shared.fetch_all_files()
|
||||
variance = {}
|
||||
@ -56,4 +57,4 @@ with torch.inference_mode():
|
||||
|
||||
top = sorted(variance.items(), key=lambda x: -x[1])
|
||||
with open("top.json", "w") as f:
|
||||
json.dump(top[:100], f)
|
||||
json.dump(top[:50], f)
|
||||
|
@ -52,7 +52,7 @@ channel = int(sys.argv[2])
|
||||
percentile = float(sys.argv[3])
|
||||
output_pairs = int(sys.argv[4])
|
||||
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
|
||||
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
|
||||
top = sorted(((filename, score) for filename, score in results.items()), key=lambda x: x[1][channel], reverse=True)
|
||||
select_from = top[:int(len(top) * percentile)]
|
||||
|
||||
out = []
|
||||
|
@ -13,14 +13,14 @@ import sys
|
||||
from model import Config, BradleyTerry
|
||||
import shared
|
||||
|
||||
batch_size = 128
|
||||
batch_size = 32
|
||||
num_pairs = batch_size * 1024
|
||||
device = "cuda"
|
||||
|
||||
config = Config(
|
||||
d_emb=1152,
|
||||
n_hidden=1,
|
||||
n_ensemble=1,
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
output_channels=3,
|
||||
@ -32,6 +32,7 @@ model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
model.eval()
|
||||
|
||||
files = shared.fetch_all_files()
|
||||
importance = {}
|
||||
@ -41,8 +42,8 @@ buffers = {k: v.detach() for k, v in model.named_buffers()}
|
||||
|
||||
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
|
||||
def compute_loss(params, buffers, sample, target):
|
||||
batch = sample.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
batch = sample.unsqueeze(1)
|
||||
targets = target.unsqueeze(1).expand((config.n_ensemble, 1, config.output_channels))
|
||||
|
||||
predictions = functional_call(model, (params, buffers), (batch,))
|
||||
loss = F.binary_cross_entropy(predictions, targets)
|
||||
@ -75,4 +76,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
|
||||
|
||||
top = sorted(importance.items(), key=lambda x: -x[1])
|
||||
with open("top.json", "w") as f:
|
||||
json.dump(top[:256], f)
|
||||
json.dump(top[:50], f)
|
||||
|
@ -31,6 +31,11 @@ params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
|
||||
torch.random.manual_seed(1)
|
||||
|
||||
for x in model.ensemble.models:
|
||||
x.output.bias.data.fill_(0)
|
||||
|
||||
out_layers = []
|
||||
out_bias = []
|
||||
|
||||
@ -50,7 +55,7 @@ with torch.inference_mode():
|
||||
downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone()
|
||||
|
||||
for i in range(10):
|
||||
input = torch.randn(3, config.d_emb)
|
||||
input = torch.randn(4, config.d_emb)
|
||||
ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T
|
||||
r_result = input
|
||||
for (layer, bias) in zip(out_layers, out_bias):
|
||||
@ -58,13 +63,14 @@ with torch.inference_mode():
|
||||
print(r_result.shape, bias.shape)
|
||||
r_result = F.silu(r_result)
|
||||
r_result = torch.matmul(downprojection, r_result) / config.n_ensemble
|
||||
bias_diff = torch.mean(r_result - ground_truth_result)
|
||||
model_issue_diff = torch.max(torch.abs(r_result - bias_diff - ground_truth_result))
|
||||
print(model_issue_diff)
|
||||
error = torch.mean(r_result - ground_truth_result)
|
||||
print(error)
|
||||
assert error.detach().cpu().numpy() < 1e-4
|
||||
|
||||
print("test vector:")
|
||||
print(input.flatten().tolist())
|
||||
#print(input.flatten().tolist())
|
||||
print("ground truth result:")
|
||||
print(ground_truth_result.shape)
|
||||
print(ground_truth_result.T.flatten().tolist())
|
||||
|
||||
save_file({
|
||||
|
@ -35,7 +35,7 @@ class Ensemble(nn.Module):
|
||||
|
||||
# model batch
|
||||
def forward(self, embs):
|
||||
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
|
||||
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_channels
|
||||
|
||||
class BradleyTerry(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
@ -51,19 +51,25 @@ async def index(request):
|
||||
<form action="/rate" method="POST">
|
||||
<table>
|
||||
<tr>
|
||||
<td><input type="radio" name="rating-useful" value="1+" id="rq1p"> <label for="rq1p">LHS is much better (useful)</label></td>
|
||||
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
|
||||
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
|
||||
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
|
||||
<td><input type="radio" name="rating-useful" value="2+" id="rq2p"> <label for="rq2p">LHS is much better (useful)</label></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><input type="radio" name="rating-meme" value="1+" id="rm1p"> <label for="rm1p">LHS is much better (memetically)</label></td>
|
||||
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
|
||||
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
|
||||
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
|
||||
<td><input type="radio" name="rating-meme" value="2+" id="rm2p"> <label for="rm2p">LHS is much better (memetically)</label></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><input type="radio" name="rating-aesthetic" value="1+" id="ra1p"> <label for="ra1p">LHS is much better (aesthetically)</label></td>
|
||||
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
|
||||
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
|
||||
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
|
||||
<td><input type="radio" name="rating-aesthetic" value="2+" id="ra2p"> <label for="ra2p">LHS is much better (aesthetically)</label></td>
|
||||
</td>
|
||||
</table>
|
||||
|
||||
@ -82,35 +88,29 @@ async def index(request):
|
||||
document.querySelector("form").submit()
|
||||
}}
|
||||
}}
|
||||
const keys = {{
|
||||
"q": "rq1p",
|
||||
"w": "rq1",
|
||||
"e": "rqe",
|
||||
"r": "rq2",
|
||||
"t": "rq2p",
|
||||
"a": "rm1p",
|
||||
"s": "rm1",
|
||||
"d": "rme",
|
||||
"f": "rm2",
|
||||
"g": "rm2p",
|
||||
"z": "ra1p",
|
||||
"x": "ra1",
|
||||
"c": "rae",
|
||||
"v": "ra2",
|
||||
"b": "ra2p",
|
||||
}}
|
||||
document.addEventListener("keypress", function(event) {{
|
||||
if (event.key === "q") {{
|
||||
document.querySelector("input[name='rating-useful'][value='1']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "w") {{
|
||||
document.querySelector("input[name='rating-useful'][value='eq']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "e") {{
|
||||
document.querySelector("input[name='rating-useful'][value='2']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "a") {{
|
||||
document.querySelector("input[name='rating-meme'][value='1']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "s") {{
|
||||
document.querySelector("input[name='rating-meme'][value='eq']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "d") {{
|
||||
document.querySelector("input[name='rating-meme'][value='2']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "z") {{
|
||||
document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "x") {{
|
||||
document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true
|
||||
commitIfReady()
|
||||
}} else if (event.key === "c") {{
|
||||
document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true
|
||||
commitIfReady()
|
||||
const key = keys[event.key]
|
||||
if (key) {{
|
||||
document.getElementById(key).checked = true
|
||||
}}
|
||||
commitIfReady()
|
||||
}});
|
||||
</script>
|
||||
</body>
|
||||
|
@ -20,27 +20,35 @@ def fetch_embedding(filename):
|
||||
csr.close()
|
||||
return x.copy() # PyTorch complains otherwise due to bad
|
||||
|
||||
def map_rating(rating, uncertainty=0.05):
|
||||
def map_rating(rating):
|
||||
def map_one(rating):
|
||||
match rating:
|
||||
case "1": # meme 1 is better
|
||||
return 1 - uncertainty
|
||||
return 0.9
|
||||
case "2":
|
||||
return uncertainty
|
||||
return 0.1
|
||||
case "2+" | "2p":
|
||||
return 0.3
|
||||
case "1+" | "1p":
|
||||
return 0.7
|
||||
case "eq":
|
||||
return 0.5
|
||||
case _: raise ValueError("invalid rating, please fix")
|
||||
|
||||
return np.array([map_one(r) for r in rating.split(",")])
|
||||
|
||||
def fetch_ratings():
|
||||
def fetch_ratings(sets):
|
||||
trains = defaultdict(list)
|
||||
validations = defaultdict(list)
|
||||
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
||||
its = set()
|
||||
for meme1, meme2, rating, iteration in csr.fetchall():
|
||||
if iteration not in its:
|
||||
print(iteration)
|
||||
its.add(iteration)
|
||||
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
||||
csr.close()
|
||||
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
|
||||
return list(x[1] for x in sorted(trains.items()) if str(x[0]) in sets), list(x[1] for x in sorted(validations.items()) if str(x[0]) in sets)
|
||||
|
||||
def generate_random_permutations(x, n):
|
||||
out = []
|
||||
|
@ -9,11 +9,12 @@ import time
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
from dataclasses import dataclass, asdict
|
||||
import sys
|
||||
|
||||
from model import Config as ModelConfig, BradleyTerry
|
||||
import shared
|
||||
|
||||
trains, validations = shared.fetch_ratings()
|
||||
trains, validations = shared.fetch_ratings(sys.argv[1:])
|
||||
for train, validation in zip(trains, validations):
|
||||
print(len(train), len(validation))
|
||||
|
||||
@ -36,11 +37,11 @@ config = TrainConfig(
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
dropout=0.1,
|
||||
dropout=0.0,
|
||||
output_channels=3
|
||||
),
|
||||
lr=3e-4,
|
||||
weight_decay=0.2,
|
||||
weight_decay=0.0,
|
||||
batch_size=1,
|
||||
epochs=5,
|
||||
compile=False,
|
||||
|
@ -1,12 +1,12 @@
|
||||
import torch
|
||||
import pyarrow as pa
|
||||
import numpy as np
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
with pa.memory_map("../../sample_1m.arrow", "r") as source:
|
||||
loaded_arrays = pa.ipc.open_file(source).read_all()
|
||||
loaded_arrays = np.memmap("embeddings.bin", dtype=np.float16).reshape(-1, 1152)
|
||||
loaded_arrays_permutation = np.random.permutation(len(loaded_arrays))
|
||||
|
||||
train_split = 0.8
|
||||
|
||||
def ckpt_path(steps):
|
||||
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"
|
||||
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"
|
||||
|
11
sae/train.py
11
sae/train.py
@ -6,9 +6,10 @@ import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from dataclasses import dataclass, asdict
|
||||
import math
|
||||
|
||||
from model import SAEConfig, SAE
|
||||
from shared import train_split, loaded_arrays, ckpt_path
|
||||
from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation
|
||||
|
||||
device = "cuda"
|
||||
|
||||
@ -24,7 +25,7 @@ class TrainConfig:
|
||||
config = TrainConfig(
|
||||
model=SAEConfig(
|
||||
d_emb=1152,
|
||||
d_hidden=65536,
|
||||
d_hidden=262144,
|
||||
top_k=128,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
@ -33,7 +34,7 @@ config = TrainConfig(
|
||||
lr=3e-4,
|
||||
weight_decay=0.0,
|
||||
batch_size=64,
|
||||
epochs=5,
|
||||
epochs=1,
|
||||
compile=True,
|
||||
)
|
||||
|
||||
@ -81,7 +82,7 @@ with open(logfile, "w") as log:
|
||||
batch = []
|
||||
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
|
||||
for batch_start in t:
|
||||
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
|
||||
batch = numpy.stack([ embedding for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + config.batch_size]] ])
|
||||
|
||||
if len(batch) == config.batch_size:
|
||||
batch = torch.Tensor(batch).to(device)
|
||||
@ -98,4 +99,4 @@ with open(logfile, "w") as log:
|
||||
print(ctr)
|
||||
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
|
||||
|
||||
print(logfile)
|
||||
print(logfile)
|
||||
|
35
slow_dump_parse_script.py
Normal file
35
slow_dump_parse_script.py
Normal file
@ -0,0 +1,35 @@
|
||||
import umsgpack
|
||||
import zstandard
|
||||
import pyarrow as pa
|
||||
|
||||
data = []
|
||||
|
||||
with open("sample.zst", "rb") as f:
|
||||
decomp = zstandard.ZstdDecompressor()
|
||||
reader = decomp.stream_reader(f)
|
||||
count = 0
|
||||
while True:
|
||||
try:
|
||||
url, id, title, subreddit, author, timestamp, embedding = umsgpack.unpack(reader)
|
||||
embedding = bytes(embedding)
|
||||
data.append({"url": url, "id": id, "title": title, "subreddit": subreddit, "author": author, "timestamp": timestamp, "embedding": embedding})
|
||||
count += 1
|
||||
except umsgpack.InsufficientDataException:
|
||||
break
|
||||
print(count)
|
||||
|
||||
schema = pa.schema([
|
||||
("url", pa.string()),
|
||||
("id", pa.string()),
|
||||
("title", pa.string()),
|
||||
("subreddit", pa.string()),
|
||||
("author", pa.string()),
|
||||
("timestamp", pa.int64()),
|
||||
("embedding", pa.binary())
|
||||
])
|
||||
|
||||
table = pa.Table.from_pylist(data, schema=schema)
|
||||
|
||||
with pa.OSFile("output.parquet", "wb") as sink:
|
||||
with pa.RecordBatchFileWriter(sink, schema) as writer:
|
||||
writer.write_table(table)
|
@ -183,8 +183,8 @@ pub struct FrontendInit {
|
||||
pub type EmbeddingVector = Vec<f32>;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct QueryResult {
|
||||
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>)>,
|
||||
pub struct QueryResult<T> {
|
||||
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>, Option<T>)>,
|
||||
pub formats: Vec<String>,
|
||||
pub extensions: HashMap<String, String>,
|
||||
}
|
||||
@ -203,7 +203,9 @@ pub struct QueryRequest {
|
||||
pub terms: Vec<QueryTerm>,
|
||||
pub k: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub include_video: bool
|
||||
pub include_video: bool,
|
||||
#[serde(default)]
|
||||
pub debug_enabled: bool
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
@ -238,8 +240,9 @@ pub async fn get_total_embedding<A: Future<Output = Result<Vec<Vec<u8>>>>, B: Fu
|
||||
}
|
||||
}
|
||||
if let Some(name) = &term.predefined_embedding {
|
||||
let embedding = predefined_embeddings.get(name).context("name invalid")?;
|
||||
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
||||
if let Some(embedding) = predefined_embeddings.get(name) {
|
||||
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,6 +69,10 @@ struct CLIArguments {
|
||||
gpu: Option<usize>,
|
||||
#[argh(option, description="descriptor CDFs")]
|
||||
cdfs: Option<String>,
|
||||
#[argh(option, description="postfilter by embedding (late discard if dot product above threshold)")]
|
||||
postfilter: Vec<String>,
|
||||
#[argh(option, description="postfilter by scorer")]
|
||||
postfilter_scorer: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
@ -162,10 +166,22 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
(snd, None)
|
||||
};
|
||||
let mut post_threshold = None;
|
||||
for x in &args.postfilter {
|
||||
let (tname, snd) = x.split_once(':').context("invalid postfilter argument")?;
|
||||
if tname == name {
|
||||
post_threshold = Some(snd.parse::<f32>().context("parse postfilter threshold")?);
|
||||
}
|
||||
}
|
||||
let blob = fs::read(path).context("read embedding")?;
|
||||
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold));
|
||||
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold, post_threshold));
|
||||
}
|
||||
|
||||
let postfilter_scorer = args.postfilter_scorer.iter().map(|x| {
|
||||
let (id, snd) = x.split_once(':').context("invalid postfilter scorer argument")?;
|
||||
Ok((id.parse::<usize>().context("invalid postfilter scorer id")?, snd.parse::<f32>().context("parse postfilter scorer threshold")?))
|
||||
}).collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let pq_codec = if let Some(pq_codec) = args.pq_codec {
|
||||
let data = fs::read(pq_codec).context("read pq codec")?;
|
||||
let pq_codec: ProductQuantizer = rmp_serde::from_read(&data[..]).context("decode pq codec")?;
|
||||
@ -255,7 +271,6 @@ fn main() -> Result<()> {
|
||||
files.sort_by_key(|(id, _)| *id);
|
||||
shard_id_mappings.sort_by_key(|(id, _)| *id);
|
||||
|
||||
|
||||
let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
|
||||
let mut out_vertices: Vec<u32> = vec![];
|
||||
let mut shards: Vec<u32> = vec![];
|
||||
@ -323,6 +338,8 @@ fn main() -> Result<()> {
|
||||
|
||||
let th = std::thread::spawn(move || reader_thread(&args.paths, tx));
|
||||
|
||||
let mut postfilter_count = 0;
|
||||
|
||||
let mut rng2 = rng.fork();
|
||||
let initial_filter = |x: ProcessedEntry| {
|
||||
i += 1;
|
||||
@ -337,7 +354,9 @@ fn main() -> Result<()> {
|
||||
latest_timestamp = latest_timestamp.max(timestamp);
|
||||
earliest_timestamp = earliest_timestamp.min(timestamp);
|
||||
|
||||
for (_name, vec, histogram, threshold) in &mut embeddings {
|
||||
let mut postfilter = false;
|
||||
|
||||
for (_name, vec, histogram, threshold, postfilter_threshold) in &mut embeddings {
|
||||
let dot = SpatialSimilarity::dot(&embedding, vec).unwrap() as f32;
|
||||
histogram.add(dot);
|
||||
if let Some(threshold) = threshold {
|
||||
@ -345,6 +364,12 @@ fn main() -> Result<()> {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
if let Some(threshold) = postfilter_threshold {
|
||||
if dot >= *threshold {
|
||||
postfilter = true;
|
||||
postfilter_count += 1; // somewhat wrong because could be duplicated
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// distance thresholding is too costly to do over a long range so just do it badly
|
||||
@ -393,7 +418,7 @@ fn main() -> Result<()> {
|
||||
println!("{}", data);
|
||||
}
|
||||
|
||||
Some((x, embedding))
|
||||
Some((x, embedding, postfilter))
|
||||
};
|
||||
|
||||
let mut dead_count = 0;
|
||||
@ -404,14 +429,14 @@ fn main() -> Result<()> {
|
||||
let batch: Vec<_> = batch.collect();
|
||||
let batch_len = batch.len();
|
||||
|
||||
for (x, _embedding) in batch.iter() {
|
||||
for (x, _embedding, _postfilter) in batch.iter() {
|
||||
if let Some(ref mut file) = output_file {
|
||||
file.write_all(&x.embedding)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(shards) = &mut shards_out {
|
||||
for (i, (x, embedding)) in batch.iter().enumerate() {
|
||||
for (i, (x, embedding, _postfilter)) in batch.iter().enumerate() {
|
||||
// closest matches first
|
||||
shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| {
|
||||
let mut dot = SpatialSimilarity::dot(¢roid, &embedding).unwrap();
|
||||
@ -439,7 +464,7 @@ fn main() -> Result<()> {
|
||||
let quantizer = pq_codec.as_ref().context("PQ codec needed to output index")?;
|
||||
|
||||
let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize);
|
||||
for (_x, embedding) in batch.iter() {
|
||||
for (_x, embedding, _postfilter) in batch.iter() {
|
||||
batch_embeddings.extend_from_slice(&embedding);
|
||||
}
|
||||
let codes = quantizer.quantize_batch(&batch_embeddings);
|
||||
@ -448,7 +473,7 @@ fn main() -> Result<()> {
|
||||
let cdfs = cdfs.as_ref().context("score model CDFs needed to output index")?;
|
||||
let scores = score_model.score_batch(&batch_embeddings)?;
|
||||
|
||||
for (i, (x, _embedding)) in batch.into_iter().enumerate() {
|
||||
for (i, (x, _embedding, mut postfilter)) in batch.into_iter().enumerate() {
|
||||
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
|
||||
|
||||
let mut entry_scores = scores[(i * score_model.output_channels)..((i + 1) * score_model.output_channels)].to_vec();
|
||||
@ -465,6 +490,13 @@ fn main() -> Result<()> {
|
||||
index_output_file.2.write_all(&[cdf_bucket])?;
|
||||
}
|
||||
|
||||
for (index, score) in postfilter_scorer.iter() {
|
||||
if entry_scores[*index] < *score {
|
||||
postfilter = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut entry = PackedIndexEntry {
|
||||
id: count + i as u32,
|
||||
vertices,
|
||||
@ -476,9 +508,10 @@ fn main() -> Result<()> {
|
||||
shards
|
||||
};
|
||||
let mut bytes = bitcode::encode(&entry);
|
||||
if bytes.len() > (RECORD_PAD_SIZE - 2) {
|
||||
// as an ugly hack for removing entries already in the index shards, kill the URL and make it a graph node only
|
||||
if bytes.len() > (RECORD_PAD_SIZE - 2) || postfilter {
|
||||
// we do need the records to fit in a fixed size and can't really drop things, so discard URL so it can exist as a graph node only
|
||||
entry.url = String::new();
|
||||
entry.url = String::new(); // URL is only input-controlled, arbitrary-length field
|
||||
bytes = bitcode::encode(&entry);
|
||||
dead_count += 1;
|
||||
}
|
||||
@ -494,11 +527,11 @@ fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
if args.print_aggregates {
|
||||
println!("earliest={} latest={} count={} read={} deduped={}", earliest_timestamp, latest_timestamp, count, i, deduped_count);
|
||||
println!("earliest={} latest={} count={} read={} deduped={} postfiltered={}", earliest_timestamp, latest_timestamp, count, i, deduped_count, postfilter_count);
|
||||
}
|
||||
if let Some(histogram_path) = args.histograms {
|
||||
let mut file = fs::File::create(histogram_path)?;
|
||||
for (name, _, histogram, _) in &embeddings {
|
||||
for (name, _, histogram, _, _) in &embeddings {
|
||||
let width = 800.0;
|
||||
let padding = 40.0;
|
||||
let bars_height = 300 as f64;
|
||||
|
@ -10,14 +10,16 @@ with open("mse_config.json") as f:
|
||||
def get_embedding(req):
|
||||
return msgpack.unpackb(requests.post(config["clip_server"], data=msgpack.packb(req)).content)
|
||||
|
||||
output, input, *xs = sys.argv[1:]
|
||||
mode, output, input = sys.argv[1:]
|
||||
|
||||
with open(output, "wb") as f:
|
||||
with open(input, "rb") as g:
|
||||
input_data = g.read()
|
||||
if not xs:
|
||||
if mode == "image":
|
||||
with open(input, "rb") as g:
|
||||
input_data = g.read()
|
||||
result = get_embedding({"images": [input_data]})[0]
|
||||
elif mode == "text":
|
||||
result = get_embedding({"text": input})[0]
|
||||
else:
|
||||
result = get_embedding({"text": xs})[0]
|
||||
raise Exception("unknown mode")
|
||||
f.write(result)
|
||||
print(base64.urlsafe_b64encode(result).decode("ascii"))
|
||||
|
@ -892,7 +892,7 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
|
||||
}
|
||||
|
||||
#[instrument(skip(index))]
|
||||
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
|
||||
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult<()>> {
|
||||
let result = index.vectors.search(&query, k as usize)?;
|
||||
|
||||
let mut seen_videos = HashSet::new();
|
||||
@ -916,7 +916,8 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
|
||||
index.filenames[id].container_filename(),
|
||||
generate_filename_hash(&index.filenames[id as usize]).clone(),
|
||||
index.format_codes[id],
|
||||
index.metadata[id].as_ref().map(|x| (x.width, x.height))
|
||||
index.metadata[id].as_ref().map(|x| (x.width, x.height)),
|
||||
Option::<()>::None
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
@ -7,7 +7,7 @@ use argh::FromArgs;
|
||||
use itertools::Itertools;
|
||||
use foldhash::{HashSet, HashSetExt};
|
||||
use half::f16;
|
||||
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64}};
|
||||
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64, SCALE_F64}};
|
||||
use simsimd::SpatialSimilarity;
|
||||
use memmap2::{Mmap, MmapOptions};
|
||||
use std::rc::Rc;
|
||||
@ -18,13 +18,14 @@ use http_body_util::{BodyExt, Empty, Full};
|
||||
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
||||
use std::pin::Pin;
|
||||
use std::future::Future;
|
||||
use serde::Serialize;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::str::FromStr;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
|
||||
mod common;
|
||||
|
||||
use common::{resize_for_embed_sync, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
|
||||
use common::{resize_for_embed_sync, QueryTerm, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
|
||||
|
||||
#[derive(FromArgs, Clone)]
|
||||
#[argh(description="Query disk index")]
|
||||
@ -43,10 +44,18 @@ struct CLIArguments {
|
||||
search_list_size: Option<usize>,
|
||||
#[argh(switch, description="always use full-precision vectors (slow)")]
|
||||
disable_pq: bool,
|
||||
#[argh(option, short='l', description="listen address")]
|
||||
listen_address: Option<String>,
|
||||
#[argh(option, short='c', description="clip server")]
|
||||
clip_server: Option<String>,
|
||||
#[argh(option, short='c', description="server config file")]
|
||||
config_path: Option<String>
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
struct ServerConfig {
|
||||
listen_address: String,
|
||||
clip_server: String,
|
||||
descriptor_names: Vec<String>,
|
||||
telemetry_file: String,
|
||||
search_list: usize,
|
||||
beam_width: usize
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
@ -82,17 +91,30 @@ fn next_several_unvisited(s: &mut NeighbourBuffer, n: usize) -> Option<Vec<u32>>
|
||||
}
|
||||
}
|
||||
|
||||
const DUPLICATES_THRESHOLD: f32 = 0.95;
|
||||
|
||||
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
|
||||
let loc = (id as usize) * index.pq_code_size;
|
||||
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
|
||||
}
|
||||
|
||||
struct VisitedNode {
|
||||
image_url: String,
|
||||
scores: Vec<f32>,
|
||||
shards: Vec<u32>,
|
||||
id: u32,
|
||||
score: i64,
|
||||
timestamp: u64,
|
||||
dimensions: (u32, u32)
|
||||
}
|
||||
|
||||
struct Scratch {
|
||||
visited_adjacent: HashSet<u32>,
|
||||
visited: HashSet<u32>,
|
||||
neighbour_buffer: NeighbourBuffer,
|
||||
neighbour_pre_buffer: Vec<u32>,
|
||||
visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)>
|
||||
visited_list: Vec<VisitedNode>,
|
||||
visited_embeddings: Vec<f32>
|
||||
}
|
||||
|
||||
struct Index {
|
||||
@ -140,10 +162,20 @@ async fn greedy_search<'a>(scratch: &mut Scratch, start: u32, query: &[f16], que
|
||||
let index = index.clone();
|
||||
let node = handle.await?;
|
||||
let vector = bytemuck::cast_slice(&node.vector);
|
||||
let distance = fast_dot_noprefetch(query, &vector);
|
||||
let mut score = fast_dot_noprefetch(query, &vector);
|
||||
score += descriptor_product(index.clone(), &descriptor_scales, node.id);
|
||||
cmps += 1;
|
||||
if scratch.visited.insert(node.id) {
|
||||
scratch.visited_list.push((node.id, distance, node.url, node.shards, node.scores));
|
||||
if scratch.visited.insert(node.id) && node.url.len() > 0 {
|
||||
scratch.visited_list.push(VisitedNode {
|
||||
image_url: node.url,
|
||||
scores: node.scores,
|
||||
shards: node.shards,
|
||||
id: node.id,
|
||||
score,
|
||||
timestamp: node.timestamp,
|
||||
dimensions: node.dimensions
|
||||
});
|
||||
scratch.visited_embeddings.extend(bytemuck::cast_slice(&node.vector).iter().map(|x: &f16| x.to_f32()));
|
||||
};
|
||||
for &neighbour in node.vertices.iter() {
|
||||
if scratch.visited_adjacent.insert(neighbour) {
|
||||
@ -250,7 +282,8 @@ async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
||||
neighbour_pre_buffer: Vec::new(),
|
||||
visited_list: Vec::new(),
|
||||
visited_adjacent: HashSet::new()
|
||||
visited_adjacent: HashSet::new(),
|
||||
visited_embeddings: Vec::new()
|
||||
};
|
||||
|
||||
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
|
||||
@ -265,14 +298,14 @@ async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||
println!("index scan {}: {:?} cmps", shard, cmps_result);
|
||||
}
|
||||
|
||||
scratch.visited_list.sort_by_key(|x| -x.1);
|
||||
for (i, (id, distance, url, shards, scores)) in scratch.visited_list.iter().take(20).enumerate() {
|
||||
let found_id = match matches.binary_search(&(*id, 0)) {
|
||||
scratch.visited_list.sort_by_key(|x| -x.score);
|
||||
for (i, node) in scratch.visited_list.iter().take(20).enumerate() {
|
||||
let found_id = match matches.binary_search(&(node.id, 0)) {
|
||||
Ok(pos) => pos,
|
||||
Err(pos) => pos
|
||||
};
|
||||
if args.verbose {
|
||||
println!("index scan: {} {} {} {:?} {:?}; rank {}", id, distance, url, shards, scores, matches[found_id].1 + 1);
|
||||
println!("index scan: {} {} {} {:?} {:?}; rank {}", node.id, node.score, node.image_url, node.shards, node.scores, matches[found_id].1 + 1);
|
||||
};
|
||||
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
|
||||
}
|
||||
@ -341,11 +374,23 @@ pub async fn query_clip_server<I, O>(base_url: &str, path: &str, data: Option<I>
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct TelemetryMessage {
|
||||
#[serde(rename="correlationId")]
|
||||
correlation_id: String,
|
||||
data: serde_json::Value,
|
||||
event: String,
|
||||
#[serde(rename="instanceId")]
|
||||
instance_id: String,
|
||||
page: String
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Service {
|
||||
index: Rc<Index>,
|
||||
inference_server_config: Rc<InferenceServerConfig>,
|
||||
args: Rc<CLIArguments>
|
||||
config: Rc<ServerConfig>,
|
||||
telemetry_channel: std::sync::mpsc::Sender<TelemetryMessage>
|
||||
}
|
||||
|
||||
impl hyper::service::Service<Request<Incoming>> for Service {
|
||||
@ -355,15 +400,16 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
||||
|
||||
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||
let index = self.index.clone();
|
||||
let args = self.args.clone();
|
||||
let config = self.config.clone();
|
||||
let inference_server_config = self.inference_server_config.clone();
|
||||
let channel = self.telemetry_channel.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let mut body = match (req.method(), req.uri().path()) {
|
||||
(&Method::GET, "/") => Response::new(Full::new(Bytes::from(serde_json::to_vec(&FrontendInit {
|
||||
n_total: (index.header.count - index.header.dead_count) as u64,
|
||||
d_emb: index.header.quantizer.n_dims,
|
||||
predefined_embedding_names: vec![]
|
||||
predefined_embedding_names: config.descriptor_names.clone()
|
||||
})?))),
|
||||
(&Method::POST, "/") => {
|
||||
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
|
||||
@ -381,7 +427,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
||||
&body.terms,
|
||||
&*inference_server_config,
|
||||
|batch, _config| {
|
||||
query_clip_server(args.clip_server.as_ref().unwrap(), "/", Some(batch))
|
||||
query_clip_server(config.clip_server.as_str(), "/", Some(batch))
|
||||
},
|
||||
|image, config| async move {
|
||||
let image = image::load_from_memory(&image)?;
|
||||
@ -397,27 +443,89 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
||||
}).unwrap();
|
||||
let selected_start = index.header.shards[selected_shard].1;
|
||||
|
||||
let beamwidth = 3;
|
||||
let beamwidth = config.beam_width;
|
||||
|
||||
let mut scratch = Scratch {
|
||||
visited: HashSet::new(),
|
||||
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
||||
neighbour_buffer: NeighbourBuffer::new(config.search_list),
|
||||
neighbour_pre_buffer: Vec::new(),
|
||||
visited_list: Vec::new(),
|
||||
visited_adjacent: HashSet::new()
|
||||
visited_adjacent: HashSet::new(),
|
||||
visited_embeddings: Vec::new()
|
||||
};
|
||||
|
||||
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
|
||||
let mut desc = vec![0.0, 0.0, 0.0, 0.0];
|
||||
|
||||
for term in &body.terms {
|
||||
if let Some(name) = &term.predefined_embedding {
|
||||
if let Some(index) = config.descriptor_names.iter().position(|x| x == name) {
|
||||
desc[index] = term.weight.unwrap_or(1.0) * 1.0/512.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let descriptor_scales = DescriptorScales(desc);
|
||||
|
||||
let query_preprocessed = index.header.quantizer.preprocess_query(&query);
|
||||
|
||||
let query = query.iter().map(|x| half::f16::from_f32(*x)).collect::<Vec<f16>>();
|
||||
|
||||
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), args.disable_pq, beamwidth).await?;
|
||||
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), false, beamwidth).await?;
|
||||
|
||||
scratch.visited_list.sort_by_key(|x| -x.1);
|
||||
let n_visited = scratch.visited_list.len();
|
||||
|
||||
let matches = scratch.visited_list.drain(..).map(|(id, score, url, shards, scores)| (score as f32, url, String::new(), 0, None)).collect::<Vec<_>>();
|
||||
let mut similarities_against_self = vec![0.0f32; n_visited * n_visited];
|
||||
|
||||
// runtime deduplicate of results list
|
||||
unsafe {
|
||||
// vecs @ vecs.T
|
||||
matrixmultiply::sgemm(
|
||||
n_visited,
|
||||
index.header.quantizer.n_dims,
|
||||
n_visited,
|
||||
1.0,
|
||||
scratch.visited_embeddings.as_ptr(),
|
||||
index.header.quantizer.n_dims as isize,
|
||||
1,
|
||||
scratch.visited_embeddings.as_ptr(),
|
||||
1,
|
||||
index.header.quantizer.n_dims as isize,
|
||||
0.0,
|
||||
similarities_against_self.as_mut_ptr(),
|
||||
n_visited as isize,
|
||||
1
|
||||
);
|
||||
}
|
||||
|
||||
// discard anything similar to something already in list
|
||||
let mut i = 0;
|
||||
let mut included = bitvec::bitvec![0; n_visited];
|
||||
scratch.visited_list.retain(|_node| {
|
||||
let row = &similarities_against_self[(i * n_visited)..((i + 1) * n_visited)];
|
||||
let old_i = i;
|
||||
i += 1;
|
||||
for (other_i, similarity) in row.iter().enumerate() {
|
||||
if similarity > &DUPLICATES_THRESHOLD && included[other_i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
included.set(old_i, true);
|
||||
true
|
||||
});
|
||||
|
||||
scratch.visited_list.sort_unstable_by_key(|x| -x.score);
|
||||
|
||||
let matches = scratch.visited_list
|
||||
.drain(..)
|
||||
.map(|node| {
|
||||
let debug = if body.debug_enabled {
|
||||
Some((node.scores, node.shards, node.timestamp))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
((node.score as f64 / SCALE_F64) as f32, node.image_url, String::new(), 0, Some(node.dimensions), debug)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let result = QueryResult {
|
||||
formats: vec![],
|
||||
@ -438,6 +546,25 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
||||
.header(hyper::header::CONTENT_TYPE, "text/plain; version=0.0.4")
|
||||
.body(Full::new(Bytes::from(buffer))).unwrap()
|
||||
},
|
||||
(&Method::POST, "/telemetry") => {
|
||||
// TODO refactor
|
||||
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
|
||||
if upper > 1000 {
|
||||
let mut resp = Response::new(Full::new(Bytes::from("Body too big")));
|
||||
*resp.status_mut() = hyper::StatusCode::PAYLOAD_TOO_LARGE;
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
let whole_body = req.collect().await?.to_bytes();
|
||||
|
||||
let message = serde_json::from_slice::<TelemetryMessage>(&whole_body)?;
|
||||
|
||||
channel.send(message)?;
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.body(Full::new(Bytes::from(""))).unwrap()
|
||||
}
|
||||
(&Method::OPTIONS, "/") => {
|
||||
Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
@ -459,9 +586,9 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceServerConfig> {
|
||||
async fn get_backend_config(clip_server: &String) -> Result<InferenceServerConfig> {
|
||||
loop {
|
||||
match query_clip_server(clip_server.as_ref().unwrap(), "/config", Option::<()>::None).await {
|
||||
match query_clip_server(clip_server, "/config", Option::<()>::None).await {
|
||||
Ok(config) => return Ok(config),
|
||||
Err(err) => {
|
||||
tracing::warn!("waiting for clip server: {}", err);
|
||||
@ -471,14 +598,31 @@ async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceSer
|
||||
}
|
||||
}
|
||||
|
||||
// can't run this as an async task because monoio File API is positional writes only
|
||||
fn telemetry_handler(rx: std::sync::mpsc::Receiver<TelemetryMessage>, config: ServerConfig) -> Result<()> {
|
||||
let mut telemetry_file = std::fs::OpenOptions::new().create(true).create(true).append(true).open(&config.telemetry_file)?;
|
||||
while let Ok(message) = rx.recv() {
|
||||
telemetry_file.write_all(rmp_serde::to_vec(&message)?.as_slice())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||
let config: ServerConfig = serde_json::from_slice(&std::fs::read(args.config_path.as_ref().unwrap())?)?;
|
||||
|
||||
let (telemetry_channel, telemetry_receiver) = std::sync::mpsc::channel();
|
||||
|
||||
let config_ = config.clone();
|
||||
std::thread::spawn(move || telemetry_handler(telemetry_receiver, config_));
|
||||
|
||||
let service = Service {
|
||||
index,
|
||||
inference_server_config: Rc::new(get_backend_config(&args.clip_server).await?),
|
||||
args: Rc::new(args.clone())
|
||||
inference_server_config: Rc::new(get_backend_config(&config.clip_server).await?),
|
||||
config: Rc::new(config.clone()),
|
||||
telemetry_channel
|
||||
};
|
||||
|
||||
let listener = TcpListener::bind(args.listen_address.as_ref().unwrap())?;
|
||||
let listener = TcpListener::bind(config.listen_address)?;
|
||||
println!("Listening");
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
@ -531,7 +675,7 @@ async fn main() -> Result<()> {
|
||||
n_descriptors: header.descriptor_cdfs.len(),
|
||||
});
|
||||
|
||||
if args.listen_address.is_some() {
|
||||
if args.config_path.is_some() {
|
||||
serve(&args, index).await?;
|
||||
} else {
|
||||
evaluate(&args, index).await?;
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user