mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-27 04:53:11 +00:00
68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
import sqlite3
|
|
import hashlib
|
|
from collections import defaultdict
|
|
import numpy
|
|
import random
|
|
import numpy as np
|
|
|
|
db = sqlite3.connect("data.sqlite3")
|
|
db.row_factory = sqlite3.Row
|
|
|
|
val_fraction = 0.2
|
|
def is_val_set(meme1, meme2):
|
|
def is_one_val(meme):
|
|
return hashlib.sha256(meme.encode("utf-8")).digest()[0] / 255 < (val_fraction / 2) # not strictly correct but good enough
|
|
return is_one_val(meme1) or is_one_val(meme2)
|
|
|
|
def fetch_embedding(filename):
|
|
csr = db.execute("SELECT embedding FROM files WHERE filename = ?", (filename,))
|
|
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
|
|
csr.close()
|
|
return x.copy() # PyTorch complains otherwise due to bad
|
|
|
|
def map_rating(rating):
|
|
def map_one(rating):
|
|
match rating:
|
|
case "1": # meme 1 is better
|
|
return 0.9
|
|
case "2":
|
|
return 0.1
|
|
case "2+" | "2p":
|
|
return 0.3
|
|
case "1+" | "1p":
|
|
return 0.7
|
|
case "eq":
|
|
return 0.5
|
|
case _: raise ValueError("invalid rating, please fix")
|
|
|
|
return np.array([map_one(r) for r in rating.split(",")])
|
|
|
|
def fetch_ratings(sets):
|
|
trains = defaultdict(list)
|
|
validations = defaultdict(list)
|
|
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
|
its = set()
|
|
for meme1, meme2, rating, iteration in csr.fetchall():
|
|
if iteration not in its:
|
|
print(iteration)
|
|
its.add(iteration)
|
|
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
|
csr.close()
|
|
return list(x[1] for x in sorted(trains.items()) if str(x[0]) in sets), list(x[1] for x in sorted(validations.items()) if str(x[0]) in sets)
|
|
|
|
def generate_random_permutations(x, n):
|
|
out = []
|
|
for _ in range(n):
|
|
random.shuffle(x)
|
|
out.append(x.copy())
|
|
return out
|
|
|
|
def fetch_all_files():
|
|
csr = db.execute("SELECT filename, embedding FROM files WHERE embedding IS NOT NULL")
|
|
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
|
|
csr.close()
|
|
return x
|
|
|
|
def checkpoint_for(steps):
|
|
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"
|