1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-02-22 05:50:05 +00:00

75 lines
2.1 KiB
Python
Raw Permalink Normal View History

2025-01-18 11:29:03 +00:00
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
import sys
from collections import defaultdict
2025-01-18 17:09:00 +00:00
import msgpack
2025-01-18 11:29:03 +00:00
from model import Config, BradleyTerry
import shared
2025-01-18 17:09:00 +00:00
def fetch_files_with_timestamps():
csr = shared.db.execute("SELECT filename, embedding, timestamp FROM files WHERE embedding IS NOT NULL")
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy(), row[2]) for row in csr.fetchall() ]
csr.close()
return x
batch_size = 2048
2025-01-18 11:29:03 +00:00
device = "cuda"
config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
output_channels=3,
dropout=0.1
)
model = BradleyTerry(config)
2025-01-18 17:09:00 +00:00
model.eval()
2025-01-18 11:29:03 +00:00
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
2025-01-18 17:09:00 +00:00
for x in model.ensemble.models:
x.output.bias.data.fill_(0) # hack to match behaviour of cut-down implementation
results = defaultdict(list)
2025-01-18 11:29:03 +00:00
model.eval()
2025-01-18 17:09:00 +00:00
files = fetch_files_with_timestamps()
2025-01-18 11:29:03 +00:00
with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size]
2025-01-18 17:09:00 +00:00
timestamps = [ t1 for f1, e1, t1 in batch ]
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1, t1 in batch ])
2025-01-18 11:29:03 +00:00
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
2025-01-18 17:09:00 +00:00
scores = model.ensemble(inputs).mean(dim=0).cpu().numpy()
for sr in scores:
for i, s in enumerate(sr):
results[i].append(s)
# add an extra timestamp channel
results[config.output_channels].extend(timestamps)
2025-01-18 11:29:03 +00:00
2025-01-18 17:09:00 +00:00
cdfs = []
# we want to encode scores in one byte, and 255/0xFF is reserved for "greater than maximum bucket"
cdf_bins = 255
for i, s in results.items():
quantiles = numpy.linspace(0, 1, cdf_bins)
cdf = numpy.quantile(numpy.array(s), quantiles)
print(cdf)
cdfs.append(cdf.tolist())
2025-01-18 11:29:03 +00:00
2025-01-18 17:09:00 +00:00
with open("cdfs.msgpack", "wb") as f:
msgpack.pack(cdfs, f)