1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-10 20:56:39 +00:00

integrate rating model

This commit is contained in:
osmarks 2025-01-18 11:29:03 +00:00
parent d3fcedda09
commit 63caba2746
11 changed files with 762 additions and 17 deletions

2
.gitignore vendored
View File

@ -20,3 +20,5 @@ diskann/target
shards
index
queries.txt
*.zst
.safetensors

439
Cargo.lock generated
View File

@ -83,6 +83,9 @@ name = "arbitrary"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110"
dependencies = [
"derive_arbitrary",
]
[[package]]
name = "arg_enum_proc_macro"
@ -467,6 +470,28 @@ version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3"
[[package]]
name = "candle-core"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "855dfedff437d2681d68e1f34ae559d88b0dd84aa5a6b63f2c8e75ebdd875bbf"
dependencies = [
"byteorder",
"gemm",
"half",
"memmap2",
"num-traits",
"num_cpus",
"rand",
"rand_distr",
"rayon",
"safetensors",
"thiserror",
"ug",
"yoke",
"zip",
]
[[package]]
name = "castaway"
version = "0.2.3"
@ -779,6 +804,17 @@ dependencies = [
"zeroize",
]
[[package]]
name = "derive_arbitrary"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "digest"
version = "0.10.7"
@ -810,6 +846,17 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "document-features"
version = "0.2.10"
@ -825,6 +872,16 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]]
name = "dyn-stack"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b"
dependencies = [
"bytemuck",
"reborrow",
]
[[package]]
name = "either"
version = "1.13.0"
@ -843,6 +900,18 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "enum-as-inner"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "equivalent"
version = "1.0.1"
@ -1121,6 +1190,124 @@ dependencies = [
"slab",
]
[[package]]
name = "gemm"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32"
dependencies = [
"dyn-stack",
"gemm-c32",
"gemm-c64",
"gemm-common",
"gemm-f16",
"gemm-f32",
"gemm-f64",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-c32"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-c64"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-common"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8"
dependencies = [
"bytemuck",
"dyn-stack",
"half",
"num-complex",
"num-traits",
"once_cell",
"paste",
"pulp",
"raw-cpuid",
"rayon",
"seq-macro",
"sysctl",
]
[[package]]
name = "gemm-f16"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4"
dependencies = [
"dyn-stack",
"gemm-common",
"gemm-f32",
"half",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"rayon",
"seq-macro",
]
[[package]]
name = "gemm-f32"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "gemm-f64"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0"
dependencies = [
"dyn-stack",
"gemm-common",
"num-complex",
"num-traits",
"paste",
"raw-cpuid",
"seq-macro",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@ -1198,6 +1385,9 @@ dependencies = [
"bytemuck",
"cfg-if",
"crunchy",
"num-traits",
"rand",
"rand_distr",
]
[[package]]
@ -1766,7 +1956,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
dependencies = [
"autocfg",
"num_cpus",
"once_cell",
"rawpointer",
"thread-tree",
]
[[package]]
@ -1827,6 +2020,7 @@ dependencies = [
"base64 0.22.1",
"bitcode",
"bytemuck",
"candle-core",
"chrono",
"compact_str",
"console-subscriber",
@ -1844,6 +2038,7 @@ dependencies = [
"itertools 0.13.0",
"json5",
"lazy_static",
"matrixmultiply",
"maud",
"memmap2",
"mimalloc",
@ -1867,7 +2062,7 @@ dependencies = [
"tracing",
"tracing-subscriber",
"url",
"walkdir",
"walkdir 1.0.7",
"zstd",
]
@ -1878,6 +2073,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
dependencies = [
"libc",
"stable_deref_trait",
]
[[package]]
@ -2039,6 +2235,20 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "num"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
@ -2072,6 +2282,7 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"bytemuck",
"num-traits",
]
@ -2137,6 +2348,27 @@ dependencies = [
"libc",
]
[[package]]
name = "num_enum"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179"
dependencies = [
"num_enum_derive",
]
[[package]]
name = "num_enum_derive"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "object"
version = "0.36.5"
@ -2372,6 +2604,15 @@ dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro-crate"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b"
dependencies = [
"toml_edit",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
@ -2476,6 +2717,18 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "pulp"
version = "0.18.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6"
dependencies = [
"bytemuck",
"libm",
"num-complex",
"reborrow",
]
[[package]]
name = "qoi"
version = "0.4.1"
@ -2530,6 +2783,16 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand",
]
[[package]]
name = "rav1e"
version = "0.7.1"
@ -2581,6 +2844,15 @@ dependencies = [
"rgb",
]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
@ -2607,6 +2879,12 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "reborrow"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "redox_syscall"
version = "0.5.7"
@ -2843,6 +3121,16 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "safetensors"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "same-file"
version = "0.1.3"
@ -2853,6 +3141,15 @@ dependencies = [
"winapi 0.2.8",
]
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.26"
@ -2897,6 +3194,12 @@ dependencies = [
"libc",
]
[[package]]
name = "seq-macro"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
[[package]]
name = "serde"
version = "1.0.210"
@ -3342,6 +3645,12 @@ dependencies = [
"urlencoding",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
@ -3402,6 +3711,31 @@ dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "sysctl"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea"
dependencies = [
"bitflags 2.6.0",
"byteorder",
"enum-as-inner",
"libc",
"thiserror",
"walkdir 2.5.0",
]
[[package]]
name = "system-configuration"
version = "0.6.1"
@ -3475,6 +3809,15 @@ dependencies = [
"syn 2.0.79",
]
[[package]]
name = "thread-tree"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630"
dependencies = [
"crossbeam-channel",
]
[[package]]
name = "thread_local"
version = "1.1.8"
@ -3805,6 +4148,19 @@ version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971"
[[package]]
name = "ug"
version = "0.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4eef2ebfc18c67a6dbcacd9d8a4d85e0568cc58c82515552382312c2730ea13"
dependencies = [
"half",
"num",
"serde",
"serde_json",
"thiserror",
]
[[package]]
name = "unicase"
version = "2.8.0"
@ -3915,10 +4271,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb08f9e670fab86099470b97cd2b252d6527f0b3cc1401acdb595ffc9dd288ff"
dependencies = [
"kernel32-sys",
"same-file",
"same-file 0.1.3",
"winapi 0.2.8",
]
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file 1.0.6",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
@ -4061,6 +4427,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
@ -4263,6 +4638,30 @@ dependencies = [
"memchr",
]
[[package]]
name = "yoke"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40"
dependencies = [
"serde",
"stable_deref_trait",
"yoke-derive",
"zerofrom",
]
[[package]]
name = "yoke-derive"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"synstructure",
]
[[package]]
name = "zerocopy"
version = "0.7.35"
@ -4284,12 +4683,48 @@ dependencies = [
"syn 2.0.79",
]
[[package]]
name = "zerofrom"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"synstructure",
]
[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
[[package]]
name = "zip"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cc23c04387f4da0374be4533ad1208cbb091d5c11d070dfef13676ad6497164"
dependencies = [
"arbitrary",
"crc32fast",
"crossbeam-utils",
"displaydoc",
"indexmap 2.6.0",
"num_enum",
"thiserror",
]
[[package]]
name = "zstd"
version = "0.13.2"

View File

@ -57,6 +57,8 @@ bitcode = "0.6"
simsimd = "6"
foldhash = "0.1"
memmap2 = "0.9"
matrixmultiply = { version = "0.3", features = ["threading"] }
candle-core = "0.8"
[[bin]]
name = "reddit-dump"

62
meme-rater/compute_cdf.py Normal file
View File

@ -0,0 +1,62 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
import sys
from collections import defaultdict
from model import Config, BradleyTerry
import shared
batch_size = 128
num_pairs = batch_size * 1024
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
output_channels=3,
dropout=0.1
)
model = BradleyTerry(config)
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
files = shared.fetch_all_files()
results = {}
model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size]
filenames = [ f1 for f1, e1 in batch ]
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy()
channel = int(sys.argv[2])
percentile = float(sys.argv[3])
output_pairs = int(sys.argv[4])
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
select_from = top[:int(len(top) * percentile)]
out = []
for _ in range(output_pairs):
# dummy score for compatibility with existing code
out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0))
with open("top.json", "w") as f:
json.dump(out, f)

