mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-28 13:33:11 +00:00
154 lines
5.6 KiB
Python
154 lines
5.6 KiB
Python
import numpy as np
|
|
import msgpack
|
|
import math
|
|
import torch
|
|
from torch import autograd
|
|
from torch.nn.modules import distance
|
|
|
|
n_dims = 1152
|
|
data = np.fromfile("500k_vecs.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32)
|
|
n_clusters = 42
|
|
|
|
def partition_soft(vectors, k, max_iter=100, batch_size=8192):
|
|
n_vectors = len(vectors)
|
|
perm = torch.randperm(n_vectors)
|
|
centroids = vectors[perm[:k]]
|
|
biases = torch.randn(k, device=vectors.device)
|
|
centroids.requires_grad = True
|
|
biases.requires_grad = True
|
|
opt = torch.optim.Adam([centroids], lr=0.01)
|
|
temperature = 1.0
|
|
size_scale = 15.0
|
|
bias_scale = 1.0
|
|
score_scale = 0.1
|
|
|
|
desired_size = n_vectors / k
|
|
|
|
with autograd.detect_anomaly():
|
|
for it in range(max_iter):
|
|
cluster_sizes = torch.zeros(k, device=vectors.device)
|
|
norm_centroids = torch.nn.functional.normalize(centroids)
|
|
score = torch.zeros(k, device=vectors.device)
|
|
#soft_assignment_entropy = 0
|
|
|
|
for i in range(0, n_vectors, batch_size):
|
|
batch = vectors[i:i+batch_size]
|
|
|
|
similarities = torch.matmul(batch, norm_centroids.T)
|
|
|
|
soft_assignments = ((similarities + biases) / temperature).softmax(dim=1)
|
|
#print(soft_assignments[0])
|
|
|
|
#print(soft_assignments.shape, similarities.shape)
|
|
|
|
#entropy_by_vector = -soft_assignments.mul(soft_assignments.log2()).sum(dim=1)
|
|
#soft_assignment_entropy += entropy_by_vector.mean()
|
|
cluster_sizes += soft_assignments.sum(dim=0)
|
|
score += similarities.mean(dim=0)
|
|
|
|
opt.zero_grad()
|
|
|
|
distances_from_ideal_cluster_size = (cluster_sizes - desired_size).pow(2) / (desired_size ** 2)
|
|
size_loss = distances_from_ideal_cluster_size.mean()
|
|
bias_loss = biases.pow(2).mean()
|
|
score_loss = -score.mean()
|
|
loss = size_scale * size_loss + bias_scale * bias_loss + score_scale * score_loss
|
|
loss.backward()
|
|
opt.step()
|
|
|
|
print(temperature, size_scale * size_loss.detach().tolist(), bias_scale * bias_loss.detach().tolist(), score_scale * score_loss.detach().tolist(), cluster_sizes.tolist())
|
|
|
|
#centroids = new_centroids / cluster_sizes.unsqueeze(1)
|
|
|
|
#if torch.allclose(centroids, new_centroids, rtol=1e-4):
|
|
# break
|
|
|
|
if it % 100 == 0:
|
|
temperature *= 0.999
|
|
size_scale *= 1.1
|
|
|
|
return centroids.detach().cpu().numpy(), biases.detach().cpu().numpy()
|
|
|
|
SPILL_K = 2
|
|
def simulated_annealing(vectors, k, max_iter=100, batch_size=31768):
|
|
n_vectors = len(vectors)
|
|
centroids = torch.randn(k, n_dims, device=vectors.device)
|
|
desired_size = n_vectors / k
|
|
|
|
def fitness(centroids):
|
|
cluster_sizes = torch.zeros(SPILL_K, k, device=vectors.device, dtype=torch.int32)
|
|
norm_centroids = torch.nn.functional.normalize(centroids)
|
|
|
|
for i in range(0, n_vectors, batch_size):
|
|
batch = vectors[i:i+batch_size]
|
|
|
|
similarities = torch.matmul(batch, norm_centroids.T)
|
|
values, indices = similarities.topk(SPILL_K, dim=1)
|
|
|
|
for j in range(SPILL_K):
|
|
batch_counts = torch.bincount(indices[:, j], minlength=k)
|
|
cluster_sizes[j] += batch_counts
|
|
|
|
distances_from_ideal_cluster_size = (cluster_sizes - desired_size).abs()
|
|
#print(distances_from_ideal_cluster_size)
|
|
return distances_from_ideal_cluster_size.max(), distances_from_ideal_cluster_size.argmax(dim=1)
|
|
|
|
global_best, global_best_result = None, 1000000
|
|
|
|
temperature = 1.0
|
|
|
|
last_fitness, _ = fitness(centroids)
|
|
last_improvement = 0
|
|
for _ in range(max_iter):
|
|
n = centroids + torch.randn_like(centroids) * temperature
|
|
new_fitness, worst_centroid = fitness(n)
|
|
print(last_fitness.tolist(), new_fitness.tolist(), temperature)
|
|
if new_fitness < last_fitness:
|
|
centroids = n
|
|
temperature *= 0.999
|
|
last_fitness = new_fitness
|
|
last_improvement = 0
|
|
else:
|
|
temperature *= 0.9995
|
|
last_improvement += 1
|
|
if last_improvement > 100:
|
|
print("rerolling")
|
|
centroids[worst_centroid] = torch.randn_like(centroids[worst_centroid])
|
|
last_improvement = 0
|
|
temperature *= 1.1
|
|
last_fitness = new_fitness
|
|
if last_fitness < desired_size * 0.1:
|
|
break
|
|
temperature = min(1.5, temperature)
|
|
if new_fitness < global_best_result:
|
|
global_best = n
|
|
global_best_result = new_fitness
|
|
|
|
return torch.nn.functional.normalize(centroids)
|
|
|
|
""""
|
|
centroids = list(zip(*partition(torch.tensor(data), n_clusters**2, max_iter=100)))
|
|
|
|
BALANCE_WEIGHT = 3e-2
|
|
|
|
big_clusters = [ ([x], c) for x, c in centroids[:n_clusters] ]
|
|
centroids = centroids[n_clusters:]
|
|
|
|
while centroids:
|
|
avg_size = sum(c for _, c in big_clusters) / len(big_clusters)
|
|
|
|
for i, (items, count) in enumerate(big_clusters):
|
|
def match_score(x):
|
|
return 1/len(items) * sum(np.dot(x, y) for y in items)
|
|
|
|
candidate_index, candidate = max(enumerate(centroids), key=lambda x: match_score(x[1][0]) - BALANCE_WEIGHT * max(0, count + x[1][1] - avg_size))
|
|
centroids.pop(candidate_index)
|
|
big_clusters[i] = (items + [candidate[0]], count + candidate[1])
|
|
|
|
print([x[1] for x in big_clusters])
|
|
|
|
"""
|
|
|
|
centroids = simulated_annealing(torch.tensor(data, device=torch.device("cuda")), n_clusters, max_iter=80000).detach().cpu().numpy()
|
|
centroids.astype(np.float16).tofile("centroids.bin")
|