mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-14 07:44:49 +00:00
76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
import torch.nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
import sqlite3
|
|
import random
|
|
import numpy
|
|
import json
|
|
import time
|
|
from tqdm import tqdm
|
|
from torch.func import functional_call, vmap, grad
|
|
|
|
from model import Config, BradleyTerry
|
|
import shared
|
|
|
|
steps = 855
|
|
batch_size = 128
|
|
num_pairs = batch_size * 1024
|
|
device = "cuda"
|
|
|
|
config = Config(
|
|
d_emb=1152,
|
|
n_hidden=1,
|
|
n_ensemble=1,
|
|
device=device,
|
|
dtype=torch.bfloat16
|
|
)
|
|
model = BradleyTerry(config)
|
|
modelc, _ = shared.checkpoint_for(855)
|
|
model.load_state_dict(torch.load(modelc))
|
|
params = sum(p.numel() for p in model.parameters())
|
|
print(f"{params/1e6:.1f}M parameters")
|
|
print(model)
|
|
|
|
files = shared.fetch_all_files()
|
|
importance = {}
|
|
|
|
params = {k: v.detach() for k, v in model.named_parameters()}
|
|
buffers = {k: v.detach() for k, v in model.named_buffers()}
|
|
|
|
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
|
|
def compute_loss(params, buffers, sample, target):
|
|
batch = sample.unsqueeze(0)
|
|
targets = target.unsqueeze(0)
|
|
|
|
predictions = functional_call(model, (params, buffers), (batch,))
|
|
loss = F.binary_cross_entropy(predictions, targets)
|
|
return loss
|
|
|
|
ft_compute_grad = grad(compute_loss)
|
|
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1))
|
|
|
|
pairs = []
|
|
for _ in range(num_pairs):
|
|
pairs.append(tuple(random.sample(files, 2)))
|
|
|
|
for bstart in tqdm(range(0, len(pairs), batch_size)):
|
|
batch = pairs[bstart:bstart + batch_size]
|
|
filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
|
|
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
|
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
|
|
#win_probs = model(inputs)
|
|
# TODO gradients
|
|
# don't take variance: do backwards pass and compute gradient norm
|
|
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device))
|
|
total_grad_norms = torch.zeros(batch_size).to(device)
|
|
for k, v in grads.items():
|
|
param_dims = tuple(range(1, len(v.shape)))
|
|
total_grad_norms += torch.linalg.vector_norm(v, dim=param_dims)
|
|
tgn = total_grad_norms.cpu().numpy()
|
|
|
|
for filename, tg in zip(filenames, tgn):
|
|
importance[filename] = float(tg)
|
|
|
|
top = sorted(importance.items(), key=lambda x: -x[1])
|
|
with open("top.json", "w") as f:
|
|
json.dump(top[:256], f) |