View File

@ -0,0 +1,74 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import msgpack
import sys
from safetensors.torch import save_file
from model import Config, BradleyTerry
import shared
device = "cpu"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
output_channels=3,
dropout=0.1
)
model = BradleyTerry(config)
model.eval()
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
out_layers = []
out_bias = []
with torch.inference_mode():
# TODO: I don't think this actually works for more than 1 hidden layer
for layer in range(config.n_hidden):
big_layer = torch.zeros(config.n_ensemble * config.d_emb, config.d_emb)
big_bias = torch.zeros(config.n_ensemble * config.d_emb)
for i in range(config.n_ensemble):
big_layer[i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].hidden[layer].weight.data.clone()
big_bias[i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].hidden[layer].bias.data.clone()
out_layers.append(big_layer)
out_bias.append(big_bias)
# we do not need to preserve the bias on the downprojection as the win probability calculation is shift-invariant
downprojection = torch.zeros(config.output_channels, config.n_ensemble * config.d_emb)
for i in range(config.n_ensemble):
downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone()
for i in range(10):
input = torch.randn(3, config.d_emb)
ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T
r_result = input
for (layer, bias) in zip(out_layers, out_bias):
r_result = torch.matmul(layer, r_result.T) + bias.unsqueeze(-1).expand(config.n_ensemble * config.d_emb, input.shape[0])
print(r_result.shape, bias.shape)
r_result = F.silu(r_result)
r_result = torch.matmul(downprojection, r_result) / config.n_ensemble
bias_diff = torch.mean(r_result - ground_truth_result)
model_issue_diff = torch.max(torch.abs(r_result - bias_diff - ground_truth_result))
print(model_issue_diff)
print("test vector:")
print(input.flatten().tolist())
print("ground truth result:")
print(ground_truth_result.T.flatten().tolist())
save_file({
"up_proj": out_layers[0],
"bias": out_bias[0],
"down_proj": downprojection
}, "model.safetensors")

