mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-25 12:37:40 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			68 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			68 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import sqlite3
 | |
| import hashlib
 | |
| from collections import defaultdict
 | |
| import numpy
 | |
| import random
 | |
| import numpy as np
 | |
| 
 | |
| db = sqlite3.connect("data.sqlite3")
 | |
| db.row_factory = sqlite3.Row
 | |
| 
 | |
| val_fraction = 0.2
 | |
| def is_val_set(meme1, meme2):
 | |
|     def is_one_val(meme):
 | |
|         return hashlib.sha256(meme.encode("utf-8")).digest()[0] / 255 < (val_fraction / 2) # not strictly correct but good enough
 | |
|     return is_one_val(meme1) or is_one_val(meme2)
 | |
| 
 | |
| def fetch_embedding(filename):
 | |
|     csr = db.execute("SELECT embedding FROM files WHERE filename = ?", (filename,))
 | |
|     x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
 | |
|     csr.close()
 | |
|     return x.copy() # PyTorch complains otherwise due to bad
 | |
| 
 | |
| def map_rating(rating):
 | |
|     def map_one(rating):
 | |
|         match rating:
 | |
|             case "1": # meme 1 is better
 | |
|                 return 0.9
 | |
|             case "2":
 | |
|                 return 0.1
 | |
|             case "2+" | "2p":
 | |
|                 return 0.3
 | |
|             case "1+" | "1p":
 | |
|                 return 0.7
 | |
|             case "eq":
 | |
|                 return 0.5
 | |
|             case _: raise ValueError("invalid rating, please fix")
 | |
| 
 | |
|     return np.array([map_one(r) for r in rating.split(",")])
 | |
| 
 | |
| def fetch_ratings(sets):
 | |
|     trains = defaultdict(list)
 | |
|     validations = defaultdict(list)
 | |
|     csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
 | |
|     its = set()
 | |
|     for meme1, meme2, rating, iteration in csr.fetchall():
 | |
|         if iteration not in its:
 | |
|             print(iteration)
 | |
|         its.add(iteration)
 | |
|         (validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
 | |
|     csr.close()
 | |
|     return list(x[1] for x in sorted(trains.items()) if str(x[0]) in sets), list(x[1] for x in sorted(validations.items()) if str(x[0]) in sets)
 | |
| 
 | |
| def generate_random_permutations(x, n):
 | |
|     out = []
 | |
|     for _ in range(n):
 | |
|         random.shuffle(x)
 | |
|         out.append(x.copy())
 | |
|     return out
 | |
| 
 | |
| def fetch_all_files():
 | |
|     csr = db.execute("SELECT filename, embedding FROM files WHERE embedding IS NOT NULL")
 | |
|     x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
 | |
|     csr.close()
 | |
|     return x
 | |
| 
 | |
| def checkpoint_for(steps):
 | |
|     return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"
 | 
