mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-09 12:16:40 +00:00
new PQ training code
This commit is contained in:
parent
4dd97631df
commit
2cebce1b73
@ -11,26 +11,9 @@ output_code_size = 64
|
||||
output_code_bits = 8
|
||||
output_codebook_size = 2**output_code_bits
|
||||
n_dims_per_code = n_dims // output_code_size
|
||||
dataset = np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
queryset = np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims)[:100000].astype(np.float32)
|
||||
device = "cpu"
|
||||
|
||||
index = faiss.index_factory(n_dims, "HNSW32,SQfp16", faiss.METRIC_INNER_PRODUCT)
|
||||
index.train(queryset)
|
||||
index.add(queryset)
|
||||
print("index ready")
|
||||
|
||||
T = 64
|
||||
|
||||
nearby_query_indices = torch.zeros((dataset.shape[0], T), dtype=torch.int32)
|
||||
|
||||
SEARCH_BATCH_SIZE = 1024
|
||||
|
||||
for i in range(0, len(dataset), SEARCH_BATCH_SIZE):
|
||||
res = index.search(dataset[i:i+SEARCH_BATCH_SIZE], T)
|
||||
nearby_query_indices[i:i+SEARCH_BATCH_SIZE] = torch.tensor(res[1])
|
||||
|
||||
print("query indices ready")
|
||||
dataset = np.random.permutation(np.fromfile("embeddings.bin", dtype=np.float16).reshape(-1, n_dims)).astype(np.float32)
|
||||
queryset = np.random.permutation(np.fromfile("query.bin", dtype=np.float16).reshape(-1, n_dims))[:100000].astype(np.float32)
|
||||
device = "cuda"
|
||||
|
||||
def pq_assign(centroids, batch):
|
||||
quantized = torch.zeros_like(batch)
|
||||
@ -47,28 +30,27 @@ def pq_assign(centroids, batch):
|
||||
# OOD-DiskANN (https://arxiv.org/abs/2211.12850) uses a more complicated scheme because it uses L2 norm
|
||||
# We only care about inner product so our quantization error (wrt a query) is just abs(dot(query, centroid - vector))
|
||||
# Directly optimize for this (wrt top queries; it might actually be better to use a random sample instead?)
|
||||
def partition(vectors, centroids, projection, opt, queries, nearby_query_indices, k, max_iter=100, batch_size=4096):
|
||||
def partition(vectors, centroids, projection, opt, queries, k, max_iter=100, batch_size=4096, query_batch_size=2048):
|
||||
n_vectors = len(vectors)
|
||||
perm = torch.randperm(n_vectors, device=device)
|
||||
#perm = torch.randperm(n_vectors, device=device)
|
||||
|
||||
t = tqdm.trange(max_iter)
|
||||
for iter in t:
|
||||
total_loss = 0
|
||||
opt.zero_grad(set_to_none=True)
|
||||
|
||||
# randomly sample queries (with replacement, probably fine)
|
||||
queries_for_iteration = queries[torch.randint(0, len(queries), (query_batch_size,), device=device)]
|
||||
|
||||
for i in range(0, n_vectors, batch_size):
|
||||
loss = torch.tensor(0.0, device=device)
|
||||
batch = vectors[i:i+batch_size] @ projection
|
||||
quantized = pq_assign(centroids, batch)
|
||||
residuals = batch - quantized
|
||||
|
||||
# for each index in our set of nearby queries
|
||||
for j in range(0, nearby_query_indices.shape[1]):
|
||||
queries_for_batch_j = queries[nearby_query_indices[i:i+batch_size, j]]
|
||||
# minimize quantiation error in direction of query, i.e. mean abs(dot(query, centroid - vector))
|
||||
# PyTorch won't do batched dot products cleanly, to spite me. Do componentwise multiplication and reduce.
|
||||
sg_errs = (queries_for_batch_j * residuals).sum(dim=-1)
|
||||
loss += torch.mean(torch.abs(sg_errs))
|
||||
batch_error = queries_for_iteration @ residuals.T
|
||||
|
||||
loss += torch.mean(torch.pow(batch_error, 2))
|
||||
|
||||
total_loss += loss.detach().item()
|
||||
loss.backward()
|
||||
@ -90,10 +72,10 @@ queries = torch.tensor(queryset, device=device)
|
||||
perm = torch.randperm(len(vectors), device=device)
|
||||
centroids = vectors[perm[:output_codebook_size]]
|
||||
centroids.requires_grad = True
|
||||
opt = torch.optim.Adam([centroids], lr=0.001)
|
||||
opt = torch.optim.Adam([centroids], lr=0.0005)
|
||||
for i in range(30):
|
||||
# update centroids to minimize query-aware quantization loss
|
||||
partition(vectors, centroids, projection, opt, queries, nearby_query_indices, output_codebook_size, max_iter=8)
|
||||
partition(vectors, centroids, projection, opt, queries, output_codebook_size, max_iter=300)
|
||||
# compute new projection as R = VU^T from XY^T = USV^T (SVD)
|
||||
# where X is dataset vectors, Y is quantized dataset vectors
|
||||
with torch.no_grad():
|
||||
@ -102,12 +84,12 @@ for i in range(30):
|
||||
u, s, vt = torch.linalg.svd(vectors.T @ y)
|
||||
projection = vt.T @ u.T
|
||||
|
||||
print("done")
|
||||
with open("opq.msgpack", "wb") as f:
|
||||
msgpack.pack({
|
||||
"centroids": centroids.detach().cpu().numpy().flatten().tolist(),
|
||||
"transform": projection.cpu().numpy().flatten().tolist(),
|
||||
"n_dims_per_code": n_dims_per_code,
|
||||
"n_dims": n_dims
|
||||
}, f)
|
||||
|
||||
with open("opq.msgpack", "wb") as f:
|
||||
msgpack.pack({
|
||||
"centroids": centroids.detach().cpu().numpy().flatten().tolist(),
|
||||
"transform": projection.cpu().numpy().flatten().tolist(),
|
||||
"n_dims_per_code": n_dims_per_code,
|
||||
"n_dims": n_dims
|
||||
}, f)
|
||||
print("done")
|
||||
|
Loading…
x
Reference in New Issue
Block a user