1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2024-11-11 06:19:54 +00:00
meme-search-engine/meme-rater/model.py

52 lines
1.7 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from functools import partial
import math
@dataclass
class Config:
d_emb: int
n_hidden: int
n_ensemble: int
device: str
dtype: torch.dtype
dropout: float
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device)
def forward(self, embs):
x = embs
for (layer, dropout) in zip(self.hidden, self.dropout):
x = F.silu(layer(dropout(x)))
return self.output(x)
class Ensemble(nn.Module):
def __init__(self, config):
super().__init__()
self.models = nn.ModuleList([ Model(config) for i in range(config.n_ensemble) ])
# model batch
def forward(self, embs):
xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
return xs.squeeze(-1)
class BradleyTerry(nn.Module):
def __init__(self, config):
super().__init__()
self.ensemble = Ensemble(config)
def forward(self, embs): # model batch input=2 d_emb
scores1 = self.ensemble(embs[:, :, 0]).float() # model batch
scores2 = self.ensemble(embs[:, :, 1]).float()
# win probabilities
#print(scores1, scores2)
probs = torch.sigmoid(scores1 - scores2) # model batch
#print(probs)
return probs