mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-22 22:10:10 +00:00
63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
![]() |
import torch.nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch
|
||
|
import sqlite3
|
||
|
import random
|
||
|
import numpy
|
||
|
import json
|
||
|
import time
|
||
|
from tqdm import tqdm
|
||
|
import sys
|
||
|
from collections import defaultdict
|
||
|
|
||
|
from model import Config, BradleyTerry
|
||
|
import shared
|
||
|
|
||
|
batch_size = 128
|
||
|
num_pairs = batch_size * 1024
|
||
|
device = "cuda"
|
||
|
|
||
|
config = Config(
|
||
|
d_emb=1152,
|
||
|
n_hidden=1,
|
||
|
n_ensemble=16,
|
||
|
device=device,
|
||
|
dtype=torch.float32,
|
||
|
output_channels=3,
|
||
|
dropout=0.1
|
||
|
)
|
||
|
model = BradleyTerry(config)
|
||
|
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
||
|
model.load_state_dict(torch.load(modelc))
|
||
|
params = sum(p.numel() for p in model.parameters())
|
||
|
print(f"{params/1e6:.1f}M parameters")
|
||
|
print(model)
|
||
|
|
||
|
files = shared.fetch_all_files()
|
||
|
results = {}
|
||
|
model.eval()
|
||
|
|
||
|
with torch.inference_mode():
|
||
|
for bstart in tqdm(range(0, len(files), batch_size)):
|
||
|
batch = files[bstart:bstart + batch_size]
|
||
|
filenames = [ f1 for f1, e1 in batch ]
|
||
|
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ])
|
||
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
||
|
scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy()
|
||
|
|
||
|
|
||
|
channel = int(sys.argv[2])
|
||
|
percentile = float(sys.argv[3])
|
||
|
output_pairs = int(sys.argv[4])
|
||
|
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
|
||
|
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
|
||
|
select_from = top[:int(len(top) * percentile)]
|
||
|
|
||
|
out = []
|
||
|
for _ in range(output_pairs):
|
||
|
# dummy score for compatibility with existing code
|
||
|
out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0))
|
||
|
|
||
|
with open("top.json", "w") as f:
|
||
|
json.dump(out, f)
|