View File

@ -119,7 +119,7 @@ with open(logfile, "w") as log:
print(steps, loss)
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 10 == 0:
if steps % 100 == 0: save_ckpt(log, steps)
if steps % 50 == 0: save_ckpt(log, steps)
loss = evaluate(steps)
#print(loss)
#best = min(loss, best)

View File

@ -152,7 +152,7 @@ pub struct PackedIndexEntry {
pub id: u32,
pub timestamp: u64,
pub dimensions: (u32, u32),
pub score: f32,
pub scores: Vec<f32>,
pub url: String,
pub shards: Vec<u32>
}

View File

@ -18,6 +18,7 @@ use std::os::unix::prelude::FileExt;
use diskann::vector::{scale_dot_result_f64, ProductQuantizer};
mod common;
mod score_model;
use common::{ProcessedEntry, ShardInputHeader, ShardedRecord, ShardHeader, PackedIndexEntry, IndexHeader};
@ -62,6 +63,10 @@ struct CLIArguments {
json: bool,
#[argh(option, short='f', description="k-means balance fudge factor", default="0.2")]
balance_fudge: f64,
#[argh(option, short='M', description="score model path")]
score_model: Option<String>,
#[argh(option, short='G', description="GPU (CUDA) device to use")]
gpu: Option<usize>
}
#[derive(Clone, Deserialize, Serialize, Debug)]
@ -124,7 +129,7 @@ const SHARD_SPILL: usize = 2;
const RECORD_PAD_SIZE: usize = 4096; // NVMe disk sector size
const D_EMB: u32 = 1152;
const EMPTY_LOOKUP: (u32, u64, u32) = (u32::MAX, 0, 0);
const BATCH_SIZE: usize = 1024;
const BATCH_SIZE: usize = 2048;
#[derive(Clone, Serialize, Debug)]
pub struct JsonEntry<'a> {
@ -149,7 +154,7 @@ fn main() -> Result<()> {
// load specified embeddings from files
let mut embeddings = Vec::new();
for x in args.embedding {
let (name, snd) = x.split_once(':').unwrap();
let (name, snd) = x.split_once(':').context("invalid embedding argument")?;
let (path, threshold) = if let Some((path, threshold)) = snd.split_once(':') {
(path, Some(threshold.parse::<f32>().context("parse threshold")?))
} else {
@ -203,9 +208,9 @@ fn main() -> Result<()> {
let file = file?;
let path = file.path();
let filename = path.file_name().unwrap().to_str().unwrap();
let (fst, snd) = filename.split_once(".").unwrap();
let (fst, snd) = filename.split_once(".").context("shard filename wrong")?;
let id: u32 = str::parse(fst)?;
let id: u32 = str::parse(fst).context("shard filename wrong")?;
if let Some(clip) = args.clip_shards {
if id >= (clip as u32) {
continue;
@ -283,8 +288,15 @@ fn main() -> Result<()> {
let mut index_output_file = if let Some(index_output) = &args.index_output {
let main_output = BufWriter::new(fs::File::create(PathBuf::from(index_output).join("index.bin")).context("create index file")?);
let pq_codes =BufWriter::new(fs::File::create(PathBuf::from(index_output).join("index.pq-codes.bin")).context("create index file")?);
Some((main_output, pq_codes))
let pq_codes = BufWriter::new(fs::File::create(PathBuf::from(index_output).join("index.pq-codes.bin")).context("create index file")?);
let descriptor_codes = BufWriter::new(fs::File::create(PathBuf::from(index_output).join("index.descriptor-codes.bin")).context("create index file")?);
Some((main_output, pq_codes, descriptor_codes))
} else {
None
};
let score_model = if let Some(score_model) = &args.score_model {
Some(score_model::ScoreModel::load(score_model, args.gpu).context("load score model")?)
} else {
None
};
@ -415,7 +427,7 @@ fn main() -> Result<()> {
}
if let (Some(read_out_vertices), Some(index_output_file)) = (&mut read_out_vertices, &mut index_output_file) {
let quantizer = pq_codec.as_ref().unwrap();
let quantizer = pq_codec.as_ref().context("PQ codec needed to output index")?;
let mut batch_embeddings = Vec::with_capacity(batch.len() * D_EMB as usize);
for (_x, embedding) in batch.iter() {
@ -423,6 +435,9 @@ fn main() -> Result<()> {
}
let codes = quantizer.quantize_batch(&batch_embeddings);
let score_model = score_model.as_ref().context("score model needed to output index")?;
let scores = score_model.score_batch(&batch_embeddings)?;
for (i, (x, _embedding)) in batch.into_iter().enumerate() {
let (vertices, shards) = read_out_vertices(count + i as u32)?; // TODO: could parallelize this given the batching
let mut entry = PackedIndexEntry {
@ -431,7 +446,7 @@ fn main() -> Result<()> {
vector: x.embedding.chunks_exact(2).map(|x| u16::from_le_bytes([x[0], x[1]])).collect(),
timestamp: x.timestamp,
dimensions: x.metadata.dimension,
score: 0.5, // TODO
scores: scores[i..(i + score_model.output_channels)].to_vec(),
url: x.metadata.final_url,
shards
};

92
src/old_score.rs Normal file

File diff suppressed because one or more lines are too long

View File

@ -55,7 +55,7 @@ struct Scratch {
visited: HashSet<u32>,
neighbour_buffer: NeighbourBuffer,
neighbour_pre_buffer: Vec<u32>,
visited_list: Vec<(u32, i64, String, Vec<u32>)>
visited_list: Vec<(u32, i64, String, Vec<u32>, Vec<f32>)>
}
struct IndexRef<'a> {
@ -85,7 +85,7 @@ fn greedy_search(scratch: &mut Scratch, start: u32, query: &[f16], query_preproc
let vector = bytemuck::cast_slice(&node.vector);
let distance = fast_dot_noprefetch(query, &vector);
cmps += 1;
scratch.visited_list.push((pt, distance, node.url, node.shards));
scratch.visited_list.push((pt, distance, node.url, node.shards, node.scores));
for &neighbour in node.vertices.iter() {
if scratch.visited.insert(neighbour) {
scratch.neighbour_pre_buffer.push(neighbour);
@ -215,17 +215,17 @@ fn main() -> Result<()> {
cmps.push(cmps_result.0);
if args.verbose {
println!("index scan {}: {:?} cmps", shard, cmps);
println!("index scan {}: {:?} cmps", shard, cmps_result);
}
scratch.visited_list.sort_by_key(|x| -x.1);
for (i, (id, distance, url, shards)) in scratch.visited_list.iter().take(20).enumerate() {
for (i, (id, distance, url, shards, scores)) in scratch.visited_list.iter().take(20).enumerate() {
let found_id = match matches.binary_search(&(*id, 0)) {
Ok(pos) => pos,
Err(pos) => pos
};
if args.verbose {
println!("index scan: {} {} {} {:?}; rank {}", id, distance, url, shards, matches[found_id].1 + 1);
println!("index scan: {} {} {} {:?} {:?}; rank {}", id, distance, url, shards, scores, matches[found_id].1 + 1);
};
top_ranks[i] = std::cmp::min(top_ranks[i], matches[found_id].1);
}

63
src/score_model.rs Normal file

File diff suppressed because one or more lines are too long