mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-01-02 21:40:31 +00:00
Sparse autoencoder testing
This commit is contained in:
parent
fc6d0c9409
commit
1d0ff95955
40
sae/compare_plots.py
Normal file
40
sae/compare_plots.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# claude-3
|
||||||
|
|
||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import sys
|
||||||
|
|
||||||
|
logs = sys.argv[1:]
|
||||||
|
|
||||||
|
def read_log(log):
|
||||||
|
# Read data from log.jsonl
|
||||||
|
data = []
|
||||||
|
with open(log, 'r') as file:
|
||||||
|
for line in file:
|
||||||
|
data.append(json.loads(line))
|
||||||
|
|
||||||
|
print(log, data[0]) # config
|
||||||
|
|
||||||
|
# Extract steps, loss, and val_loss
|
||||||
|
steps = [entry['step'] for entry in data if "loss" in entry]
|
||||||
|
loss = [entry['loss'] for entry in data if "loss" in entry]
|
||||||
|
|
||||||
|
# Calculate rolling average for loss
|
||||||
|
window_size = 50
|
||||||
|
rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)]
|
||||||
|
rolling_steps = steps[window_size-1:]
|
||||||
|
|
||||||
|
return rolling_steps, rolling_avg
|
||||||
|
|
||||||
|
# Create the plot
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
#plt.plot(steps, loss, label='Loss')
|
||||||
|
for i, log in enumerate(logs):
|
||||||
|
rolling_steps, rolling_avg = read_log(log)
|
||||||
|
plt.plot(rolling_steps, rolling_avg, label=f"{i}")
|
||||||
|
|
||||||
|
plt.xlabel('Steps')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
plt.show()
|
93
sae/export_features.py
Normal file
93
sae/export_features.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import torch.nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
import numpy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import aioitertools
|
||||||
|
|
||||||
|
from model import SAEConfig, SAE
|
||||||
|
from shared import train_split, loaded_arrays, ckpt_path
|
||||||
|
|
||||||
|
model_path, _ = ckpt_path(12991)
|
||||||
|
|
||||||
|
model = SAE(SAEConfig(
|
||||||
|
d_emb=1152,
|
||||||
|
d_hidden=65536,
|
||||||
|
top_k=128,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32,
|
||||||
|
up_proj_bias=False
|
||||||
|
))
|
||||||
|
|
||||||
|
state_dict = torch.load(model_path)
|
||||||
|
|
||||||
|
batch_size = 1024
|
||||||
|
retrieve_batch_size = 512
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):]
|
||||||
|
|
||||||
|
for batch_start in tqdm(range(0, len(validation_set), batch_size)):
|
||||||
|
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ])
|
||||||
|
batch = torch.Tensor(batch).to("cuda")
|
||||||
|
reconstructions = model(batch).float()
|
||||||
|
|
||||||
|
feature_frequencies = model.reset_counters()
|
||||||
|
features = model.up_proj.weight.cpu().numpy()
|
||||||
|
|
||||||
|
meme_search_backend = "http://localhost:1707/"
|
||||||
|
memes_url = "https://i.osmarks.net/memes-or-something/"
|
||||||
|
meme_search_url = "https://mse.osmarks.net/?e="
|
||||||
|
|
||||||
|
def emb_url(embedding):
|
||||||
|
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
|
||||||
|
|
||||||
|
async def get_exemplars():
|
||||||
|
async with aiohttp.ClientSession():
|
||||||
|
for base in tqdm(range(0, len(features), retrieve_batch_size)):
|
||||||
|
chunk = features[base:base + retrieve_batch_size]
|
||||||
|
with open(f"feature_dumps/features{base}.html", "w") as f:
|
||||||
|
f.write("""<!DOCTYPE html>
|
||||||
|
<title>Embeddings SAE Features</title>
|
||||||
|
<style>
|
||||||
|
div img {
|
||||||
|
width: 20%
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
<body><h1>Embeddings SAE Features</h1>""")
|
||||||
|
|
||||||
|
async def lookup(embedding):
|
||||||
|
async with aiohttp.request("POST", meme_search_backend, json={
|
||||||
|
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
||||||
|
"k": 10
|
||||||
|
}) as res:
|
||||||
|
return (await res.json())["matches"]
|
||||||
|
|
||||||
|
exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
|
||||||
|
negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
|
||||||
|
|
||||||
|
for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
|
||||||
|
f.write(f"""
|
||||||
|
<h2>Feature {offset + base}</h2>
|
||||||
|
<h3>Frequency {frequency / len(validation_set)}</h3>
|
||||||
|
<div>
|
||||||
|
<h4><a href="{emb_url(feature)}">Max</a></h4>
|
||||||
|
""")
|
||||||
|
for match in exemplars[offset]:
|
||||||
|
f.write(f'<img loading="lazy" src="{memes_url+match[1]}">')
|
||||||
|
f.write(f'<h4><a href="{emb_url(-feature)}">Min</a></h4>')
|
||||||
|
for match in negative_exemplars[offset]:
|
||||||
|
f.write(f'<img loading="lazy" src="{memes_url+match[1]}">')
|
||||||
|
f.write("</div>")
|
||||||
|
|
||||||
|
asyncio.run(get_exemplars())
|
43
sae/model.py
Normal file
43
sae/model.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SAEConfig:
|
||||||
|
d_emb: int
|
||||||
|
d_hidden: int
|
||||||
|
top_k: int
|
||||||
|
up_proj_bias: bool
|
||||||
|
device: str
|
||||||
|
dtype: torch.dtype
|
||||||
|
|
||||||
|
class SAE(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.up_proj = nn.Linear(config.d_emb, config.d_hidden, dtype=config.dtype, device=config.device, bias=config.up_proj_bias)
|
||||||
|
self.down_proj = nn.Linear(config.d_hidden, config.d_emb, dtype=config.dtype, device=config.device)
|
||||||
|
self.down_proj.weight = nn.Parameter(self.up_proj.weight.T.clone())
|
||||||
|
self.feature_activation_counter = torch.zeros(config.d_hidden, dtype=torch.int32, device=config.device)
|
||||||
|
self.reset_counters()
|
||||||
|
|
||||||
|
def reset_counters(self):
|
||||||
|
old = self.feature_activation_counter.detach().cpu().numpy()
|
||||||
|
torch.zero_(self.feature_activation_counter)
|
||||||
|
return old
|
||||||
|
|
||||||
|
def forward(self, embs):
|
||||||
|
x = self.up_proj(embs)
|
||||||
|
x = F.relu(x)
|
||||||
|
topk = torch.kthvalue(x, k=(self.config.d_hidden - self.config.top_k), dim=-1)
|
||||||
|
thresholds = topk.values.unsqueeze(-1).expand_as(x)
|
||||||
|
zero = torch.zeros_like(x)
|
||||||
|
# If multiple values are the same, we don't actually pick exactly k values. This can happen quite easily if for some reason a lot of values are negative and thus get ReLUed to 0.
|
||||||
|
# This should not really happen but it does.
|
||||||
|
# This uses greater than rather than greater than or equal to work around this. We compensate for this by setting k off by one in the kthvalue call.
|
||||||
|
mask = x > thresholds
|
||||||
|
x = torch.where(mask, x, zero)
|
||||||
|
self.feature_activation_counter += mask.sum(0)
|
||||||
|
return self.down_proj(x)
|
12
sae/shared.py
Normal file
12
sae/shared.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import torch
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
with pa.memory_map("../../sample_1m.arrow", "r") as source:
|
||||||
|
loaded_arrays = pa.ipc.open_file(source).read_all()
|
||||||
|
|
||||||
|
train_split = 0.8
|
||||||
|
|
||||||
|
def ckpt_path(steps):
|
||||||
|
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"
|
51
sae/shuffle.py
Normal file
51
sae/shuffle.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# claude output
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow.compute as pc
|
||||||
|
|
||||||
|
def shuffle_arrow_file(input_file, output_file, chunk_size=100000):
|
||||||
|
# Open the input file
|
||||||
|
with pa.memory_map(input_file, 'r') as source:
|
||||||
|
reader = pa.ipc.open_file(source)
|
||||||
|
|
||||||
|
# Get the schema and create a writer for the output file
|
||||||
|
schema = reader.schema
|
||||||
|
with pa.OSFile(output_file, 'wb') as sink:
|
||||||
|
writer = pa.ipc.new_file(sink, schema)
|
||||||
|
|
||||||
|
# Calculate total number of rows
|
||||||
|
total_rows = reader.num_record_batches
|
||||||
|
all_batches = [reader.get_batch(i) for i in range(total_rows)]
|
||||||
|
total_rows = sum(batch.num_rows for batch in all_batches)
|
||||||
|
|
||||||
|
# Generate shuffled indices
|
||||||
|
indices = np.random.permutation(total_rows)
|
||||||
|
|
||||||
|
# Process in chunks
|
||||||
|
for i in range(0, total_rows, chunk_size):
|
||||||
|
# Get indices for this chunk
|
||||||
|
chunk_indices = indices[i:i+chunk_size]
|
||||||
|
|
||||||
|
# Take rows using these indices
|
||||||
|
chunk_data = []
|
||||||
|
for idx in chunk_indices:
|
||||||
|
batch_idx = 0
|
||||||
|
row_idx = idx
|
||||||
|
while row_idx >= all_batches[batch_idx].num_rows:
|
||||||
|
row_idx -= all_batches[batch_idx].num_rows
|
||||||
|
batch_idx += 1
|
||||||
|
chunk_data.append(all_batches[batch_idx].slice(row_idx, 1))
|
||||||
|
|
||||||
|
chunk = pa.Table.from_batches(chunk_data)
|
||||||
|
|
||||||
|
# Write the chunk
|
||||||
|
writer.write_table(chunk)
|
||||||
|
|
||||||
|
# Close the writer
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
input_file = "../../sample_1m.arrow"
|
||||||
|
output_file = "shuffled_trainset.arrow"
|
||||||
|
shuffle_arrow_file(input_file, output_file)
|
101
sae/train.py
Normal file
101
sae/train.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import torch.nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
import numpy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
from model import SAEConfig, SAE
|
||||||
|
from shared import train_split, loaded_arrays, ckpt_path
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainConfig:
|
||||||
|
model: SAEConfig
|
||||||
|
lr: float
|
||||||
|
weight_decay: float
|
||||||
|
batch_size: int
|
||||||
|
epochs: int
|
||||||
|
compile: bool
|
||||||
|
|
||||||
|
config = TrainConfig(
|
||||||
|
model=SAEConfig(
|
||||||
|
d_emb=1152,
|
||||||
|
d_hidden=65536,
|
||||||
|
top_k=128,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
up_proj_bias=False
|
||||||
|
),
|
||||||
|
lr=3e-4,
|
||||||
|
weight_decay=0.0,
|
||||||
|
batch_size=64,
|
||||||
|
epochs=5,
|
||||||
|
compile=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def exprange(min, max, n):
|
||||||
|
lmin, lmax = math.log(min), math.log(max)
|
||||||
|
step = (lmax - lmin) / (n - 1)
|
||||||
|
return (math.exp(lmin + step * i) for i in range(n))
|
||||||
|
|
||||||
|
model = SAE(config.model)
|
||||||
|
params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
|
||||||
|
|
||||||
|
def train_step(model, batch):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
reconstructions = model(batch).float()
|
||||||
|
loss = F.mse_loss(reconstructions, batch)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
if config.compile:
|
||||||
|
print("compiling...")
|
||||||
|
train_step = torch.compile(train_step)
|
||||||
|
|
||||||
|
def save_ckpt(log, steps):
|
||||||
|
#print("saving...")
|
||||||
|
modelc, optimc = ckpt_path(steps)
|
||||||
|
torch.save(optimizer.state_dict(), optimc)
|
||||||
|
torch.save({"model": model.state_dict(), "config": config}, modelc)
|
||||||
|
|
||||||
|
class JSONEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o):
|
||||||
|
if isinstance(o, torch.dtype):
|
||||||
|
return str(o)
|
||||||
|
else: return super().default(o)
|
||||||
|
|
||||||
|
logfile = f"logs/log-{time.time()}.jsonl"
|
||||||
|
with open(logfile, "w") as log:
|
||||||
|
steps = 0
|
||||||
|
log.write(JSONEncoder().encode(asdict(config)) + "\n")
|
||||||
|
for epoch in range(config.epochs):
|
||||||
|
batch = []
|
||||||
|
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
|
||||||
|
for batch_start in t:
|
||||||
|
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
|
||||||
|
|
||||||
|
if len(batch) == config.batch_size:
|
||||||
|
batch = torch.Tensor(batch).to(device)
|
||||||
|
loss = train_step(model, batch)
|
||||||
|
loss = loss.detach().cpu().item()
|
||||||
|
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
|
||||||
|
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
||||||
|
if steps % 5000 == 0: save_ckpt(log, steps)
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
save_ckpt(log, steps)
|
||||||
|
print(model.feature_activation_counter.cpu())
|
||||||
|
ctr = model.reset_counters()
|
||||||
|
print(ctr)
|
||||||
|
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
|
||||||
|
|
||||||
|
print(logfile)
|
Loading…
Reference in New Issue
Block a user