mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-31 07:13:04 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			154 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			154 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import numpy as np
 | |
| import msgpack
 | |
| import math
 | |
| import torch
 | |
| from torch import autograd
 | |
| from torch.nn.modules import distance
 | |
| 
 | |
| n_dims = 1152
 | |
| data = np.fromfile("500k_vecs.bin", dtype=np.float16).reshape(-1, n_dims).astype(np.float32)
 | |
| n_clusters = 42
 | |
| 
 | |
| def partition_soft(vectors, k, max_iter=100, batch_size=8192):
 | |
|     n_vectors = len(vectors)
 | |
|     perm = torch.randperm(n_vectors)
 | |
|     centroids = vectors[perm[:k]]
 | |
|     biases = torch.randn(k, device=vectors.device)
 | |
|     centroids.requires_grad = True
 | |
|     biases.requires_grad = True
 | |
|     opt = torch.optim.Adam([centroids], lr=0.01)
 | |
|     temperature = 1.0
 | |
|     size_scale = 15.0
 | |
|     bias_scale = 1.0
 | |
|     score_scale = 0.1
 | |
| 
 | |
|     desired_size = n_vectors / k
 | |
| 
 | |
|     with autograd.detect_anomaly():
 | |
|         for it in range(max_iter):
 | |
|             cluster_sizes = torch.zeros(k, device=vectors.device)
 | |
|             norm_centroids = torch.nn.functional.normalize(centroids)
 | |
|             score = torch.zeros(k, device=vectors.device)
 | |
|             #soft_assignment_entropy = 0
 | |
| 
 | |
|             for i in range(0, n_vectors, batch_size):
 | |
|                 batch = vectors[i:i+batch_size]
 | |
| 
 | |
|                 similarities = torch.matmul(batch, norm_centroids.T)
 | |
| 
 | |
|                 soft_assignments = ((similarities + biases) / temperature).softmax(dim=1)
 | |
|                 #print(soft_assignments[0])
 | |
| 
 | |
|                 #print(soft_assignments.shape, similarities.shape)
 | |
| 
 | |
|                 #entropy_by_vector = -soft_assignments.mul(soft_assignments.log2()).sum(dim=1)
 | |
|                 #soft_assignment_entropy += entropy_by_vector.mean()
 | |
|                 cluster_sizes += soft_assignments.sum(dim=0)
 | |
|                 score += similarities.mean(dim=0)
 | |
| 
 | |
|             opt.zero_grad()
 | |
| 
 | |
|             distances_from_ideal_cluster_size = (cluster_sizes - desired_size).pow(2) / (desired_size ** 2)
 | |
|             size_loss = distances_from_ideal_cluster_size.mean()
 | |
|             bias_loss = biases.pow(2).mean()
 | |
|             score_loss = -score.mean()
 | |
|             loss = size_scale * size_loss + bias_scale * bias_loss + score_scale * score_loss
 | |
|             loss.backward()
 | |
|             opt.step()
 | |
| 
 | |
|             print(temperature, size_scale * size_loss.detach().tolist(), bias_scale * bias_loss.detach().tolist(), score_scale * score_loss.detach().tolist(), cluster_sizes.tolist())
 | |
| 
 | |
|             #centroids = new_centroids / cluster_sizes.unsqueeze(1)
 | |
| 
 | |
|             #if torch.allclose(centroids, new_centroids, rtol=1e-4):
 | |
|             #    break
 | |
| 
 | |
|             if it % 100 == 0:
 | |
|                 temperature *= 0.999
 | |
|                 size_scale *= 1.1
 | |
| 
 | |
|         return centroids.detach().cpu().numpy(), biases.detach().cpu().numpy()
 | |
| 
 | |
| SPILL_K = 2
 | |
| def simulated_annealing(vectors, k, max_iter=100, batch_size=31768):
 | |
|     n_vectors = len(vectors)
 | |
|     centroids = torch.randn(k, n_dims, device=vectors.device)
 | |
|     desired_size = n_vectors / k
 | |
