mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-02-01 03:39:13 +00:00
43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
|
|
@dataclass
|
|
class SAEConfig:
|
|
d_emb: int
|
|
d_hidden: int
|
|
top_k: int
|
|
up_proj_bias: bool
|
|
device: str
|
|
dtype: torch.dtype
|
|
|
|
class SAE(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.up_proj = nn.Linear(config.d_emb, config.d_hidden, dtype=config.dtype, device=config.device, bias=config.up_proj_bias)
|
|
self.down_proj = nn.Linear(config.d_hidden, config.d_emb, dtype=config.dtype, device=config.device)
|
|
self.down_proj.weight = nn.Parameter(self.up_proj.weight.T.clone())
|
|
self.feature_activation_counter = torch.zeros(config.d_hidden, dtype=torch.int32, device=config.device)
|
|
self.reset_counters()
|
|
|
|
def reset_counters(self):
|
|
old = self.feature_activation_counter.detach().cpu().numpy()
|
|
torch.zero_(self.feature_activation_counter)
|
|
return old
|
|
|
|
def forward(self, embs):
|
|
x = self.up_proj(embs)
|
|
x = F.relu(x)
|
|
topk = torch.kthvalue(x, k=(self.config.d_hidden - self.config.top_k), dim=-1)
|
|
thresholds = topk.values.unsqueeze(-1).expand_as(x)
|
|
zero = torch.zeros_like(x)
|
|
# If multiple values are the same, we don't actually pick exactly k values. This can happen quite easily if for some reason a lot of values are negative and thus get ReLUed to 0.
|
|
# This should not really happen but it does.
|
|
# This uses greater than rather than greater than or equal to work around this. We compensate for this by setting k off by one in the kthvalue call.
|
|
mask = x > thresholds
|
|
x = torch.where(mask, x, zero)
|
|
self.feature_activation_counter += mask.sum(0)
|
|
return self.down_proj(x) |