mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-21 13:30:06 +00:00
81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
import torch.nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
import sqlite3
|
|
import random
|
|
import numpy
|
|
import json
|
|
import msgpack
|
|
import sys
|
|
from safetensors.torch import save_file
|
|
|
|
from model import Config, BradleyTerry
|
|
import shared
|
|
|
|
device = "cpu"
|
|
|
|
config = Config(
|
|
d_emb=1152,
|
|
n_hidden=1,
|
|
n_ensemble=16,
|
|
device=device,
|
|
dtype=torch.float32,
|
|
output_channels=3,
|
|
dropout=0.1
|
|
)
|
|
model = BradleyTerry(config)
|
|
model.eval()
|
|
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
|
model.load_state_dict(torch.load(modelc))
|
|
params = sum(p.numel() for p in model.parameters())
|
|
print(f"{params/1e6:.1f}M parameters")
|
|
print(model)
|
|
|
|
torch.random.manual_seed(1)
|
|
|
|
for x in model.ensemble.models:
|
|
x.output.bias.data.fill_(0)
|
|
|
|
out_layers = []
|
|
out_bias = []
|
|
|
|
with torch.inference_mode():
|
|
# TODO: I don't think this actually works for more than 1 hidden layer
|
|
for layer in range(config.n_hidden):
|
|
big_layer = torch.zeros(config.n_ensemble * config.d_emb, config.d_emb)
|
|
big_bias = torch.zeros(config.n_ensemble * config.d_emb)
|
|
for i in range(config.n_ensemble):
|
|
big_layer[i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].hidden[layer].weight.data.clone()
|
|
big_bias[i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].hidden[layer].bias.data.clone()
|
|
out_layers.append(big_layer)
|
|
out_bias.append(big_bias)
|
|
# we do not need to preserve the bias on the downprojection as the win probability calculation is shift-invariant
|
|
downprojection = torch.zeros(config.output_channels, config.n_ensemble * config.d_emb)
|
|
for i in range(config.n_ensemble):
|
|
downprojection[:, i*config.d_emb:(i+1)*config.d_emb] = model.ensemble.models[i].output.weight.data.clone()
|
|
|
|
for i in range(10):
|
|
input = torch.randn(4, config.d_emb)
|
|
ground_truth_result = model.ensemble(input.unsqueeze(0).expand((config.n_ensemble, *input.shape))).mean(dim=0).T
|
|
r_result = input
|
|
for (layer, bias) in zip(out_layers, out_bias):
|
|
r_result = torch.matmul(layer, r_result.T) + bias.unsqueeze(-1).expand(config.n_ensemble * config.d_emb, input.shape[0])
|
|
print(r_result.shape, bias.shape)
|
|
r_result = F.silu(r_result)
|
|
r_result = torch.matmul(downprojection, r_result) / config.n_ensemble
|
|
error = torch.mean(r_result - ground_truth_result)
|
|
print(error)
|
|
assert error.detach().cpu().numpy() < 1e-4
|
|
|
|
print("test vector:")
|
|
#print(input.flatten().tolist())
|
|
print("ground truth result:")
|
|
print(ground_truth_result.shape)
|
|
print(ground_truth_result.T.flatten().tolist())
|
|
|
|
save_file({
|
|
"up_proj": out_layers[0],
|
|
"bias": out_bias[0],
|
|
"down_proj": downprojection
|
|
}, "model.safetensors")
|