1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-01-02 21:40:31 +00:00

Sparse autoencoder testing

This commit is contained in:
osmarks 2024-10-05 17:22:44 +01:00
parent fc6d0c9409
commit 1d0ff95955
6 changed files with 340 additions and 0 deletions

40
sae/compare_plots.py Normal file
View File

@ -0,0 +1,40 @@
# claude-3
import json
import matplotlib.pyplot as plt
import sys
logs = sys.argv[1:]
def read_log(log):
# Read data from log.jsonl
data = []
with open(log, 'r') as file:
for line in file:
data.append(json.loads(line))
print(log, data[0]) # config
# Extract steps, loss, and val_loss
steps = [entry['step'] for entry in data if "loss" in entry]
loss = [entry['loss'] for entry in data if "loss" in entry]
# Calculate rolling average for loss
window_size = 50
rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)]
rolling_steps = steps[window_size-1:]
return rolling_steps, rolling_avg
# Create the plot
plt.figure(figsize=(10, 6))
#plt.plot(steps, loss, label='Loss')
for i, log in enumerate(logs):
rolling_steps, rolling_avg = read_log(log)
plt.plot(rolling_steps, rolling_avg, label=f"{i}")
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

93
sae/export_features.py Normal file
View File

@ -0,0 +1,93 @@
import torch.nn
import torch.nn.functional as F
import torch
import numpy
import json
import time
from tqdm import tqdm
from dataclasses import dataclass, asdict
import numpy as np
import base64
import asyncio
import aiohttp
import aioitertools
from model import SAEConfig, SAE
from shared import train_split, loaded_arrays, ckpt_path
model_path, _ = ckpt_path(12991)
model = SAE(SAEConfig(
d_emb=1152,
d_hidden=65536,
top_k=128,
device="cuda",
dtype=torch.float32,
up_proj_bias=False
))
state_dict = torch.load(model_path)
batch_size = 1024
retrieve_batch_size = 512
with torch.inference_mode():
model.load_state_dict(state_dict)
model.eval()
validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):]
for batch_start in tqdm(range(0, len(validation_set), batch_size)):
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ])
batch = torch.Tensor(batch).to("cuda")
reconstructions = model(batch).float()
feature_frequencies = model.reset_counters()
features = model.up_proj.weight.cpu().numpy()
meme_search_backend = "http://localhost:1707/"
memes_url = "https://i.osmarks.net/memes-or-something/"
meme_search_url = "https://mse.osmarks.net/?e="
def emb_url(embedding):
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
async def get_exemplars():
async with aiohttp.ClientSession():
for base in tqdm(range(0, len(features), retrieve_batch_size)):
chunk = features[base:base + retrieve_batch_size]
with open(f"feature_dumps/features{base}.html", "w") as f:
f.write("""<!DOCTYPE html>
<title>Embeddings SAE Features</title>
<style>
div img {
width: 20%
}
</style>
<body><h1>Embeddings SAE Features</h1>""")
async def lookup(embedding):
async with aiohttp.request("POST", meme_search_backend, json={
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
"k": 10
}) as res:
return (await res.json())["matches"]
exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
f.write(f"""
<h2>Feature {offset + base}</h2>
<h3>Frequency {frequency / len(validation_set)}</h3>
<div>
<h4><a href="{emb_url(feature)}">Max</a></h4>
""")
for match in exemplars[offset]:
f.write(f'<img loading="lazy" src="{memes_url+match[1]}">')
f.write(f'<h4><a href="{emb_url(-feature)}">Min</a></h4>')
for match in negative_exemplars[offset]:
f.write(f'<img loading="lazy" src="{memes_url+match[1]}">')
f.write("</div>")
asyncio.run(get_exemplars())

