1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-02-23 14:30:04 +00:00
2024-12-31 23:05:48 +00:00

45 lines
1.5 KiB
Python

import numpy as np
import msgpack
import math
import torch
import faiss
import tqdm
n_dims = 1152
output_code_size = 64
output_code_bits = 8
output_codebook_size = 2**output_code_bits
n_dims_per_code = n_dims // output_code_size
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
device = "cpu"
def pq_assign(centroids, batch):
quantized = torch.zeros_like(batch)
# Assign to nearest centroid in each subspace
for dmin in range(0, n_dims, n_dims_per_code):
dmax = dmin + n_dims_per_code
similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T)
assignments = similarities.argmax(dim=1)
quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax]
return quantized
with open("opq.msgpack", "rb") as f:
data = msgpack.unpack(f)
centroids = torch.tensor(data["centroids"], device=device).reshape(2**output_code_bits, n_dims)
projection = torch.tensor(data["transform"], device=device).reshape(n_dims, n_dims)
vectors = torch.tensor(dataset, device=device)
queries = torch.tensor(queryset, device=device)
sample_size = 64
qsample = pq_assign(centroids, vectors[:sample_size] @ projection)
print(qsample)
print(vectors[:sample_size])
exact_results = vectors[:sample_size] @ queries[0]
approx_results = qsample @ (projection @ queries[0])
print(np.argsort(approx_results))
print(np.argsort(exact_results))