mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-27 21:13:11 +00:00
release version
This commit is contained in:
parent
3852d0078d
commit
ee23b81444
5
.gitignore
vendored
5
.gitignore
vendored
@ -22,3 +22,8 @@ index
|
|||||||
queries.txt
|
queries.txt
|
||||||
*.zst
|
*.zst
|
||||||
.safetensors
|
.safetensors
|
||||||
|
*/static/*.woff2
|
||||||
|
flamegraph.svg
|
||||||
|
*.jsonl
|
||||||
|
*.safetensors
|
||||||
|
perf.data
|
||||||
|
41
Cargo.lock
generated
41
Cargo.lock
generated
@ -416,6 +416,18 @@ version = "2.5.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
|
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitvec"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
|
||||||
|
dependencies = [
|
||||||
|
"funty",
|
||||||
|
"radium",
|
||||||
|
"tap",
|
||||||
|
"wyz",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "block-buffer"
|
name = "block-buffer"
|
||||||
version = "0.10.4"
|
version = "0.10.4"
|
||||||
@ -1117,6 +1129,12 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "funty"
|
||||||
|
version = "2.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
@ -2046,6 +2064,7 @@ dependencies = [
|
|||||||
"axum",
|
"axum",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"bitcode",
|
"bitcode",
|
||||||
|
"bitvec",
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"candle-core",
|
"candle-core",
|
||||||
"chrono",
|
"chrono",
|
||||||
@ -2067,6 +2086,7 @@ dependencies = [
|
|||||||
"itertools 0.13.0",
|
"itertools 0.13.0",
|
||||||
"json5",
|
"json5",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
|
"matrixmultiply",
|
||||||
"maud",
|
"maud",
|
||||||
"memmap2",
|
"memmap2",
|
||||||
"mimalloc",
|
"mimalloc",
|
||||||
@ -2850,6 +2870,12 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "radium"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand"
|
name = "rand"
|
||||||
version = "0.8.5"
|
version = "0.8.5"
|
||||||
@ -3873,6 +3899,12 @@ dependencies = [
|
|||||||
"version-compare",
|
"version-compare",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tap"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "target-lexicon"
|
name = "target-lexicon"
|
||||||
version = "0.12.16"
|
version = "0.12.16"
|
||||||
@ -4732,6 +4764,15 @@ dependencies = [
|
|||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wyz"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
|
||||||
|
dependencies = [
|
||||||
|
"tap",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yoke"
|
name = "yoke"
|
||||||
version = "0.7.5"
|
version = "0.7.5"
|
||||||
|
@ -62,6 +62,8 @@ monoio = "0.2"
|
|||||||
hyper = "1"
|
hyper = "1"
|
||||||
monoio-compat = { version = "0.2", features = ["hyper"] }
|
monoio-compat = { version = "0.2", features = ["hyper"] }
|
||||||
http-body-util = "0.1"
|
http-body-util = "0.1"
|
||||||
|
matrixmultiply = "0.3"
|
||||||
|
bitvec = "1"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "reddit-dump"
|
name = "reddit-dump"
|
||||||
|
@ -9,15 +9,15 @@
|
|||||||
|
|
||||||
<div class="about">
|
<div class="about">
|
||||||
<p>
|
<p>
|
||||||
Welcome to {util.hardConfig.name} by <a href="https://osmarks.net/">osmarks.net Computational Memetics</a>. {util.hardConfig.name} searches images using semantic image/text embedding models. In general, search by thinking of what caption your desired image might have been given by random people on the internet. The model currently in use can read text fairly well and understands moderately abstract properties of images, but is limited to English and case-insensitive.
|
Welcome to {util.hardConfig.name} by <a href="https://osmarks.net/">osmarks.net</a> Computational Memetics. {util.hardConfig.name} searches images using semantic image/text embedding models (refer to <a href="https://arxiv.org/abs/2303.15343">https://arxiv.org/abs/2303.15343</a>, <a href="https://arxiv.org/abs/2103.00020">https://arxiv.org/abs/2103.00020</a>). In general, search by thinking of what caption your desired image might have been given by random people on the internet. The model currently in use can read text fairly well and understands moderately abstract properties of images, but is limited to English and case-insensitive.
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
Advanced Mode sliders are generated from PCA on the index. The human-readable labels are generated manually by <a href="https://datasets.osmarks.net/components.html">looking at things</a>.
|
"Useful"/"aesthetic"/"meme" sliders are defined based on an approximate extrapolation of my preferences and may not agree with your own opinions. Results are otherwise decided by the inscrutable whims of a billion-parameter neural network. Please note that large-scale internet data may contain things you do not like.
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
The code is open-source and available on <a href="https://github.com/osmarks/meme-search-engine/">GitHub.</a>
|
The code is open-source and available on <a href="https://github.com/osmarks/meme-search-engine/">GitHub</a>.
|
||||||
</p>
|
</p>
|
||||||
{#if util.hardConfig.telemetryEndpoint}
|
{#if util.hardConfig.telemetry_endpoint}
|
||||||
<h2>Privacy</h2>
|
<h2>Privacy</h2>
|
||||||
<p>
|
<p>
|
||||||
We do not collect personal information. We do collect usage information (associated with a random ID) to improve the ranking algorithms. You can disable this:
|
We do not collect personal information. We do collect usage information (associated with a random ID) to improve the ranking algorithms. You can disable this:
|
||||||
|
@ -154,7 +154,6 @@
|
|||||||
<div class="right">
|
<div class="right">
|
||||||
<NavItem page="advanced">Advanced</NavItem>
|
<NavItem page="advanced">Advanced</NavItem>
|
||||||
<NavItem page="about">About</NavItem>
|
<NavItem page="about">About</NavItem>
|
||||||
<NavItem page="refine">Refine</NavItem>
|
|
||||||
</div>
|
</div>
|
||||||
</nav>
|
</nav>
|
||||||
|
|
||||||
@ -185,7 +184,11 @@
|
|||||||
<option>+</option>
|
<option>+</option>
|
||||||
<option>-</option>
|
<option>-</option>
|
||||||
</select>
|
</select>
|
||||||
<input type="range" min="0" max="2" bind:value={term.weight} step="0.01">
|
{#if debugEnabled}
|
||||||
|
<input type="number" bind:value={term.weight} step="0.01">
|
||||||
|
{:else}
|
||||||
|
<input type="range" min="0" max="2" bind:value={term.weight} step="0.01">
|
||||||
|
{/if}
|
||||||
{#if term.type === "image"}
|
{#if term.type === "image"}
|
||||||
<span>{term.file.name}</span>
|
<span>{term.file.name}</span>
|
||||||
{:else if term.type === "text"}
|
{:else if term.type === "text"}
|
||||||
@ -200,6 +203,14 @@
|
|||||||
</li>
|
</li>
|
||||||
{/each}
|
{/each}
|
||||||
</ul>
|
</ul>
|
||||||
|
{#if showDebugSwitch}
|
||||||
|
<div class="ctrlbar">
|
||||||
|
<input type="checkbox" bind:checked={debugEnabled} id="debug" />
|
||||||
|
<label for="debug">Debug</label>
|
||||||
|
{#if debugEnabled}
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
<div class="ctrlbar">
|
<div class="ctrlbar">
|
||||||
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
|
<input type="search" placeholder="Text Query" on:keydown={handleKey} on:focus={newTextQuery}>
|
||||||
<button on:click={pickFile}>Image Query</button>
|
<button on:click={pickFile}>Image Query</button>
|
||||||
@ -242,6 +253,9 @@
|
|||||||
let queryCounter = 0
|
let queryCounter = 0
|
||||||
let config = {}
|
let config = {}
|
||||||
|
|
||||||
|
const showDebugSwitch = localStorage.getItem("debugEnabled") === "true"
|
||||||
|
let debugEnabled = false
|
||||||
|
|
||||||
const newTextQuery = (content=null) => {
|
const newTextQuery = (content=null) => {
|
||||||
queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" })
|
queryTerms.push({ type: "text", weight: 1, sign: "+", text: typeof content === "string" ? content : "" })
|
||||||
queryTerms = queryTerms
|
queryTerms = queryTerms
|
||||||
@ -253,9 +267,14 @@
|
|||||||
|
|
||||||
const runSearch = async () => {
|
const runSearch = async () => {
|
||||||
if (!resultPromise) {
|
if (!resultPromise) {
|
||||||
|
const friendlyModeQueryOrRandom = friendlyModeQuery ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : [{ embedding: util.randn(config.d_emb, 1 / (config.d_emb ** 0.5)), weight: 1, sign: "+" }]
|
||||||
|
const terms = friendlyMode ?
|
||||||
|
friendlyModeQueryOrRandom.concat(util.hardConfig.friendly_mode_default_terms ?? []) :
|
||||||
|
queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] }))
|
||||||
let args = {
|
let args = {
|
||||||
"terms": friendlyMode ? [{ text: friendlyModeQuery, weight: 1, sign: "+" }] : queryTerms.filter(x => x.text !== "").map(x => ({ image: x.imageData, text: x.text, embedding: x.embedding, predefined_embedding: x.predefinedEmbedding, weight: x.weight * { "+": 1, "-": -1 }[x.sign] })),
|
"terms": terms,
|
||||||
"include_video": true
|
"include_video": true,
|
||||||
|
"debug_enabled": debugEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
util.sendTelemetry("search", {
|
util.sendTelemetry("search", {
|
||||||
@ -314,7 +333,7 @@
|
|||||||
type: "predefined_embedding",
|
type: "predefined_embedding",
|
||||||
predefinedEmbedding: predefinedEmbeddingName,
|
predefinedEmbedding: predefinedEmbeddingName,
|
||||||
sign: "+",
|
sign: "+",
|
||||||
weight: 0.2
|
weight: 1
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
queryTerms = queryTerms
|
queryTerms = queryTerms
|
||||||
|
@ -36,26 +36,14 @@
|
|||||||
|
|
||||||
const d_emb = 1152
|
const d_emb = 1152
|
||||||
|
|
||||||
const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
|
|
||||||
const vecZero = d => new Array(d).fill(0)
|
|
||||||
const vecScale = (xs, s) => xs.map(x => x * s)
|
|
||||||
|
|
||||||
const boxMuller = () => {
|
|
||||||
let x = Math.random()
|
|
||||||
let y = Math.random()
|
|
||||||
return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
|
|
||||||
}
|
|
||||||
|
|
||||||
const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)
|
|
||||||
|
|
||||||
const K = 2
|
const K = 2
|
||||||
let candidates = []
|
let candidates = []
|
||||||
|
|
||||||
const select = async candidate => {
|
const select = async candidate => {
|
||||||
candidates = []
|
candidates = []
|
||||||
const direction = randn(d_emb, 1 / d_emb)
|
const direction = util.randn(d_emb, 1 / d_emb)
|
||||||
for (let i = -K; i <= K; i++) {
|
for (let i = -K; i <= K; i++) {
|
||||||
const newV = vecSum(vecScale(direction, i / K), candidate.vector)
|
const newV = util.vecSum(util.vecScale(direction, i / K), candidate.vector)
|
||||||
candidates.push({ vector: newV, results: null, i: i + K })
|
candidates.push({ vector: newV, results: null, i: i + K })
|
||||||
}
|
}
|
||||||
await Promise.all(candidates.map(async x => {
|
await Promise.all(candidates.map(async x => {
|
||||||
@ -67,7 +55,7 @@
|
|||||||
console.log(candidates)
|
console.log(candidates)
|
||||||
}
|
}
|
||||||
|
|
||||||
select({ vector: randn(d_emb, 1 / d_emb) })
|
select({ vector: util.randn(d_emb, 1 / d_emb) })
|
||||||
|
|
||||||
const handleKey = ev => {
|
const handleKey = ev => {
|
||||||
const num = parseInt(ev.key)
|
const num = parseInt(ev.key)
|
||||||
|
@ -28,6 +28,10 @@
|
|||||||
{#key `${queryCounter}${result.file}`}
|
{#key `${queryCounter}${result.file}`}
|
||||||
<div class="result">
|
<div class="result">
|
||||||
<ResultImage {result} {results} {updateCounter} {redrawGrid} constrainBy="width" />
|
<ResultImage {result} {results} {updateCounter} {redrawGrid} constrainBy="width" />
|
||||||
|
{#if result[5]} <!-- debug info -->
|
||||||
|
<div>{result[0]}</div>
|
||||||
|
<div>{JSON.stringify(result[5])}</div>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
{/key}
|
{/key}
|
||||||
{/each}
|
{/each}
|
||||||
@ -72,7 +76,7 @@
|
|||||||
export const redrawGrid = async () => {
|
export const redrawGrid = async () => {
|
||||||
if (refreshLayout) {
|
if (refreshLayout) {
|
||||||
refreshLayout()
|
refreshLayout()
|
||||||
await recomputeScroll()
|
setTimeout(recomputeScroll, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,7 +86,7 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleScroll = () => {
|
const handleScroll = () => {
|
||||||
if (window.scrollY + window.innerHeight >= heightThreshold && pendingImageLoads === 0) {
|
if (window.scrollY + window.innerHeight >= heightThreshold) {
|
||||||
recomputeScroll()
|
recomputeScroll()
|
||||||
if (window.scrollY + window.innerHeight < heightThreshold) return;
|
if (window.scrollY + window.innerHeight < heightThreshold) return;
|
||||||
let init = displayedResults.length
|
let init = displayedResults.length
|
||||||
|
@ -1,4 +1,11 @@
|
|||||||
@use 'sass:color'
|
@use 'sass:color'
|
||||||
|
|
||||||
|
@font-face
|
||||||
|
font-family: 'Iosevka'
|
||||||
|
font-style: normal
|
||||||
|
font-weight: 400
|
||||||
|
src: url(./iosevka.woff2) format('woff2')
|
||||||
|
unicode-range: U+0000-00C0, U+A2-A9, U+AC-AE, U+00D7, U+00F7, U+FEFF, U+FFFD
|
||||||
|
|
||||||
$palette-primary: #3f9b0b
|
$palette-primary: #3f9b0b
|
||||||
$palette-secondary: #033500
|
$palette-secondary: #033500
|
||||||
|
BIN
clipfront2/src/iosevka.woff2
Normal file
BIN
clipfront2/src/iosevka.woff2
Normal file
Binary file not shown.
@ -78,3 +78,15 @@ fetch(config.backend_url).then(x => x.json().then(x => {
|
|||||||
serverConfig.set(x)
|
serverConfig.set(x)
|
||||||
window.serverConfig = x
|
window.serverConfig = x
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
export const vecSum = (xs, ys) => xs.map((x, i) => x + ys[i])
|
||||||
|
export const vecZero = d => new Array(d).fill(0)
|
||||||
|
export const vecScale = (xs, s) => xs.map(x => x * s)
|
||||||
|
|
||||||
|
const boxMuller = () => {
|
||||||
|
let x = Math.random()
|
||||||
|
let y = Math.random()
|
||||||
|
return Math.sqrt(-2.0 * Math.log(x)) * Math.cos(2.0 * Math.PI * y)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const randn = (d, sigma) => Array.from({ length: d }, () => boxMuller() * sigma)
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
<meta name="viewport" content="width=device-width, minimum-scale=1, initial-scale=1, user-scalable=yes">
|
<meta name="viewport" content="width=device-width, minimum-scale=1, initial-scale=1, user-scalable=yes">
|
||||||
<meta name="description" content="Organizing the world's memes.">
|
<meta name="description" content="Organizing the world's memes.">
|
||||||
|
|
||||||
<title>Meme Search Engine</title>
|
<title>Nooscope</title>
|
||||||
</style>
|
</style>
|
||||||
<link rel="stylesheet" href="app.css">
|
<link rel="stylesheet" href="app.css">
|
||||||
<link rel="icon" type="image/png" href="logo.png">
|
<link rel="icon" type="image/png" href="logo.png">
|
||||||
|
10
config2.json
Normal file
10
config2.json
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"listen_address": "127.0.0.1:5601",
|
||||||
|
"clip_server": "http://100.64.0.10:1708",
|
||||||
|
"descriptor_names": [
|
||||||
|
"Useful",
|
||||||
|
"Meme",
|
||||||
|
"Aesthetic",
|
||||||
|
"Time"
|
||||||
|
]
|
||||||
|
}
|
164
diskann/chainq.py
Normal file
164
diskann/chainq.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import numpy as np
|
||||||
|
import msgpack
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import autograd
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# https://github.com/una-dinosauria/local-search-quantization/blob/master/src/encodings/encode_chain.jl#L2
|
||||||
|
# vectorized somewhat
|
||||||
|
def viterbi_encode(
|
||||||
|
out_codes: torch.Tensor, # N x M (ints)
|
||||||
|
vectors: torch.Tensor, # N x D
|
||||||
|
codebooks: torch.Tensor # M x H x D
|
||||||
|
):
|
||||||
|
N, D = vectors.shape
|
||||||
|
M, H, D2 = codebooks.shape
|
||||||
|
assert D == D2
|
||||||
|
|
||||||
|
# M x H x N - ||x-c||^2 ignoring x.T @ x component
|
||||||
|
unary_costs = -2 * (codebooks @ vectors.T) + (torch.linalg.norm(codebooks, dim=2) ** 2).unsqueeze(-1)
|
||||||
|
binary_costs = torch.zeros(M - 1, H, H, dtype=torch.float, device=DEVICE)
|
||||||
|
for i in range(M - 1):
|
||||||
|
binary_costs[i] = 2 * codebooks[i] @ codebooks[i + 1].T
|
||||||
|
print("binary costs", binary_costs)
|
||||||
|
|
||||||
|
min_cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
|
||||||
|
min_idx = torch.zeros(M, H, N, dtype=torch.int, device=DEVICE)
|
||||||
|
|
||||||
|
cost = torch.zeros(H, N, dtype=torch.float, device=DEVICE)
|
||||||
|
|
||||||
|
# forward pass - propagate optimal costs and indices forward
|
||||||
|
for step in tqdm.trange(M - 1):
|
||||||
|
if step > 0:
|
||||||
|
unary_costs[step] += min_cost
|
||||||
|
|
||||||
|
ucost = unary_costs[step]
|
||||||
|
|
||||||
|
# for all possible costs at this step
|
||||||
|
for j in range(H):
|
||||||
|
bcost = binary_costs[step, j].unsqueeze(-1) # independent of N
|
||||||
|
cost = ucost + bcost
|
||||||
|
|
||||||
|
min_values, min_indices = torch.min(cost, dim=0)
|
||||||
|
min_cost[j] = min_values
|
||||||
|
min_idx[step, j] = min_indices
|
||||||
|
|
||||||
|
unary_costs[-1] += min_cost
|
||||||
|
|
||||||
|
# backward pass - propagate optimal indices backwards
|
||||||
|
out_codes[:, -1] = torch.argmin(unary_costs[-1], dim=0)
|
||||||
|
for i in range(M - 2, -1, -1):
|
||||||
|
out_codes[:, i] = min_idx[i][out_codes[:, i + 1], range(N)]
|
||||||
|
|
||||||
|
def dims_for(dim, total_dims, m):
|
||||||
|
dims_per_code = total_dims // m
|
||||||
|
relevant_codebooks = [dim // dims_per_code]
|
||||||
|
if relevant_codebooks[-1] < m - 1:
|
||||||
|
relevant_codebooks.append(relevant_codebooks[-1] + 1)
|
||||||
|
return relevant_codebooks
|
||||||
|
|
||||||
|
def update_codebooks(transformed_data, codes, h):
|
||||||
|
n, d = transformed_data.shape
|
||||||
|
n2, m = codes.shape
|
||||||
|
assert n == n2
|
||||||
|
|
||||||
|
new_codebook = torch.zeros(m, h, d, dtype=torch.float, device=DEVICE)
|
||||||
|
for dim in tqdm.trange(d):
|
||||||
|
relevant_codebooks = dims_for(dim, d, m)
|
||||||
|
assignment_matrix = torch.zeros(n, len(relevant_codebooks), h, dtype=torch.float, device=DEVICE)
|
||||||
|
indices = (
|
||||||
|
torch.arange(n, dtype=torch.int, device=DEVICE).repeat(len(relevant_codebooks)),
|
||||||
|
torch.arange(len(relevant_codebooks), dtype=torch.int, device=DEVICE).repeat_interleave(n),
|
||||||
|
codes[:, relevant_codebooks].T.flatten()
|
||||||
|
)
|
||||||
|
assignment_matrix[indices] = 1
|
||||||
|
#print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
|
||||||
|
assignment_matrix = assignment_matrix.reshape(n, len(relevant_codebooks) * h)
|
||||||
|
#print(assignment_matrix, assignment_matrix.shape, transformed_data[:, dim], transformed_data[:, dim].shape)
|
||||||
|
#soln = torch.linalg.lstsq(assignment_matrix, transformed_data[:, dim])[0]
|
||||||
|
|
||||||
|
reg = 1e-3 * torch.eye(len(relevant_codebooks) * h, device=DEVICE)
|
||||||
|
A = assignment_matrix.T @ assignment_matrix + reg
|
||||||
|
b = assignment_matrix.T @ transformed_data[:, dim]
|
||||||
|
|
||||||
|
#print("matrix", A)
|
||||||
|
|
||||||
|
usage = assignment_matrix.sum(dim=0)
|
||||||
|
unused = usage < 1
|
||||||
|
|
||||||
|
print(unused.sum().detach().item())
|
||||||
|
soln = torch.linalg.solve(A, b)
|
||||||
|
|
||||||
|
#print("solution", soln.reshape(len(relevant_codebooks), h))
|
||||||
|
if unused.any():
|
||||||
|
soln[unused] = torch.randn_like(soln[unused])
|
||||||
|
|
||||||
|
new_codebook[relevant_codebooks, :, dim] = soln.reshape(len(relevant_codebooks), h)
|
||||||
|
|
||||||
|
if torch.isnan(new_codebook[relevant_codebooks, :, dim]).any():
|
||||||
|
print("oh no", dim, new_codebook, relevant_codebooks, new_codebook[relevant_codebooks, :, dim])
|
||||||
|
print("--- dim ---", dim)
|
||||||
|
print("- sum per column:", assignment_matrix.sum(dim=0)) # Check if any columns are all zero
|
||||||
|
print("- rank:", torch.linalg.matrix_rank(assignment_matrix))
|
||||||
|
print("- condition number:", torch.linalg.cond(assignment_matrix))
|
||||||
|
raise SystemExit
|
||||||
|
|
||||||
|
return new_codebook
|
||||||
|
|
||||||
|
BATCH = 8192
|
||||||
|
|
||||||
|
def train_chainq(vectors, m, h, transform, codebooks, n_iters):
|
||||||
|
for i in range(n_iters):
|
||||||
|
transformed_data = vectors @ transform.T
|
||||||
|
codes = torch.zeros(vectors.shape[0], m, dtype=torch.int, device=DEVICE)
|
||||||
|
for i in range(0, vectors.shape[0], BATCH):
|
||||||
|
viterbi_encode(codes[i:i+BATCH], transformed_data[i:i+BATCH], codebooks)
|
||||||
|
print("encoded")
|
||||||
|
#codebooks = update_codebooks(transformed_data, codes, h)
|
||||||
|
print("codebooks updated")
|
||||||
|
|
||||||
|
quantized = torch.zeros_like(vectors, dtype=torch.float, device=DEVICE)
|
||||||
|
for j in range(m):
|
||||||
|
quantized[:] += codebooks[j, codes[:, j]]
|
||||||
|
print("quantized")
|
||||||
|
print((quantized - transformed_data).abs().mean(), transformed_data.abs().mean())
|
||||||
|
|
||||||
|
print("comparing")
|
||||||
|
res = transformed_data.T @ quantized
|
||||||
|
|
||||||
|
print("running SVD...")
|
||||||
|
u, s, vt = torch.linalg.svd(res)
|
||||||
|
print("done.")
|
||||||
|
transform = u @ vt
|
||||||
|
print("regenerated transform")
|
||||||
|
|
||||||
|
return codebooks, transform
|
||||||
|
|
||||||
|
with open("opq.msgpack", "rb") as f:
|
||||||
|
data = msgpack.unpackb(f.read())
|
||||||
|
|
||||||
|
n_dims = 1152
|
||||||
|
dataset = torch.tensor(np.random.permutation(np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32))[:BATCH*1], device=DEVICE)
|
||||||
|
|
||||||
|
codebooks = torch.zeros(64, 256, n_dims, dtype=torch.float, device=DEVICE)
|
||||||
|
|
||||||
|
centroids = torch.tensor(np.array(data["centroids"]).astype(np.float32).reshape(256, n_dims), device=DEVICE)
|
||||||
|
for dim in range(n_dims):
|
||||||
|
relevant_codebooks = dim // 64
|
||||||
|
codebooks[relevant_codebooks, :, dim] = centroids[:, dim]
|
||||||
|
|
||||||
|
print(centroids)
|
||||||
|
#print("codebooks", codebooks.tolist())
|
||||||
|
|
||||||
|
codebooks, transform = train_chainq(dataset, 64, 256, torch.tensor(np.array(data["transform"]).astype(np.float32).reshape(n_dims, n_dims), device=DEVICE), codebooks, 100)
|
||||||
|
|
||||||
|
with open("chainq.msgpack", "wb") as f:
|
||||||
|
msgpack.pack({
|
||||||
|
"codebooks": codebooks.cpu().numpy().flatten().tolist(),
|
||||||
|
"transform": transform.cpu().numpy().flatten().tolist(),
|
||||||
|
"n_dims": n_dims,
|
||||||
|
"n_dims_per_code": n_dims // 64
|
||||||
|
}, f)
|
@ -45,7 +45,7 @@ impl Vector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
|
// Floats are vaguely annoying and not sortable (trivially), so we mostly represent dot products as integers
|
||||||
pub const SCALE: f32 = 1099511627776.0;
|
pub const SCALE: f32 = 4294967296.0;
|
||||||
pub const SCALE_F64: f64 = SCALE as f64;
|
pub const SCALE_F64: f64 = SCALE as f64;
|
||||||
|
|
||||||
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 {
|
pub fn dot<'a>(x: VectorRef<'a>, y: VectorRef<'a>) -> i64 {
|
||||||
|
131
faiss_bench_quantizer.py
Normal file
131
faiss_bench_quantizer.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import faiss
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def eval_codec(q, xb):
|
||||||
|
t0 = time.time()
|
||||||
|
codes = q.compute_codes(xb)
|
||||||
|
t1 = time.time()
|
||||||
|
xb_decoded = q.decode(codes)
|
||||||
|
recons_err = ((xb - xb_decoded) ** 2).sum(axis=1).mean()
|
||||||
|
print(f"\tencode time: {t1 - t0:.3f} reconstruction error: {recons_err:.3f} ")
|
||||||
|
|
||||||
|
|
||||||
|
def eval_quantizer(q, xb, xt, variants=None):
|
||||||
|
if variants is None:
|
||||||
|
variants = [(None, None)]
|
||||||
|
t0 = time.time()
|
||||||
|
q.train(xt)
|
||||||
|
t1 = time.time()
|
||||||
|
train_t = t1 - t0
|
||||||
|
print(f'\ttraining time: {train_t:.3f} s')
|
||||||
|
for name, val in variants:
|
||||||
|
if name is not None:
|
||||||
|
print(f"{name}={val}")
|
||||||
|
|
||||||
|
if isinstance(q, faiss.ProductAdditiveQuantizer):
|
||||||
|
for i in range(q.nsplits):
|
||||||
|
subq = faiss.downcast_Quantizer(q.subquantizer(i))
|
||||||
|
getattr(subq, name)
|
||||||
|
setattr(subq, name, val)
|
||||||
|
else:
|
||||||
|
getattr(q, name) # make sure field exists
|
||||||
|
setattr(q, name, val)
|
||||||
|
|
||||||
|
eval_codec(q, xb)
|
||||||
|
|
||||||
|
|
||||||
|
todo = sys.argv[1:]
|
||||||
|
|
||||||
|
ds = np.fromfile(todo[0], dtype=np.float16).reshape(-1, 1152).astype(np.float32)
|
||||||
|
print(ds)
|
||||||
|
del todo[0]
|
||||||
|
|
||||||
|
if len(todo) > 0:
|
||||||
|
if todo[0].count("x") == 1:
|
||||||
|
M, nbits = [int(x) for x in todo[0].split("x")]
|
||||||
|
del todo[0]
|
||||||
|
elif todo[0].count("x") == 2:
|
||||||
|
nsplits, Msub, nbits = [int(x) for x in todo[0].split("x")]
|
||||||
|
M = nsplits * Msub
|
||||||
|
del todo[0]
|
||||||
|
|
||||||
|
maxtrain = max(100 << nbits, 10**5)
|
||||||
|
print(f"eval on {M}x{nbits} maxtrain={maxtrain}")
|
||||||
|
|
||||||
|
xb, xt = ds, ds
|
||||||
|
|
||||||
|
nb, d = xb.shape
|
||||||
|
nt, d = xt.shape
|
||||||
|
|
||||||
|
|
||||||
|
# fastest to slowest
|
||||||
|
|
||||||
|
if 'lsq-gpu' in todo:
|
||||||
|
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
||||||
|
ngpus = faiss.get_num_gpus()
|
||||||
|
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
|
||||||
|
lsq.verbose = True
|
||||||
|
eval_quantizer(lsq, xb, xt, 'lsq-gpu')
|
||||||
|
|
||||||
|
if 'pq' in todo:
|
||||||
|
pq = faiss.ProductQuantizer(d, M, nbits)
|
||||||
|
print("===== PQ")
|
||||||
|
eval_quantizer(pq, xb, xt)
|
||||||
|
|
||||||
|
if 'opq' in todo:
|
||||||
|
d2 = ((d + M - 1) // M) * M
|
||||||
|
print("OPQ d2=", d2)
|
||||||
|
opq = faiss.OPQMatrix(d, M, d2)
|
||||||
|
opq.train(xt)
|
||||||
|
xb2 = opq.apply(xb)
|
||||||
|
xt2 = opq.apply(xt)
|
||||||
|
pq = faiss.ProductQuantizer(d2, M, nbits)
|
||||||
|
print("===== PQ")
|
||||||
|
eval_quantizer(pq, xb2, xt2)
|
||||||
|
|
||||||
|
if 'prq' in todo:
|
||||||
|
print(f"===== PRQ{nsplits}x{Msub}x{nbits}")
|
||||||
|
prq = faiss.ProductResidualQuantizer(d, nsplits, Msub, nbits)
|
||||||
|
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
|
||||||
|
eval_quantizer(prq, xb, xt, variants=variants)
|
||||||
|
|
||||||
|
if 'plsq' in todo:
|
||||||
|
print(f"===== PLSQ{nsplits}x{Msub}x{nbits}")
|
||||||
|
plsq = faiss.ProductLocalSearchQuantizer(d, nsplits, Msub, nbits)
|
||||||
|
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
|
||||||
|
eval_quantizer(plsq, xb, xt, variants=variants)
|
||||||
|
|
||||||
|
if 'rq' in todo:
|
||||||
|
print("===== RQ")
|
||||||
|
rq = faiss.ResidualQuantizer(d, M, nbits, )
|
||||||
|
rq.max_beam_size
|
||||||
|
rq.max_beam_size = 30 # for compatibility with older runs
|
||||||
|
# rq.train_type = faiss.ResidualQuantizer.Train_default
|
||||||
|
# rq.verbose = True
|
||||||
|
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
|
||||||
|
eval_quantizer(rq, xb, xt, variants=variants)
|
||||||
|
|
||||||
|
if 'rq_lut' in todo:
|
||||||
|
print("===== RQ")
|
||||||
|
rq = faiss.ResidualQuantizer(d, M, nbits, )
|
||||||
|
rq.max_beam_size
|
||||||
|
rq.max_beam_size = 30 # for compatibility with older runs
|
||||||
|
rq.use_beam_LUT
|
||||||
|
rq.use_beam_LUT = 1
|
||||||
|
# rq.train_type = faiss.ResidualQuantizer.Train_default
|
||||||
|
# rq.verbose = True
|
||||||
|
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32, 64)]
|
||||||
|
eval_quantizer(rq, xb, xt, variants=variants)
|
||||||
|
|
||||||
|
if 'lsq' in todo:
|
||||||
|
print("===== LSQ")
|
||||||
|
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
|
||||||
|
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
|
||||||
|
eval_quantizer(lsq, xb, xt, variants=variants)
|
@ -1,8 +1,12 @@
|
|||||||
{
|
{
|
||||||
"backend_url": "http://localhost:5601/",
|
"backend_url": "/backend",
|
||||||
"image_path": "",
|
"image_path": "",
|
||||||
"thumb_path": null,
|
"thumb_path": null,
|
||||||
"description": "Organizing the world's memes.",
|
"description": "Organizing the world's memes.",
|
||||||
"name": "Nooscope",
|
"name": "Nooscope",
|
||||||
"telemetry_endpoint": "/telemetry"
|
"telemetry_endpoint": "/telemetry",
|
||||||
|
"friendly_mode_default_terms": [
|
||||||
|
{ "predefined_embedding": "Meme", "weight": 1, "sign": "+" },
|
||||||
|
{ "predefined_embedding": "Aesthetic", "weight": 0.2, "sign": "+" }
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
4
genseahash.py
Normal file
4
genseahash.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
import seahash, sys
|
||||||
|
|
||||||
|
with open(sys.argv[1], "rb") as f:
|
||||||
|
print(seahash.hash(f.read()))
|
@ -31,6 +31,7 @@ model.load_state_dict(torch.load(modelc))
|
|||||||
params = sum(p.numel() for p in model.parameters())
|
params = sum(p.numel() for p in model.parameters())
|
||||||
print(f"{params/1e6:.1f}M parameters")
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
print(model)
|
print(model)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
files = shared.fetch_all_files()
|
files = shared.fetch_all_files()
|
||||||
variance = {}
|
variance = {}
|
||||||
@ -56,4 +57,4 @@ with torch.inference_mode():
|
|||||||
|
|
||||||
top = sorted(variance.items(), key=lambda x: -x[1])
|
top = sorted(variance.items(), key=lambda x: -x[1])
|
||||||
with open("top.json", "w") as f:
|
with open("top.json", "w") as f:
|
||||||
json.dump(top[:100], f)
|
json.dump(top[:50], f)
|
||||||
|
@ -52,7 +52,7 @@ channel = int(sys.argv[2])
|
|||||||
percentile = float(sys.argv[3])
|
percentile = float(sys.argv[3])
|
||||||
output_pairs = int(sys.argv[4])
|
output_pairs = int(sys.argv[4])
|
||||||
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
|
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
|
||||||
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
|
top = sorted(((filename, score) for filename, score in results.items()), key=lambda x: x[1][channel], reverse=True)
|
||||||
select_from = top[:int(len(top) * percentile)]
|
select_from = top[:int(len(top) * percentile)]
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
|
@ -13,14 +13,14 @@ import sys
|
|||||||
from model import Config, BradleyTerry
|
from model import Config, BradleyTerry
|
||||||
import shared
|
import shared
|
||||||
|
|
||||||
batch_size = 128
|
batch_size = 32
|
||||||
num_pairs = batch_size * 1024
|
num_pairs = batch_size * 1024
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
d_emb=1152,
|
d_emb=1152,
|
||||||
n_hidden=1,
|
n_hidden=1,
|
||||||
n_ensemble=1,
|
n_ensemble=16,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
output_channels=3,
|
output_channels=3,
|
||||||
@ -32,6 +32,7 @@ model.load_state_dict(torch.load(modelc))
|
|||||||
params = sum(p.numel() for p in model.parameters())
|
params = sum(p.numel() for p in model.parameters())
|
||||||
print(f"{params/1e6:.1f}M parameters")
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
print(model)
|
print(model)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
files = shared.fetch_all_files()
|
files = shared.fetch_all_files()
|
||||||
importance = {}
|
importance = {}
|
||||||
@ -41,8 +42,8 @@ buffers = {k: v.detach() for k, v in model.named_buffers()}
|
|||||||
|
|
||||||
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
|
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
|
||||||
def compute_loss(params, buffers, sample, target):
|
def compute_loss(params, buffers, sample, target):
|
||||||
batch = sample.unsqueeze(0)
|
batch = sample.unsqueeze(1)
|
||||||
targets = target.unsqueeze(0)
|
targets = target.unsqueeze(1).expand((config.n_ensemble, 1, config.output_channels))
|
||||||
|
|
||||||
predictions = functional_call(model, (params, buffers), (batch,))
|
predictions = functional_call(model, (params, buffers), (batch,))
|
||||||
loss = F.binary_cross_entropy(predictions, targets)
|
loss = F.binary_cross_entropy(predictions, targets)
|
||||||
@ -75,4 +76,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
|
|||||||
|
|
||||||
top = sorted(importance.items(), key=lambda x: -x[1])
|
top = sorted(importance.items(), key=lambda x: -x[1])
|
||||||
with open("top.json", "w") as f:
|
with open("top.json", "w") as f:
|
||||||
json.dump(top[:256], f)
|
json.dump(top[:50], f)
|
||||||
|
@ -31,6 +31,11 @@ params = sum(p.numel() for p in model.parameters())
|
|||||||
print(f"{params/1e6:.1f}M parameters")
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
|
torch.random.manual_seed(1)
|
||||||
|
|
||||||
|
for x in model.ensemble.models:
|
||||||
|
x.output.bias.data.fill_(0)
|
||||||
|
|
||||||
out_layers = []
|
out_layers = []
|
||||||
out_bias = []
|
out_bias = []
|
||||||
|
|
||||||
@ -50,7 +55,7 @@ with torch.inference_mode():
|
|||||||
downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone()
|
downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone()
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
input = torch.randn(3, config.d_emb)
|
input = torch.randn(4, config.d_emb)
|
||||||
ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T
|
ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T
|
||||||
r_result = input
|
r_result = input
|
||||||
for (layer, bias) in zip(out_layers, out_bias):
|
for (layer, bias) in zip(out_layers, out_bias):
|
||||||
@ -58,13 +63,14 @@ with torch.inference_mode():
|
|||||||
print(r_result.shape, bias.shape)
|
print(r_result.shape, bias.shape)
|
||||||
r_result = F.silu(r_result)
|
r_result = F.silu(r_result)
|
||||||
r_result = torch.matmul(downprojection, r_result) / config.n_ensemble
|
r_result = torch.matmul(downprojection, r_result) / config.n_ensemble
|
||||||
bias_diff = torch.mean(r_result - ground_truth_result)
|
error = torch.mean(r_result - ground_truth_result)
|
||||||
model_issue_diff = torch.max(torch.abs(r_result - bias_diff - ground_truth_result))
|
print(error)
|
||||||
print(model_issue_diff)
|
assert error.detach().cpu().numpy() < 1e-4
|
||||||
|
|
||||||
print("test vector:")
|
print("test vector:")
|
||||||
print(input.flatten().tolist())
|
#print(input.flatten().tolist())
|
||||||
print("ground truth result:")
|
print("ground truth result:")
|
||||||
|
print(ground_truth_result.shape)
|
||||||
print(ground_truth_result.T.flatten().tolist())
|
print(ground_truth_result.T.flatten().tolist())
|
||||||
|
|
||||||
save_file({
|
save_file({
|
||||||
|
@ -35,7 +35,7 @@ class Ensemble(nn.Module):
|
|||||||
|
|
||||||
# model batch
|
# model batch
|
||||||
def forward(self, embs):
|
def forward(self, embs):
|
||||||
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
|
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_channels
|
||||||
|
|
||||||
class BradleyTerry(nn.Module):
|
class BradleyTerry(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -51,19 +51,25 @@ async def index(request):
|
|||||||
<form action="/rate" method="POST">
|
<form action="/rate" method="POST">
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
|
<td><input type="radio" name="rating-useful" value="1+" id="rq1p"> <label for="rq1p">LHS is much better (useful)</label></td>
|
||||||
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
|
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
|
||||||
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
|
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
|
||||||
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
|
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
|
||||||
|
<td><input type="radio" name="rating-useful" value="2+" id="rq2p"> <label for="rq2p">LHS is much better (useful)</label></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
|
<td><input type="radio" name="rating-meme" value="1+" id="rm1p"> <label for="rm1p">LHS is much better (memetically)</label></td>
|
||||||
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
|
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
|
||||||
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
|
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
|
||||||
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
|
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
|
||||||
|
<td><input type="radio" name="rating-meme" value="2+" id="rm2p"> <label for="rm2p">LHS is much better (memetically)</label></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
|
<td><input type="radio" name="rating-aesthetic" value="1+" id="ra1p"> <label for="ra1p">LHS is much better (aesthetically)</label></td>
|
||||||
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
|
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
|
||||||
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
|
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
|
||||||
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
|
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
|
||||||
|
<td><input type="radio" name="rating-aesthetic" value="2+" id="ra2p"> <label for="ra2p">LHS is much better (aesthetically)</label></td>
|
||||||
</td>
|
</td>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
@ -82,35 +88,29 @@ async def index(request):
|
|||||||
document.querySelector("form").submit()
|
document.querySelector("form").submit()
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
|
const keys = {{
|
||||||
|
"q": "rq1p",
|
||||||
|
"w": "rq1",
|
||||||
|
"e": "rqe",
|
||||||
|
"r": "rq2",
|
||||||
|
"t": "rq2p",
|
||||||
|
"a": "rm1p",
|
||||||
|
"s": "rm1",
|
||||||
|
"d": "rme",
|
||||||
|
"f": "rm2",
|
||||||
|
"g": "rm2p",
|
||||||
|
"z": "ra1p",
|
||||||
|
"x": "ra1",
|
||||||
|
"c": "rae",
|
||||||
|
"v": "ra2",
|
||||||
|
"b": "ra2p",
|
||||||
|
}}
|
||||||
document.addEventListener("keypress", function(event) {{
|
document.addEventListener("keypress", function(event) {{
|
||||||
if (event.key === "q") {{
|
const key = keys[event.key]
|
||||||
document.querySelector("input[name='rating-useful'][value='1']").checked = true
|
if (key) {{
|
||||||
commitIfReady()
|
document.getElementById(key).checked = true
|
||||||
}} else if (event.key === "w") {{
|
|
||||||
document.querySelector("input[name='rating-useful'][value='eq']").checked = true
|
|
||||||
commitIfReady()
|
|
||||||
}} else if (event.key === "e") {{
|
|
||||||
document.querySelector("input[name='rating-useful'][value='2']").checked = true
|
|
||||||
commitIfReady()
|
|
||||||
}} else if (event.key === "a") {{
|
|
||||||
document.querySelector("input[name='rating-meme'][value='1']").checked = true
|
|
||||||
commitIfReady()
|
|
||||||
}} else if (event.key === "s") {{
|
|
||||||
document.querySelector("input[name='rating-meme'][value='eq']").checked = true
|
|
||||||
commitIfReady()
|
|
||||||
}} else if (event.key === "d") {{
|
|
||||||
document.querySelector("input[name='rating-meme'][value='2']").checked = true
|
|
||||||
commitIfReady()
|
|
||||||
}} else if (event.key === "z") {{
|
|
||||||
document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true
|
|
||||||
commitIfReady()
|
|
||||||
}} else if (event.key === "x") {{
|
|
||||||
document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true
|
|
||||||
commitIfReady()
|
|
||||||
}} else if (event.key === "c") {{
|
|
||||||
document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true
|
|
||||||
commitIfReady()
|
|
||||||
}}
|
}}
|
||||||
|
commitIfReady()
|
||||||
}});
|
}});
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
|
@ -20,27 +20,35 @@ def fetch_embedding(filename):
|
|||||||
csr.close()
|
csr.close()
|
||||||
return x.copy() # PyTorch complains otherwise due to bad
|
return x.copy() # PyTorch complains otherwise due to bad
|
||||||
|
|
||||||
def map_rating(rating, uncertainty=0.05):
|
def map_rating(rating):
|
||||||
def map_one(rating):
|
def map_one(rating):
|
||||||
match rating:
|
match rating:
|
||||||
case "1": # meme 1 is better
|
case "1": # meme 1 is better
|
||||||
return 1 - uncertainty
|
return 0.9
|
||||||
case "2":
|
case "2":
|
||||||
return uncertainty
|
return 0.1
|
||||||
|
case "2+" | "2p":
|
||||||
|
return 0.3
|
||||||
|
case "1+" | "1p":
|
||||||
|
return 0.7
|
||||||
case "eq":
|
case "eq":
|
||||||
return 0.5
|
return 0.5
|
||||||
case _: raise ValueError("invalid rating, please fix")
|
case _: raise ValueError("invalid rating, please fix")
|
||||||
|
|
||||||
return np.array([map_one(r) for r in rating.split(",")])
|
return np.array([map_one(r) for r in rating.split(",")])
|
||||||
|
|
||||||
def fetch_ratings():
|
def fetch_ratings(sets):
|
||||||
trains = defaultdict(list)
|
trains = defaultdict(list)
|
||||||
validations = defaultdict(list)
|
validations = defaultdict(list)
|
||||||
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
||||||
|
its = set()
|
||||||
for meme1, meme2, rating, iteration in csr.fetchall():
|
for meme1, meme2, rating, iteration in csr.fetchall():
|
||||||
|
if iteration not in its:
|
||||||
|
print(iteration)
|
||||||
|
its.add(iteration)
|
||||||
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
||||||
csr.close()
|
csr.close()
|
||||||
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
|
return list(x[1] for x in sorted(trains.items()) if str(x[0]) in sets), list(x[1] for x in sorted(validations.items()) if str(x[0]) in sets)
|
||||||
|
|
||||||
def generate_random_permutations(x, n):
|
def generate_random_permutations(x, n):
|
||||||
out = []
|
out = []
|
||||||
|
@ -9,11 +9,12 @@ import time
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
|
import sys
|
||||||
|
|
||||||
from model import Config as ModelConfig, BradleyTerry
|
from model import Config as ModelConfig, BradleyTerry
|
||||||
import shared
|
import shared
|
||||||
|
|
||||||
trains, validations = shared.fetch_ratings()
|
trains, validations = shared.fetch_ratings(sys.argv[1:])
|
||||||
for train, validation in zip(trains, validations):
|
for train, validation in zip(trains, validations):
|
||||||
print(len(train), len(validation))
|
print(len(train), len(validation))
|
||||||
|
|
||||||
@ -36,11 +37,11 @@ config = TrainConfig(
|
|||||||
n_ensemble=16,
|
n_ensemble=16,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
dropout=0.1,
|
dropout=0.0,
|
||||||
output_channels=3
|
output_channels=3
|
||||||
),
|
),
|
||||||
lr=3e-4,
|
lr=3e-4,
|
||||||
weight_decay=0.2,
|
weight_decay=0.0,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
epochs=5,
|
epochs=5,
|
||||||
compile=False,
|
compile=False,
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import pyarrow as pa
|
import numpy as np
|
||||||
|
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
with pa.memory_map("../../sample_1m.arrow", "r") as source:
|
loaded_arrays = np.memmap("embeddings.bin", dtype=np.float16).reshape(-1, 1152)
|
||||||
loaded_arrays = pa.ipc.open_file(source).read_all()
|
loaded_arrays_permutation = np.random.permutation(len(loaded_arrays))
|
||||||
|
|
||||||
train_split = 0.8
|
train_split = 0.8
|
||||||
|
|
||||||
|
@ -6,9 +6,10 @@ import json
|
|||||||
import time
|
import time
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
|
import math
|
||||||
|
|
||||||
from model import SAEConfig, SAE
|
from model import SAEConfig, SAE
|
||||||
from shared import train_split, loaded_arrays, ckpt_path
|
from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ class TrainConfig:
|
|||||||
config = TrainConfig(
|
config = TrainConfig(
|
||||||
model=SAEConfig(
|
model=SAEConfig(
|
||||||
d_emb=1152,
|
d_emb=1152,
|
||||||
d_hidden=65536,
|
d_hidden=262144,
|
||||||
top_k=128,
|
top_k=128,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -33,7 +34,7 @@ config = TrainConfig(
|
|||||||
lr=3e-4,
|
lr=3e-4,
|
||||||
weight_decay=0.0,
|
weight_decay=0.0,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
epochs=5,
|
epochs=1,
|
||||||
compile=True,
|
compile=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -81,7 +82,7 @@ with open(logfile, "w") as log:
|
|||||||
batch = []
|
batch = []
|
||||||
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
|
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
|
||||||
for batch_start in t:
|
for batch_start in t:
|
||||||
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
|
batch = numpy.stack([ embedding for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + config.batch_size]] ])
|
||||||
|
|
||||||
if len(batch) == config.batch_size:
|
if len(batch) == config.batch_size:
|
||||||
batch = torch.Tensor(batch).to(device)
|
batch = torch.Tensor(batch).to(device)
|
||||||
|
35
slow_dump_parse_script.py
Normal file
35
slow_dump_parse_script.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import umsgpack
|
||||||
|
import zstandard
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
data = []
|
||||||
|
|
||||||
|
with open("sample.zst", "rb") as f:
|
||||||
|
decomp = zstandard.ZstdDecompressor()
|
||||||
|
reader = decomp.stream_reader(f)
|
||||||
|
count = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
url, id, title, subreddit, author, timestamp, embedding = umsgpack.unpack(reader)
|
||||||
|
embedding = bytes(embedding)
|
||||||
|
data.append({"url": url, "id": id, "title": title, "subreddit": subreddit, "author": author, "timestamp": timestamp, "embedding": embedding})
|
||||||
|
count += 1
|
||||||
|
except umsgpack.InsufficientDataException:
|
||||||
|
break
|
||||||
|
print(count)
|
||||||
|
|
||||||
|
schema = pa.schema([
|
||||||
|
("url", pa.string()),
|
||||||
|
("id", pa.string()),
|
||||||
|
("title", pa.string()),
|
||||||
|
("subreddit", pa.string()),
|
||||||
|
("author", pa.string()),
|
||||||
|
("timestamp", pa.int64()),
|
||||||
|
("embedding", pa.binary())
|
||||||
|
])
|
||||||
|
|
||||||
|
table = pa.Table.from_pylist(data, schema=schema)
|
||||||
|
|
||||||
|
with pa.OSFile("output.parquet", "wb") as sink:
|
||||||
|
with pa.RecordBatchFileWriter(sink, schema) as writer:
|
||||||
|
writer.write_table(table)
|
@ -183,8 +183,8 @@ pub struct FrontendInit {
|
|||||||
pub type EmbeddingVector = Vec<f32>;
|
pub type EmbeddingVector = Vec<f32>;
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct QueryResult {
|
pub struct QueryResult<T> {
|
||||||
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>)>,
|
pub matches: Vec<(f32, String, String, u64, Option<(u32, u32)>, Option<T>)>,
|
||||||
pub formats: Vec<String>,
|
pub formats: Vec<String>,
|
||||||
pub extensions: HashMap<String, String>,
|
pub extensions: HashMap<String, String>,
|
||||||
}
|
}
|
||||||
@ -203,7 +203,9 @@ pub struct QueryRequest {
|
|||||||
pub terms: Vec<QueryTerm>,
|
pub terms: Vec<QueryTerm>,
|
||||||
pub k: Option<usize>,
|
pub k: Option<usize>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub include_video: bool
|
pub include_video: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub debug_enabled: bool
|
||||||
}
|
}
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
lazy_static::lazy_static! {
|
||||||
@ -238,8 +240,9 @@ pub async fn get_total_embedding<A: Future<Output = Result<Vec<Vec<u8>>>>, B: Fu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(name) = &term.predefined_embedding {
|
if let Some(name) = &term.predefined_embedding {
|
||||||
let embedding = predefined_embeddings.get(name).context("name invalid")?;
|
if let Some(embedding) = predefined_embeddings.get(name) {
|
||||||
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
total_embedding = total_embedding + embedding * term.weight.unwrap_or(1.0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,6 +69,10 @@ struct CLIArguments {
|
|||||||
gpu: Option<usize>,
|
gpu: Option<usize>,
|
||||||
#[argh(option, description="descriptor CDFs")]
|
#[argh(option, description="descriptor CDFs")]
|
||||||
cdfs: Option<String>,
|
cdfs: Option<String>,
|
||||||
|
#[argh(option, description="postfilter by embedding (late discard if dot product above threshold)")]
|
||||||
|
postfilter: Vec<String>,
|
||||||
|
#[argh(option, description="postfilter by scorer")]
|
||||||
|
postfilter_scorer: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||||
@ -162,10 +166,22 @@ fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
(snd, None)
|
(snd, None)
|
||||||
};
|
};
|
||||||
|
let mut post_threshold = None;
|
||||||
|
for x in &args.postfilter {
|
||||||
|
let (tname, snd) = x.split_once(':').context("invalid postfilter argument")?;
|
||||||
|
if tname == name {
|
||||||
|
post_threshold = Some(snd.parse::<f32>().context("parse postfilter threshold")?);
|
||||||
|
}
|
||||||
|
}
|
||||||
let blob = fs::read(path).context("read embedding")?;
|
let blob = fs::read(path).context("read embedding")?;
|
||||||
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold));
|
embeddings.push((name.to_string(), common::decode_fp16_buffer(&blob), Histogram::new(-1.0, 1.0, 512), threshold, post_threshold));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let postfilter_scorer = args.postfilter_scorer.iter().map(|x| {
|
||||||
|
let (id, snd) = x.split_once(':').context("invalid postfilter scorer argument")?;
|
||||||
|
Ok((id.parse::<usize>().context("invalid postfilter scorer id")?, snd.parse::<f32>().context("parse postfilter scorer threshold")?))
|
||||||
|
}).collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let pq_codec = if let Some(pq_codec) = args.pq_codec {
|
let pq_codec = if let Some(pq_codec) = args.pq_codec {
|
||||||
let data = fs::read(pq_codec).context("read pq codec")?;
|
let data = fs::read(pq_codec).context("read pq codec")?;
|
||||||
let pq_codec: ProductQuantizer = rmp_serde::from_read(&data[..]).context("decode pq codec")?;
|
let pq_codec: ProductQuantizer = rmp_serde::from_read(&data[..]).context("decode pq codec")?;
|
||||||
@ -255,7 +271,6 @@ fn main() -> Result<()> {
|
|||||||
files.sort_by_key(|(id, _)| *id);
|
files.sort_by_key(|(id, _)| *id);
|
||||||
shard_id_mappings.sort_by_key(|(id, _)| *id);
|
shard_id_mappings.sort_by_key(|(id, _)| *id);
|
||||||
|
|
||||||
|
|
||||||
let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
|
let read_out_vertices = move |id: u32| -> Result<(Vec<u32>, Vec<u32>)> {
|
||||||
let mut out_vertices: Vec<u32> = vec![];
|
let mut out_vertices: Vec<u32> = vec![];
|
||||||
let mut shards: Vec<u32> = vec![];
|
let mut shards: Vec<u32> = vec![];
|
||||||
@ -323,6 +338,8 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let th = std::thread::spawn(move || reader_thread(&args.paths, tx));
|
let th = std::thread::spawn(move || reader_thread(&args.paths, tx));
|
||||||
|
|
||||||
|
let mut postfilter_count = 0;
|
||||||
|
|
||||||
let mut rng2 = rng.fork();
|
let mut rng2 = rng.fork();
|
||||||
let initial_filter = |x: ProcessedEntry| {
|
let initial_filter = |x: ProcessedEntry| {
|
||||||
i += 1;
|
i += 1;
|
||||||
@ -337,7 +354,9 @@ fn main() -> Result<()> {
|
|||||||
latest_timestamp = latest_timestamp.max(timestamp);
|
latest_timestamp = latest_timestamp.max(timestamp);
|
||||||
earliest_timestamp = earliest_timestamp.min(timestamp);
|
earliest_timestamp = earliest_timestamp.min(timestamp);
|
||||||
|
|
||||||
for (_name, vec, histogram, threshold) in &mut embeddings {
|
let mut postfilter = false;
|
||||||
|
|
||||||
|
for (_name, vec, histogram, threshold, postfilter_threshold) in &mut embeddings {
|
||||||
let dot = SpatialSimilarity::dot(&embedding, vec).unwrap() as f32;
|
let dot = SpatialSimilarity::dot(&embedding, vec).unwrap() as f32;
|
||||||
histogram.add(dot);
|
histogram.add(dot);
|
||||||
if let Some(threshold) = threshold {
|
if let Some(threshold) = threshold {
|
||||||
@ -345,6 +364,12 @@ fn main() -> Result<()> {
|
|||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if let Some(threshold) = postfilter_threshold {
|
||||||
|
if dot >= *threshold {
|
||||||
|
postfilter = true;
|
||||||
|
postfilter_count += 1; // somewhat wrong because could be duplicated
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// distance thresholding is too costly to do over a long range so just do it badly
|
// distance thresholding is too costly to do over a long range so just do it badly
|
||||||
@ -393,7 +418,7 @@ fn main() -> Result<()> {
|
|||||||
println!("{}", data);
|
println!("{}", data);
|
||||||
}
|
}
|
||||||
|
|
||||||
Some((x, embedding))
|
Some((x, embedding, postfilter))
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut dead_count = 0;
|
let mut dead_count = 0;
|
||||||
@ -404,14 +429,14 @@ fn main() -> Result<()> {
|
|||||||
let batch: Vec<_> = batch.collect();
|
let batch: Vec<_> = batch.collect();
|
||||||
let batch_len = batch.len();
|
let batch_len = batch.len();
|
||||||
|
|
||||||
for (x, _embedding) in batch.iter() {
|
for (x, _embedding, _postfilter) in batch.iter() {
|
||||||
if let Some(ref mut file) = output_file {
|
if let Some(ref mut file) = output_file {
|
||||||
file.write_all(&x.embedding)?;
|
file.write_all(&x.embedding)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(shards) = &mut shards_out {
|
if let Some(shards) = &mut shards_out {
|
||||||
for (i, (x, embedding)) in batch.iter().enumerate() {
|
for (i, (x, embedding, _postfilter)) in batch.iter().enumerate() {
|
||||||
// closest matches first
|
// closest matches first
|
||||||
shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| {
|
shards.sort_by_cached_key(|&(ref centroid, _, shard_count, _shard_index)| {
|
||||||
let mut dot = SpatialSimilarity::dot(¢roid, &embedding).unwrap();
|
let mut dot = SpatialSimilarity::dot(¢roid, &embedding).unwrap();
|
||||||
@ -439,7 +464,7 @@ fn main() -> Result<()> {
|
|||||||
let quantizer = pq_codec.as_ref().context("PQ codec needed to output index")?;
|
let quantizer = pq_codec.as_ref().context("PQ codec needed to output index")?;
|
||||||
|
|
||||||
let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize);
|
let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize);
|
||||||
for (_x, embedding) in batch.iter() {
|
for (_x, embedding, _postfilter) in batch.iter() {
|
||||||
batch_embeddings.extend_from_slice(&embedding);
|
batch_embeddings.extend_from_slice(&embedding);
|
||||||
}
|
}
|
||||||
let codes = quantizer.quantize_batch(&batch_embeddings);
|
let codes = quantizer.quantize_batch(&batch_embeddings);
|
||||||
@ -448,7 +473,7 @@ fn main() -> Result<()> {
|
|||||||
let cdfs = cdfs.as_ref().context("score model CDFs needed to output index")?;
|
let cdfs = cdfs.as_ref().context("score model CDFs needed to output index")?;
|
||||||
let scores = score_model.score_batch(&batch_embeddings)?;
|
let scores = score_model.score_batch(&batch_embeddings)?;
|
||||||
|
|
||||||
for (i, (x, _embedding)) in batch.into_iter().enumerate() {
|
for (i, (x, _embedding, mut postfilter)) in batch.into_iter().enumerate() {
|
||||||
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
|
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
|
||||||
|
|
||||||
let mut entry_scores = scores[(i * score_model.output_channels)..((i + 1) * score_model.output_channels)].to_vec();
|
let mut entry_scores = scores[(i * score_model.output_channels)..((i + 1) * score_model.output_channels)].to_vec();
|
||||||
@ -465,6 +490,13 @@ fn main() -> Result<()> {
|
|||||||
index_output_file.2.write_all(&[cdf_bucket])?;
|
index_output_file.2.write_all(&[cdf_bucket])?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (index, score) in postfilter_scorer.iter() {
|
||||||
|
if entry_scores[*index] < *score {
|
||||||
|
postfilter = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut entry = PackedIndexEntry {
|
let mut entry = PackedIndexEntry {
|
||||||
id: count + i as u32,
|
id: count + i as u32,
|
||||||
vertices,
|
vertices,
|
||||||
@ -476,9 +508,10 @@ fn main() -> Result<()> {
|
|||||||
shards
|
shards
|
||||||
};
|
};
|
||||||
let mut bytes = bitcode::encode(&entry);
|
let mut bytes = bitcode::encode(&entry);
|
||||||
if bytes.len() > (RECORD_PAD_SIZE - 2) {
|
// as an ugly hack for removing entries already in the index shards, kill the URL and make it a graph node only
|
||||||
|
if bytes.len() > (RECORD_PAD_SIZE - 2) || postfilter {
|
||||||
// we do need the records to fit in a fixed size and can't really drop things, so discard URL so it can exist as a graph node only
|
// we do need the records to fit in a fixed size and can't really drop things, so discard URL so it can exist as a graph node only
|
||||||
entry.url = String::new();
|
entry.url = String::new(); // URL is only input-controlled, arbitrary-length field
|
||||||
bytes = bitcode::encode(&entry);
|
bytes = bitcode::encode(&entry);
|
||||||
dead_count += 1;
|
dead_count += 1;
|
||||||
}
|
}
|
||||||
@ -494,11 +527,11 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if args.print_aggregates {
|
if args.print_aggregates {
|
||||||
println!("earliest={} latest={} count={} read={} deduped={}", earliest_timestamp, latest_timestamp, count, i, deduped_count);
|
println!("earliest={} latest={} count={} read={} deduped={} postfiltered={}", earliest_timestamp, latest_timestamp, count, i, deduped_count, postfilter_count);
|
||||||
}
|
}
|
||||||
if let Some(histogram_path) = args.histograms {
|
if let Some(histogram_path) = args.histograms {
|
||||||
let mut file = fs::File::create(histogram_path)?;
|
let mut file = fs::File::create(histogram_path)?;
|
||||||
for (name, _, histogram, _) in &embeddings {
|
for (name, _, histogram, _, _) in &embeddings {
|
||||||
let width = 800.0;
|
let width = 800.0;
|
||||||
let padding = 40.0;
|
let padding = 40.0;
|
||||||
let bars_height = 300 as f64;
|
let bars_height = 300 as f64;
|
||||||
|
@ -10,14 +10,16 @@ with open("mse_config.json") as f:
|
|||||||
def get_embedding(req):
|
def get_embedding(req):
|
||||||
return msgpack.unpackb(requests.post(config["clip_server"], data=msgpack.packb(req)).content)
|
return msgpack.unpackb(requests.post(config["clip_server"], data=msgpack.packb(req)).content)
|
||||||
|
|
||||||
output, input, *xs = sys.argv[1:]
|
mode, output, input = sys.argv[1:]
|
||||||
|
|
||||||
with open(output, "wb") as f:
|
with open(output, "wb") as f:
|
||||||
with open(input, "rb") as g:
|
if mode == "image":
|
||||||
input_data = g.read()
|
with open(input, "rb") as g:
|
||||||
if not xs:
|
input_data = g.read()
|
||||||
result = get_embedding({"images": [input_data]})[0]
|
result = get_embedding({"images": [input_data]})[0]
|
||||||
|
elif mode == "text":
|
||||||
|
result = get_embedding({"text": input})[0]
|
||||||
else:
|
else:
|
||||||
result = get_embedding({"text": xs})[0]
|
raise Exception("unknown mode")
|
||||||
f.write(result)
|
f.write(result)
|
||||||
print(base64.urlsafe_b64encode(result).decode("ascii"))
|
print(base64.urlsafe_b64encode(result).decode("ascii"))
|
||||||
|
@ -892,7 +892,7 @@ async fn build_index(config: Arc<WConfig>) -> Result<IIndex> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(index))]
|
#[instrument(skip(index))]
|
||||||
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult> {
|
async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bool) -> Result<QueryResult<()>> {
|
||||||
let result = index.vectors.search(&query, k as usize)?;
|
let result = index.vectors.search(&query, k as usize)?;
|
||||||
|
|
||||||
let mut seen_videos = HashSet::new();
|
let mut seen_videos = HashSet::new();
|
||||||
@ -916,7 +916,8 @@ async fn query_index(index: &IIndex, query: EmbeddingVector, k: usize, video: bo
|
|||||||
index.filenames[id].container_filename(),
|
index.filenames[id].container_filename(),
|
||||||
generate_filename_hash(&index.filenames[id as usize]).clone(),
|
generate_filename_hash(&index.filenames[id as usize]).clone(),
|
||||||
index.format_codes[id],
|
index.format_codes[id],
|
||||||
index.metadata[id].as_ref().map(|x| (x.width, x.height))
|
index.metadata[id].as_ref().map(|x| (x.width, x.height)),
|
||||||
|
Option::<()>::None
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -7,7 +7,7 @@ use argh::FromArgs;
|
|||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use foldhash::{HashSet, HashSetExt};
|
use foldhash::{HashSet, HashSetExt};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64}};
|
use diskann::{NeighbourBuffer, vector::{fast_dot_noprefetch, QueryLUT, scale_dot_result, scale_dot_result_f64, SCALE_F64}};
|
||||||
use simsimd::SpatialSimilarity;
|
use simsimd::SpatialSimilarity;
|
||||||
use memmap2::{Mmap, MmapOptions};
|
use memmap2::{Mmap, MmapOptions};
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
@ -18,13 +18,14 @@ use http_body_util::{BodyExt, Empty, Full};
|
|||||||
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
use prometheus::{register_int_counter, register_int_counter_vec, register_int_gauge, Encoder, IntCounter, IntGauge, IntCounterVec};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use serde::Serialize;
|
use serde::{Serialize, Deserialize};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use common::{resize_for_embed_sync, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
|
use common::{resize_for_embed_sync, QueryTerm, FrontendInit, IndexHeader, InferenceServerConfig, PackedIndexEntry, QueryRequest, QueryResult};
|
||||||
|
|
||||||
#[derive(FromArgs, Clone)]
|
#[derive(FromArgs, Clone)]
|
||||||
#[argh(description="Query disk index")]
|
#[argh(description="Query disk index")]
|
||||||
@ -43,10 +44,18 @@ struct CLIArguments {
|
|||||||
search_list_size: Option<usize>,
|
search_list_size: Option<usize>,
|
||||||
#[argh(switch, description="always use full-precision vectors (slow)")]
|
#[argh(switch, description="always use full-precision vectors (slow)")]
|
||||||
disable_pq: bool,
|
disable_pq: bool,
|
||||||
#[argh(option, short='l', description="listen address")]
|
#[argh(option, short='c', description="server config file")]
|
||||||
listen_address: Option<String>,
|
config_path: Option<String>
|
||||||
#[argh(option, short='c', description="clip server")]
|
}
|
||||||
clip_server: Option<String>,
|
|
||||||
|
#[derive(Deserialize, Clone)]
|
||||||
|
struct ServerConfig {
|
||||||
|
listen_address: String,
|
||||||
|
clip_server: String,
|
||||||
|
descriptor_names: Vec<String>,
|
||||||
|
telemetry_file: String,
|
||||||
|
search_list: usize,
|
||||||
|
beam_width: usize
|
||||||
}
|
}
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
@ -82,17 +91,30 @@ fn next_several_unvisited(s: &mut NeighbourBuffer, n: usize) -> Option<Vec<u32>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const DUPLICATES_THRESHOLD: f32 = 0.95;
|
||||||
|
|
||||||
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
|
fn read_pq_codes(id: u32, index: Rc<Index>, buf: &mut Vec<u8>) {
|
||||||
let loc = (id as usize) * index.pq_code_size;
|
let loc = (id as usize) * index.pq_code_size;
|
||||||
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
|
buf.extend(&index.pq_codes[loc..loc+index.pq_code_size])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct VisitedNode {
|
||||||
|
image_url: String,
|
||||||
|
scores: Vec<f32>,
|
||||||
|
shards: Vec<u32>,
|
||||||
|
id: u32,
|
||||||
|
score: i64,
|
||||||
|
timestamp: u64,
|
||||||
|
dimensions: (u32, u32)
|
||||||
|
}
|
||||||
|
|
||||||
struct Scratch {
|
struct Scratch {
|
||||||
visited_adjacent: HashSet<u32>,
|
visited_adjacent: HashSet<u32>,
|
||||||
visited: HashSet<u32>,
|
visited: HashSet<u32>,
|
||||||
neighbour_buffer: NeighbourBuffer,
|
neighbour_buffer: NeighbourBuffer,
|
||||||
neighbour_pre_buffer: Vec<u32>,
|
neighbour_pre_buffer: Vec<u32>,
|
||||||
visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)>
|
visited_list: Vec<VisitedNode>,
|
||||||
|
visited_embeddings: Vec<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Index {
|
struct Index {
|
||||||
@ -140,10 +162,20 @@ async fn greedy_search<'a>(scratch: &mut Scratch, start: u32, query: &[f16], que
|
|||||||
let index = index.clone();
|
let index = index.clone();
|
||||||
let node = handle.await?;
|
let node = handle.await?;
|
||||||
let vector = bytemuck::cast_slice(&node.vector);
|
let vector = bytemuck::cast_slice(&node.vector);
|
||||||
let distance = fast_dot_noprefetch(query, &vector);
|
let mut score = fast_dot_noprefetch(query, &vector);
|
||||||
|
score += descriptor_product(index.clone(), &descriptor_scales, node.id);
|
||||||
cmps += 1;
|
cmps += 1;
|
||||||
if scratch.visited.insert(node.id) {
|
if scratch.visited.insert(node.id) && node.url.len() > 0 {
|
||||||
scratch.visited_list.push((node.id, distance, node.url, node.shards, node.scores));
|
scratch.visited_list.push(VisitedNode {
|
||||||
|
image_url: node.url,
|
||||||
|
scores: node.scores,
|
||||||
|
shards: node.shards,
|
||||||
|
id: node.id,
|
||||||
|
score,
|
||||||
|
timestamp: node.timestamp,
|
||||||
|
dimensions: node.dimensions
|
||||||
|
});
|
||||||
|
scratch.visited_embeddings.extend(bytemuck::cast_slice(&node.vector).iter().map(|x: &f16| x.to_f32()));
|
||||||
};
|
};
|
||||||
for &neighbour in node.vertices.iter() {
|
for &neighbour in node.vertices.iter() {
|
||||||
if scratch.visited_adjacent.insert(neighbour) {
|
if scratch.visited_adjacent.insert(neighbour) {
|
||||||
@ -250,7 +282,8 @@ async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
|||||||
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
||||||
neighbour_pre_buffer: Vec::new(),
|
neighbour_pre_buffer: Vec::new(),
|
||||||
visited_list: Vec::new(),
|
visited_list: Vec::new(),
|
||||||
visited_adjacent: HashSet::new()
|
visited_adjacent: HashSet::new(),
|
||||||
|
visited_embeddings: Vec::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
|
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
|
||||||
@ -265,14 +298,14 @@ async fn evaluate(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
|||||||
println!("index scan {}: {:?} cmps", shard, cmps_result);
|
println!("index scan {}: {:?} cmps", shard, cmps_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
scratch.visited_list.sort_by_key(|x| -x.1);
|
scratch.visited_list.sort_by_key(|x| -x.score);
|
||||||
for (i, (id, distance, url, shards, scores)) in scratch.visited_list.iter().take(20).enumerate() {
|
for (i, node) in scratch.visited_list.iter().take(20).enumerate() {
|
||||||
let found_id = match matches.binary_search(&(*id, 0)) {
|
let found_id = match matches.binary_search(&(node.id, 0)) {
|
||||||
Ok(pos) => pos,
|
Ok(pos) => pos,
|
||||||
Err(pos) => pos
|
Err(pos) => pos
|
||||||
};
|
};
|
||||||
if args.verbose {
|
if args.verbose {
|
||||||
println!("index scan: {} {} {} {:?} {:?}; rank {}", id, distance, url, shards, scores, matches[found_id].1 + 1);
|
println!("index scan: {} {} {} {:?} {:?}; rank {}", node.id, node.score, node.image_url, node.shards, node.scores, matches[found_id].1 + 1);
|
||||||
};
|
};
|
||||||
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
|
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
|
||||||
}
|
}
|
||||||
@ -341,11 +374,23 @@ pub async fn query_clip_server<I, O>(base_url: &str, path: &str, data: Option<I>
|
|||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct TelemetryMessage {
|
||||||
|
#[serde(rename="correlationId")]
|
||||||
|
correlation_id: String,
|
||||||
|
data: serde_json::Value,
|
||||||
|
event: String,
|
||||||
|
#[serde(rename="instanceId")]
|
||||||
|
instance_id: String,
|
||||||
|
page: String
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Service {
|
struct Service {
|
||||||
index: Rc<Index>,
|
index: Rc<Index>,
|
||||||
inference_server_config: Rc<InferenceServerConfig>,
|
inference_server_config: Rc<InferenceServerConfig>,
|
||||||
args: Rc<CLIArguments>
|
config: Rc<ServerConfig>,
|
||||||
|
telemetry_channel: std::sync::mpsc::Sender<TelemetryMessage>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl hyper::service::Service<Request<Incoming>> for Service {
|
impl hyper::service::Service<Request<Incoming>> for Service {
|
||||||
@ -355,15 +400,16 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
|||||||
|
|
||||||
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||||
let index = self.index.clone();
|
let index = self.index.clone();
|
||||||
let args = self.args.clone();
|
let config = self.config.clone();
|
||||||
let inference_server_config = self.inference_server_config.clone();
|
let inference_server_config = self.inference_server_config.clone();
|
||||||
|
let channel = self.telemetry_channel.clone();
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let mut body = match (req.method(), req.uri().path()) {
|
let mut body = match (req.method(), req.uri().path()) {
|
||||||
(&Method::GET, "/") => Response::new(Full::new(Bytes::from(serde_json::to_vec(&FrontendInit {
|
(&Method::GET, "/") => Response::new(Full::new(Bytes::from(serde_json::to_vec(&FrontendInit {
|
||||||
n_total: (index.header.count - index.header.dead_count) as u64,
|
n_total: (index.header.count - index.header.dead_count) as u64,
|
||||||
d_emb: index.header.quantizer.n_dims,
|
d_emb: index.header.quantizer.n_dims,
|
||||||
predefined_embedding_names: vec![]
|
predefined_embedding_names: config.descriptor_names.clone()
|
||||||
})?))),
|
})?))),
|
||||||
(&Method::POST, "/") => {
|
(&Method::POST, "/") => {
|
||||||
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
|
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
|
||||||
@ -381,7 +427,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
|||||||
&body.terms,
|
&body.terms,
|
||||||
&*inference_server_config,
|
&*inference_server_config,
|
||||||
|batch, _config| {
|
|batch, _config| {
|
||||||
query_clip_server(args.clip_server.as_ref().unwrap(), "/", Some(batch))
|
query_clip_server(config.clip_server.as_str(), "/", Some(batch))
|
||||||
},
|
},
|
||||||
|image, config| async move {
|
|image, config| async move {
|
||||||
let image = image::load_from_memory(&image)?;
|
let image = image::load_from_memory(&image)?;
|
||||||
@ -397,27 +443,89 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
|||||||
}).unwrap();
|
}).unwrap();
|
||||||
let selected_start = index.header.shards[selected_shard].1;
|
let selected_start = index.header.shards[selected_shard].1;
|
||||||
|
|
||||||
let beamwidth = 3;
|
let beamwidth = config.beam_width;
|
||||||
|
|
||||||
let mut scratch = Scratch {
|
let mut scratch = Scratch {
|
||||||
visited: HashSet::new(),
|
visited: HashSet::new(),
|
||||||
neighbour_buffer: NeighbourBuffer::new(args.search_list_size.unwrap_or(1000)),
|
neighbour_buffer: NeighbourBuffer::new(config.search_list),
|
||||||
neighbour_pre_buffer: Vec::new(),
|
neighbour_pre_buffer: Vec::new(),
|
||||||
visited_list: Vec::new(),
|
visited_list: Vec::new(),
|
||||||
visited_adjacent: HashSet::new()
|
visited_adjacent: HashSet::new(),
|
||||||
|
visited_embeddings: Vec::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
let descriptor_scales = DescriptorScales(vec![0.0, 0.0, 0.0, 0.0]);
|
let mut desc = vec![0.0, 0.0, 0.0, 0.0];
|
||||||
|
|
||||||
|
for term in &body.terms {
|
||||||
|
if let Some(name) = &term.predefined_embedding {
|
||||||
|
if let Some(index) = config.descriptor_names.iter().position(|x| x == name) {
|
||||||
|
desc[index] = term.weight.unwrap_or(1.0) * 1.0/512.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let descriptor_scales = DescriptorScales(desc);
|
||||||
|
|
||||||
let query_preprocessed = index.header.quantizer.preprocess_query(&query);
|
let query_preprocessed = index.header.quantizer.preprocess_query(&query);
|
||||||
|
|
||||||
let query = query.iter().map(|x| half::f16::from_f32(*x)).collect::<Vec<f16>>();
|
let query = query.iter().map(|x| half::f16::from_f32(*x)).collect::<Vec<f16>>();
|
||||||
|
|
||||||
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), args.disable_pq, beamwidth).await?;
|
let cmps_result = greedy_search(&mut scratch, selected_start, &query, &query_preprocessed, &descriptor_scales, index.clone(), false, beamwidth).await?;
|
||||||
|
|
||||||
scratch.visited_list.sort_by_key(|x| -x.1);
|
let n_visited = scratch.visited_list.len();
|
||||||
|
|
||||||
let matches = scratch.visited_list.drain(..).map(|(id, score, url, shards, scores)| (score as f32, url, String::new(), 0, None)).collect::<Vec<_>>();
|
let mut similarities_against_self = vec![0.0f32; n_visited * n_visited];
|
||||||
|
|
||||||
|
// runtime deduplicate of results list
|
||||||
|
unsafe {
|
||||||
|
// vecs @ vecs.T
|
||||||
|
matrixmultiply::sgemm(
|
||||||
|
n_visited,
|
||||||
|
index.header.quantizer.n_dims,
|
||||||
|
n_visited,
|
||||||
|
1.0,
|
||||||
|
scratch.visited_embeddings.as_ptr(),
|
||||||
|
index.header.quantizer.n_dims as isize,
|
||||||
|
1,
|
||||||
|
scratch.visited_embeddings.as_ptr(),
|
||||||
|
1,
|
||||||
|
index.header.quantizer.n_dims as isize,
|
||||||
|
0.0,
|
||||||
|
similarities_against_self.as_mut_ptr(),
|
||||||
|
n_visited as isize,
|
||||||
|
1
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// discard anything similar to something already in list
|
||||||
|
let mut i = 0;
|
||||||
|
let mut included = bitvec::bitvec![0; n_visited];
|
||||||
|
scratch.visited_list.retain(|_node| {
|
||||||
|
let row = &similarities_against_self[(i * n_visited)..((i + 1) * n_visited)];
|
||||||
|
let old_i = i;
|
||||||
|
i += 1;
|
||||||
|
for (other_i, similarity) in row.iter().enumerate() {
|
||||||
|
if similarity > &DUPLICATES_THRESHOLD && included[other_i] {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
included.set(old_i, true);
|
||||||
|
true
|
||||||
|
});
|
||||||
|
|
||||||
|
scratch.visited_list.sort_unstable_by_key(|x| -x.score);
|
||||||
|
|
||||||
|
let matches = scratch.visited_list
|
||||||
|
.drain(..)
|
||||||
|
.map(|node| {
|
||||||
|
let debug = if body.debug_enabled {
|
||||||
|
Some((node.scores, node.shards, node.timestamp))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
((node.score as f64 / SCALE_F64) as f32, node.image_url, String::new(), 0, Some(node.dimensions), debug)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let result = QueryResult {
|
let result = QueryResult {
|
||||||
formats: vec![],
|
formats: vec![],
|
||||||
@ -438,6 +546,25 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
|||||||
.header(hyper::header::CONTENT_TYPE, "text/plain; version=0.0.4")
|
.header(hyper::header::CONTENT_TYPE, "text/plain; version=0.0.4")
|
||||||
.body(Full::new(Bytes::from(buffer))).unwrap()
|
.body(Full::new(Bytes::from(buffer))).unwrap()
|
||||||
},
|
},
|
||||||
|
(&Method::POST, "/telemetry") => {
|
||||||
|
// TODO refactor
|
||||||
|
let upper = req.body().size_hint().upper().unwrap_or(u64::MAX);
|
||||||
|
if upper > 1000 {
|
||||||
|
let mut resp = Response::new(Full::new(Bytes::from("Body too big")));
|
||||||
|
*resp.status_mut() = hyper::StatusCode::PAYLOAD_TOO_LARGE;
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
|
||||||
|
let whole_body = req.collect().await?.to_bytes();
|
||||||
|
|
||||||
|
let message = serde_json::from_slice::<TelemetryMessage>(&whole_body)?;
|
||||||
|
|
||||||
|
channel.send(message)?;
|
||||||
|
|
||||||
|
Response::builder()
|
||||||
|
.status(StatusCode::NO_CONTENT)
|
||||||
|
.body(Full::new(Bytes::from(""))).unwrap()
|
||||||
|
}
|
||||||
(&Method::OPTIONS, "/") => {
|
(&Method::OPTIONS, "/") => {
|
||||||
Response::builder()
|
Response::builder()
|
||||||
.status(StatusCode::NO_CONTENT)
|
.status(StatusCode::NO_CONTENT)
|
||||||
@ -459,9 +586,9 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceServerConfig> {
|
async fn get_backend_config(clip_server: &String) -> Result<InferenceServerConfig> {
|
||||||
loop {
|
loop {
|
||||||
match query_clip_server(clip_server.as_ref().unwrap(), "/config", Option::<()>::None).await {
|
match query_clip_server(clip_server, "/config", Option::<()>::None).await {
|
||||||
Ok(config) => return Ok(config),
|
Ok(config) => return Ok(config),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
tracing::warn!("waiting for clip server: {}", err);
|
tracing::warn!("waiting for clip server: {}", err);
|
||||||
@ -471,14 +598,31 @@ async fn get_backend_config(clip_server: &Option<String>) -> Result<InferenceSer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// can't run this as an async task because monoio File API is positional writes only
|
||||||
|
fn telemetry_handler(rx: std::sync::mpsc::Receiver<TelemetryMessage>, config: ServerConfig) -> Result<()> {
|
||||||
|
let mut telemetry_file = std::fs::OpenOptions::new().create(true).create(true).append(true).open(&config.telemetry_file)?;
|
||||||
|
while let Ok(message) = rx.recv() {
|
||||||
|
telemetry_file.write_all(rmp_serde::to_vec(&message)?.as_slice())?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
async fn serve(args: &CLIArguments, index: Rc<Index>) -> Result<()> {
|
||||||
|
let config: ServerConfig = serde_json::from_slice(&std::fs::read(args.config_path.as_ref().unwrap())?)?;
|
||||||
|
|
||||||
|
let (telemetry_channel, telemetry_receiver) = std::sync::mpsc::channel();
|
||||||
|
|
||||||
|
let config_ = config.clone();
|
||||||
|
std::thread::spawn(move || telemetry_handler(telemetry_receiver, config_));
|
||||||
|
|
||||||
let service = Service {
|
let service = Service {
|
||||||
index,
|
index,
|
||||||
inference_server_config: Rc::new(get_backend_config(&args.clip_server).await?),
|
inference_server_config: Rc::new(get_backend_config(&config.clip_server).await?),
|
||||||
args: Rc::new(args.clone())
|
config: Rc::new(config.clone()),
|
||||||
|
telemetry_channel
|
||||||
};
|
};
|
||||||
|
|
||||||
let listener = TcpListener::bind(args.listen_address.as_ref().unwrap())?;
|
let listener = TcpListener::bind(config.listen_address)?;
|
||||||
println!("Listening");
|
println!("Listening");
|
||||||
loop {
|
loop {
|
||||||
let (stream, _) = listener.accept().await?;
|
let (stream, _) = listener.accept().await?;
|
||||||
@ -531,7 +675,7 @@ async fn main() -> Result<()> {
|
|||||||
n_descriptors: header.descriptor_cdfs.len(),
|
n_descriptors: header.descriptor_cdfs.len(),
|
||||||
});
|
});
|
||||||
|
|
||||||
if args.listen_address.is_some() {
|
if args.config_path.is_some() {
|
||||||
serve(&args, index).await?;
|
serve(&args, index).await?;
|
||||||
} else {
|
} else {
|
||||||
evaluate(&args, index).await?;
|
evaluate(&args, index).await?;
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user