mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-22 22:10:10 +00:00
45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
![]() |
import numpy as np
|
||
|
import msgpack
|
||
|
import math
|
||
|
import torch
|
||
|
import faiss
|
||
|
import tqdm
|
||
|
|
||
|
n_dims = 1152
|
||
|
output_code_size = 64
|
||
|
output_code_bits = 8
|
||
|
output_codebook_size = 2**output_code_bits
|
||
|
n_dims_per_code = n_dims // output_code_size
|
||
|
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||
|
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||
|
device = "cpu"
|
||
|
|
||
|
def pq_assign(centroids, batch):
|
||
|
quantized = torch.zeros_like(batch)
|
||
|
|
||
|
# Assign to nearest centroid in each subspace
|
||
|
for dmin in range(0, n_dims, n_dims_per_code):
|
||
|
dmax = dmin + n_dims_per_code
|
||
|
similarities = torch.matmul(batch[:, dmin:dmax], centroids[:, dmin:dmax].T)
|
||
|
assignments = similarities.argmax(dim=1)
|
||
|
quantized[:, dmin:dmax] = centroids[assignments, dmin:dmax]
|
||
|
|
||
|
return quantized
|
||
|
|
||
|
with open("opq.msgpack", "rb") as f:
|
||
|
data = msgpack.unpack(f)
|
||
|
centroids = torch.tensor(data["centroids"], device=device).reshape(2**output_code_bits, n_dims)
|
||
|
projection = torch.tensor(data["transform"], device=device).reshape(n_dims, n_dims)
|
||
|
|
||
|
vectors = torch.tensor(dataset, device=device)
|
||
|
queries = torch.tensor(queryset, device=device)
|
||
|
|
||
|
sample_size = 64
|
||
|
qsample = pq_assign(centroids, vectors[:sample_size] @ projection)
|
||
|
print(qsample)
|
||
|
print(vectors[:sample_size])
|
||
|
exact_results = vectors[:sample_size] @ queries[0]
|
||
|
approx_results = qsample @ (projection @ queries[0])
|
||
|
print(np.argsort(approx_results))
|
||
|
print(np.argsort(exact_results))
|