mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
53 lines
1.8 KiB
Python
53 lines
1.8 KiB
Python
import sqlite3
|
|
import hashlib
|
|
from collections import defaultdict
|
|
import numpy
|
|
import random
|
|
|
|
db = sqlite3.connect("data.sqlite3")
|
|
db.row_factory = sqlite3.Row
|
|
|
|
val_fraction = 0.2
|
|
def is_val_set(meme1, meme2):
|
|
def is_one_val(meme):
|
|
return hashlib.sha256(meme.encode("utf-8")).digest()[0] / 255 < (val_fraction / 2) # not strictly correct but good enough
|
|
return is_one_val(meme1) or is_one_val(meme2)
|
|
|
|
def fetch_embedding(filename):
|
|
csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,))
|
|
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
|
|
csr.close()
|
|
return x.copy() # PyTorch complains otherwise due to bad
|
|
|
|
def map_rating(rating, uncertainty=0.05):
|
|
match rating:
|
|
case "1": # meme 1 is better
|
|
return 1 - uncertainty
|
|
case "2":
|
|
return uncertainty
|
|
case _: raise ValueError("invalid rating, please fix")
|
|
|
|
def fetch_ratings():
|
|
trains = defaultdict(list)
|
|
validations = defaultdict(list)
|
|
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
|
for meme1, meme2, rating, iteration in csr.fetchall():
|
|
(validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
|
csr.close()
|
|
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
|
|
|
|
def generate_random_permutations(x, n):
|
|
out = []
|
|
for _ in range(n):
|
|
random.shuffle(x)
|
|
out.append(x.copy())
|
|
return out
|
|
|
|
def fetch_all_files():
|
|
csr = db.execute("SELECT filename, embedding_vector FROM files")
|
|
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
|
|
csr.close()
|
|
return x
|
|
|
|
def checkpoint_for(steps):
|
|
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt" |