| 
 | |
|     def fitness(centroids):
 | |
|         cluster_sizes = torch.zeros(SPILL_K, k, device=vectors.device, dtype=torch.int32)
 | |
|         norm_centroids = torch.nn.functional.normalize(centroids)
 | |
| 
 | |
|         for i in range(0, n_vectors, batch_size):
 | |
|             batch = vectors[i:i+batch_size]
 | |
| 
 | |
|             similarities = torch.matmul(batch, norm_centroids.T)
 | |
|             values, indices = similarities.topk(SPILL_K, dim=1)
 | |
| 
 | |
|             for j in range(SPILL_K):
 | |
|                 batch_counts = torch.bincount(indices[:, j], minlength=k)
 | |
|                 cluster_sizes[j] += batch_counts
 | |
| 
 | |
|         distances_from_ideal_cluster_size = (cluster_sizes - desired_size).abs()
 | |
|         #print(distances_from_ideal_cluster_size)
 | |
|         return distances_from_ideal_cluster_size.max(), distances_from_ideal_cluster_size.argmax(dim=1)
 | |
| 
 | |
|     global_best, global_best_result = None, 1000000
 | |
| 
 | |
|     temperature = 1.0
 | |
| 
 | |
|     last_fitness, _ = fitness(centroids)
 | |
|     last_improvement = 0
 | |
|     for _ in range(max_iter):
 | |
|         n = centroids + torch.randn_like(centroids) * temperature
 | |
|         new_fitness, worst_centroid = fitness(n)
 | |
|         print(last_fitness.tolist(), new_fitness.tolist(), temperature)
 | |
|         if new_fitness < last_fitness:
 | |
|             centroids = n
 | |
|             temperature *= 0.999
 | |
|             last_fitness = new_fitness
 | |
|             last_improvement = 0
 | |
|         else:
 | |
|             temperature *= 0.9995
 | |
|             last_improvement += 1
 | |
|         if last_improvement > 100:
 | |
|             print("rerolling")
 | |
|             centroids[worst_centroid] = torch.randn_like(centroids[worst_centroid])
 | |
|             last_improvement = 0
 | |
|             temperature *= 1.1
 | |
|             last_fitness = new_fitness
 | |
|         if last_fitness < desired_size * 0.1:
 | |
|             break
 | |
|         temperature = min(1.5, temperature)
 | |
|         if new_fitness < global_best_result:
 | |
|             global_best = n
 | |
|             global_best_result = new_fitness
 | |
| 
 | |
|     return torch.nn.functional.normalize(centroids)
 | |
| 
 | |
| """"
 | |
| centroids = list(zip(*partition(torch.tensor(data), n_clusters**2, max_iter=100)))
 | |
| 
 | |
| BALANCE_WEIGHT = 3e-2
 | |
| 
 | |
| big_clusters = [ ([x], c) for x, c in centroids[:n_clusters] ]
 | |
| centroids = centroids[n_clusters:]
 | |
| 
 | |
| while centroids:
 | |
|     avg_size = sum(c for _, c in big_clusters) / len(big_clusters)
 | |
| 
 | |
|     for i, (items, count) in enumerate(big_clusters):
 | |
|         def match_score(x):
 | |
|             return 1/len(items) * sum(np.dot(x, y) for y in items)
 | |
| 
 | |
|         candidate_index, candidate = max(enumerate(centroids), key=lambda x: match_score(x[1][0]) - BALANCE_WEIGHT * max(0, count + x[1][1] - avg_size))
 | |
|         centroids.pop(candidate_index)
 | |
|         big_clusters[i] = (items + [candidate[0]], count + candidate[1])
 | |
| 
 | |
| print([x[1] for x in big_clusters])
 | |
| 
 | |
| """
 | |
| 
 | |
| centroids = simulated_annealing(torch.tensor(data, device=torch.device("cuda")), n_clusters, max_iter=80000).detach().cpu().numpy()
 | |
| centroids.astype(np.float16).tofile("centroids.bin")
 | 
