1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-28 13:33:11 +00:00
meme-search-engine/kmeans.py
2025-01-16 16:02:59 +00:00

154 lines
5.6 KiB
Python

import numpy as np
import msgpack
import math
import torch
from torch import autograd
from torch.nn.modules import distance
n_dims = 1152
data = np.fromfile("500k_vecs.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32)
n_clusters = 42
def partition_soft(vectors, k, max_iter=100, batch_size=8192):
n_vectors = len(vectors)
perm = torch.randperm(n_vectors)
centroids = vectors[perm[:k]]
biases = torch.randn(k, device=vectors.device)
centroids.requires_grad = True
biases.requires_grad = True
opt = torch.optim.Adam([centroids], lr=0.01)
temperature = 1.0
size_scale = 15.0
bias_scale = 1.0
score_scale = 0.1
desired_size = n_vectors / k
with autograd.detect_anomaly():
for it in range(max_iter):
cluster_sizes = torch.zeros(k, device=vectors.device)
norm_centroids = torch.nn.functional.normalize(centroids)
score = torch.zeros(k, device=vectors.device)
#soft_assignment_entropy = 0
for i in range(0, n_vectors, batch_size):
batch = vectors[i:i+batch_size]
similarities = torch.matmul(batch, norm_centroids.T)
soft_assignments = ((similarities + biases) / temperature).softmax(dim=1)
#print(soft_assignments[0])
#print(soft_assignments.shape, similarities.shape)
#entropy_by_vector = -soft_assignments.mul(soft_assignments.log2()).sum(dim=1)
#soft_assignment_entropy += entropy_by_vector.mean()
cluster_sizes += soft_assignments.sum(dim=0)
score += similarities.mean(dim=0)
opt.zero_grad()
distances_from_ideal_cluster_size = (cluster_sizes - desired_size).pow(2) / (desired_size ** 2)
size_loss = distances_from_ideal_cluster_size.mean()
bias_loss = biases.pow(2).mean()
score_loss = -score.mean()
loss = size_scale * size_loss + bias_scale * bias_loss + score_scale * score_loss
loss.backward()
opt.step()
print(temperature, size_scale * size_loss.detach().tolist(), bias_scale * bias_loss.detach().tolist(), score_scale * score_loss.detach().tolist(), cluster_sizes.tolist())
#centroids = new_centroids / cluster_sizes.unsqueeze(1)
#if torch.allclose(centroids, new_centroids, rtol=1e-4):
# break
if it % 100 == 0:
temperature *= 0.999
size_scale *= 1.1
return centroids.detach().cpu().numpy(), biases.detach().cpu().numpy()
SPILL_K = 2
def simulated_annealing(vectors, k, max_iter=100, batch_size=31768):
n_vectors = len(vectors)
centroids = torch.randn(k, n_dims, device=vectors.device)
desired_size = n_vectors / k
def fitness(centroids):
cluster_sizes = torch.zeros(SPILL_K, k, device=vectors.device, dtype=torch.int32)
norm_centroids = torch.nn.functional.normalize(centroids)
for i in range(0, n_vectors, batch_size):
batch = vectors[i:i+batch_size]
similarities = torch.matmul(batch, norm_centroids.T)
values, indices = similarities.topk(SPILL_K, dim=1)
for j in range(SPILL_K):
batch_counts = torch.bincount(indices[:, j], minlength=k)
cluster_sizes[j] += batch_counts
distances_from_ideal_cluster_size = (cluster_sizes - desired_size).abs()
#print(distances_from_ideal_cluster_size)
return distances_from_ideal_cluster_size.max(), distances_from_ideal_cluster_size.argmax(dim=1)
global_best, global_best_result = None, 1000000
temperature = 1.0
last_fitness, _ = fitness(centroids)
last_improvement = 0
for _ in range(max_iter):
n = centroids + torch.randn_like(centroids) * temperature
new_fitness, worst_centroid = fitness(n)
print(last_fitness.tolist(), new_fitness.tolist(), temperature)
if new_fitness < last_fitness:
centroids = n
temperature *= 0.999
last_fitness = new_fitness
last_improvement = 0
else:
temperature *= 0.9995
last_improvement += 1
if last_improvement > 100:
print("rerolling")
centroids[worst_centroid] = torch.randn_like(centroids[worst_centroid])
last_improvement = 0
temperature *= 1.1
last_fitness = new_fitness
if last_fitness < desired_size * 0.1:
break
temperature = min(1.5, temperature)
if new_fitness < global_best_result:
global_best = n
global_best_result = new_fitness
return torch.nn.functional.normalize(centroids)
""""
centroids = list(zip(*partition(torch.tensor(data), n_clusters**2, max_iter=100)))
BALANCE_WEIGHT = 3e-2
big_clusters = [ ([x], c) for x, c in centroids[:n_clusters] ]
centroids = centroids[n_clusters:]
while centroids:
avg_size = sum(c for _, c in big_clusters) / len(big_clusters)
for i, (items, count) in enumerate(big_clusters):
def match_score(x):
return 1/len(items) * sum(np.dot(x, y) for y in items)
candidate_index, candidate = max(enumerate(centroids), key=lambda x: match_score(x[1][0]) - BALANCE_WEIGHT * max(0, count + x[1][1] - avg_size))
centroids.pop(candidate_index)
big_clusters[i] = (items + [candidate[0]], count + candidate[1])
print([x[1] for x in big_clusters])
"""
centroids = simulated_annealing(torch.tensor(data, device=torch.device("cuda")), n_clusters, max_iter=80000).detach().cpu().numpy()
centroids.astype(np.float16).tofile("centroids.bin")