43
sae/model.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from functools import partial
@dataclass
class SAEConfig:
d_emb: int
d_hidden: int
top_k: int
up_proj_bias: bool
device: str
dtype: torch.dtype
class SAE(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.up_proj = nn.Linear(config.d_emb, config.d_hidden, dtype=config.dtype, device=config.device, bias=config.up_proj_bias)
self.down_proj = nn.Linear(config.d_hidden, config.d_emb, dtype=config.dtype, device=config.device)
self.down_proj.weight = nn.Parameter(self.up_proj.weight.T.clone())
self.feature_activation_counter = torch.zeros(config.d_hidden, dtype=torch.int32, device=config.device)
self.reset_counters()
def reset_counters(self):
old = self.feature_activation_counter.detach().cpu().numpy()
torch.zero_(self.feature_activation_counter)
return old
def forward(self, embs):
x = self.up_proj(embs)
x = F.relu(x)
topk = torch.kthvalue(x, k=(self.config.d_hidden - self.config.top_k), dim=-1)
thresholds = topk.values.unsqueeze(-1).expand_as(x)
zero = torch.zeros_like(x)
# If multiple values are the same, we don't actually pick exactly k values. This can happen quite easily if for some reason a lot of values are negative and thus get ReLUed to 0.
# This should not really happen but it does.
# This uses greater than rather than greater than or equal to work around this. We compensate for this by setting k off by one in the kthvalue call.
mask = x > thresholds
x = torch.where(mask, x, zero)
self.feature_activation_counter += mask.sum(0)
return self.down_proj(x)

12
sae/shared.py Normal file
View File

@ -0,0 +1,12 @@
import torch
import pyarrow as pa
torch.set_float32_matmul_precision("high")
with pa.memory_map("../../sample_1m.arrow", "r") as source:
loaded_arrays = pa.ipc.open_file(source).read_all()
train_split = 0.8
def ckpt_path(steps):
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"

51
sae/shuffle.py Normal file
View File

@ -0,0 +1,51 @@
# claude output
import pyarrow as pa
import numpy as np
import pyarrow.compute as pc
def shuffle_arrow_file(input_file, output_file, chunk_size=100000):
# Open the input file
with pa.memory_map(input_file, 'r') as source:
reader = pa.ipc.open_file(source)
# Get the schema and create a writer for the output file
schema = reader.schema
with pa.OSFile(output_file, 'wb') as sink:
writer = pa.ipc.new_file(sink, schema)
# Calculate total number of rows
total_rows = reader.num_record_batches
all_batches = [reader.get_batch(i) for i in range(total_rows)]
total_rows = sum(batch.num_rows for batch in all_batches)
# Generate shuffled indices
indices = np.random.permutation(total_rows)
# Process in chunks
for i in range(0, total_rows, chunk_size):
# Get indices for this chunk
chunk_indices = indices[i:i+chunk_size]
# Take rows using these indices
chunk_data = []
for idx in chunk_indices:
batch_idx = 0
row_idx = idx
while row_idx >= all_batches[batch_idx].num_rows:
row_idx -= all_batches[batch_idx].num_rows
batch_idx += 1
chunk_data.append(all_batches[batch_idx].slice(row_idx, 1))
chunk = pa.Table.from_batches(chunk_data)
# Write the chunk
writer.write_table(chunk)
# Close the writer
writer.close()
# Usage
input_file = "../../sample_1m.arrow"
output_file = "shuffled_trainset.arrow"
shuffle_arrow_file(input_file, output_file)

101
sae/train.py Normal file
View File

@ -0,0 +1,101 @@
import torch.nn
import torch.nn.functional as F
import torch
import numpy
import json
import time
from tqdm import tqdm
from dataclasses import dataclass, asdict
from model import SAEConfig, SAE
from shared import train_split, loaded_arrays, ckpt_path
device = "cuda"
@dataclass
class TrainConfig:
model: SAEConfig
lr: float
weight_decay: float
batch_size: int
epochs: int
compile: bool
config = TrainConfig(
model=SAEConfig(
d_emb=1152,
d_hidden=65536,
top_k=128,
device=device,
dtype=torch.float32,
up_proj_bias=False
),
lr=3e-4,
weight_decay=0.0,
batch_size=64,
epochs=5,
compile=True,
)
def exprange(min, max, n):
lmin, lmax = math.log(min), math.log(max)
step = (lmax - lmin) / (n - 1)
return (math.exp(lmin + step * i) for i in range(n))
model = SAE(config.model)
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
def train_step(model, batch):
optimizer.zero_grad()
reconstructions = model(batch).float()
loss = F.mse_loss(reconstructions, batch)
loss.backward()
optimizer.step()
return loss
if config.compile:
print("compiling...")
train_step = torch.compile(train_step)
def save_ckpt(log, steps):
#print("saving...")
modelc, optimc = ckpt_path(steps)
torch.save(optimizer.state_dict(), optimc)
torch.save({"model": model.state_dict(), "config": config}, modelc)
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, torch.dtype):
return str(o)
else: return super().default(o)
logfile = f"logs/log-{time.time()}.jsonl"
with open(logfile, "w") as log:
steps = 0
log.write(JSONEncoder().encode(asdict(config)) + "\n")
for epoch in range(config.epochs):
batch = []
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
for batch_start in t:
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
if len(batch) == config.batch_size:
batch = torch.Tensor(batch).to(device)
loss = train_step(model, batch)
loss = loss.detach().cpu().item()
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 5000 == 0: save_ckpt(log, steps)
steps += 1
save_ckpt(log, steps)
print(model.feature_activation_counter.cpu())
ctr = model.reset_counters()
print(ctr)
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
print(logfile)