mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-01-02 21:40:31 +00:00
Sparse autoencoder testing
This commit is contained in:
parent
fc6d0c9409
commit
1d0ff95955
40
sae/compare_plots.py
Normal file
40
sae/compare_plots.py
Normal file
@ -0,0 +1,40 @@
|
||||
# claude-3
|
||||
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
|
||||
logs = sys.argv[1:]
|
||||
|
||||
def read_log(log):
|
||||
# Read data from log.jsonl
|
||||
data = []
|
||||
with open(log, 'r') as file:
|
||||
for line in file:
|
||||
data.append(json.loads(line))
|
||||
|
||||
print(log, data[0]) # config
|
||||
|
||||
# Extract steps, loss, and val_loss
|
||||
steps = [entry['step'] for entry in data if "loss" in entry]
|
||||
loss = [entry['loss'] for entry in data if "loss" in entry]
|
||||
|
||||
# Calculate rolling average for loss
|
||||
window_size = 50
|
||||
rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)]
|
||||
rolling_steps = steps[window_size-1:]
|
||||
|
||||
return rolling_steps, rolling_avg
|
||||
|
||||
# Create the plot
|
||||
plt.figure(figsize=(10, 6))
|
||||
#plt.plot(steps, loss, label='Loss')
|
||||
for i, log in enumerate(logs):
|
||||
rolling_steps, rolling_avg = read_log(log)
|
||||
plt.plot(rolling_steps, rolling_avg, label=f"{i}")
|
||||
|
||||
plt.xlabel('Steps')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.show()
|
93
sae/export_features.py
Normal file
93
sae/export_features.py
Normal file
@ -0,0 +1,93 @@
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import numpy
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from dataclasses import dataclass, asdict
|
||||
import numpy as np
|
||||
import base64
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import aioitertools
|
||||
|
||||
from model import SAEConfig, SAE
|
||||
from shared import train_split, loaded_arrays, ckpt_path
|
||||
|
||||
model_path, _ = ckpt_path(12991)
|
||||
|
||||
model = SAE(SAEConfig(
|
||||
d_emb=1152,
|
||||
d_hidden=65536,
|
||||
top_k=128,
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
up_proj_bias=False
|
||||
))
|
||||
|
||||
state_dict = torch.load(model_path)
|
||||
|
||||
batch_size = 1024
|
||||
retrieve_batch_size = 512
|
||||
|
||||
with torch.inference_mode():
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):]
|
||||
|
||||
for batch_start in tqdm(range(0, len(validation_set), batch_size)):
|
||||
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ])
|
||||
batch = torch.Tensor(batch).to("cuda")
|
||||
reconstructions = model(batch).float()
|
||||
|
||||
feature_frequencies = model.reset_counters()
|
||||
features = model.up_proj.weight.cpu().numpy()
|
||||
|
||||
meme_search_backend = "http://localhost:1707/"
|
||||
memes_url = "https://i.osmarks.net/memes-or-something/"
|
||||
meme_search_url = "https://mse.osmarks.net/?e="
|
||||
|
||||
def emb_url(embedding):
|
||||
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
|
||||
|
||||
async def get_exemplars():
|
||||
async with aiohttp.ClientSession():
|
||||
for base in tqdm(range(0, len(features), retrieve_batch_size)):
|
||||
chunk = features[base:base + retrieve_batch_size]
|
||||
with open(f"feature_dumps/features{base}.html", "w") as f:
|
||||
f.write("""<!DOCTYPE html>
|
||||
<title>Embeddings SAE Features</title>
|
||||
<style>
|
||||
div img {
|
||||
width: 20%
|
||||
}
|
||||
</style>
|
||||
<body><h1>Embeddings SAE Features</h1>""")
|
||||
|
||||
async def lookup(embedding):
|
||||
async with aiohttp.request("POST", meme_search_backend, json={
|
||||
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
||||
"k": 10
|
||||
}) as res:
|
||||
return (await res.json())["matches"]
|
||||
|
||||
exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
|
||||
negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
|
||||
|
||||
for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
|
||||
f.write(f"""
|
||||
<h2>Feature {offset + base}</h2>
|
||||
<h3>Frequency {frequency / len(validation_set)}</h3>
|
||||
<div>
|
||||
<h4><a href="{emb_url(feature)}">Max</a></h4>
|
||||
""")
|
||||
for match in exemplars[offset]:
|
||||
f.write(f'<img loading="lazy" src="{memes_url+match[1]}">')
|
||||
f.write(f'<h4><a href="{emb_url(-feature)}">Min</a></h4>')
|
||||
for match in negative_exemplars[offset]:
|
||||
f.write(f'<img loading="lazy" src="{memes_url+match[1]}">')
|
||||
f.write("</div>")
|
||||
|
||||
asyncio.run(get_exemplars())
|
43
sae/model.py
Normal file
43
sae/model.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
||||
@dataclass
|
||||
class SAEConfig:
|
||||
d_emb: int
|
||||
d_hidden: int
|
||||
top_k: int
|
||||
up_proj_bias: bool
|
||||
device: str
|
||||
dtype: torch.dtype
|
||||
|
||||
class SAE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.up_proj = nn.Linear(config.d_emb, config.d_hidden, dtype=config.dtype, device=config.device, bias=config.up_proj_bias)
|
||||
self.down_proj = nn.Linear(config.d_hidden, config.d_emb, dtype=config.dtype, device=config.device)
|
||||
self.down_proj.weight = nn.Parameter(self.up_proj.weight.T.clone())
|
||||
self.feature_activation_counter = torch.zeros(config.d_hidden, dtype=torch.int32, device=config.device)
|
||||
self.reset_counters()
|
||||
|
||||
def reset_counters(self):
|
||||
old = self.feature_activation_counter.detach().cpu().numpy()
|
||||
torch.zero_(self.feature_activation_counter)
|
||||
return old
|
||||
|
||||
def forward(self, embs):
|
||||
x = self.up_proj(embs)
|
||||
x = F.relu(x)
|
||||
topk = torch.kthvalue(x, k=(self.config.d_hidden - self.config.top_k), dim=-1)
|
||||
thresholds = topk.values.unsqueeze(-1).expand_as(x)
|
||||
zero = torch.zeros_like(x)
|
||||
# If multiple values are the same, we don't actually pick exactly k values. This can happen quite easily if for some reason a lot of values are negative and thus get ReLUed to 0.
|
||||
# This should not really happen but it does.
|
||||
# This uses greater than rather than greater than or equal to work around this. We compensate for this by setting k off by one in the kthvalue call.
|
||||
mask = x > thresholds
|
||||
x = torch.where(mask, x, zero)
|
||||
self.feature_activation_counter += mask.sum(0)
|
||||
return self.down_proj(x)
|
12
sae/shared.py
Normal file
12
sae/shared.py
Normal file
@ -0,0 +1,12 @@
|
||||
import torch
|
||||
import pyarrow as pa
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
with pa.memory_map("../../sample_1m.arrow", "r") as source:
|
||||
loaded_arrays = pa.ipc.open_file(source).read_all()
|
||||
|
||||
train_split = 0.8
|
||||
|
||||
def ckpt_path(steps):
|
||||
return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"
|
51
sae/shuffle.py
Normal file
51
sae/shuffle.py
Normal file
@ -0,0 +1,51 @@
|
||||
# claude output
|
||||
|
||||
import pyarrow as pa
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
|
||||
def shuffle_arrow_file(input_file, output_file, chunk_size=100000):
|
||||
# Open the input file
|
||||
with pa.memory_map(input_file, 'r') as source:
|
||||
reader = pa.ipc.open_file(source)
|
||||
|
||||
# Get the schema and create a writer for the output file
|
||||
schema = reader.schema
|
||||
with pa.OSFile(output_file, 'wb') as sink:
|
||||
writer = pa.ipc.new_file(sink, schema)
|
||||
|
||||
# Calculate total number of rows
|
||||
total_rows = reader.num_record_batches
|
||||
all_batches = [reader.get_batch(i) for i in range(total_rows)]
|
||||
total_rows = sum(batch.num_rows for batch in all_batches)
|
||||
|
||||
# Generate shuffled indices
|
||||
indices = np.random.permutation(total_rows)
|
||||
|
||||
# Process in chunks
|
||||
for i in range(0, total_rows, chunk_size):
|
||||
# Get indices for this chunk
|
||||
chunk_indices = indices[i:i+chunk_size]
|
||||
|
||||
# Take rows using these indices
|
||||
chunk_data = []
|
||||
for idx in chunk_indices:
|
||||
batch_idx = 0
|
||||
row_idx = idx
|
||||
while row_idx >= all_batches[batch_idx].num_rows:
|
||||
row_idx -= all_batches[batch_idx].num_rows
|
||||
batch_idx += 1
|
||||
chunk_data.append(all_batches[batch_idx].slice(row_idx, 1))
|
||||
|
||||
chunk = pa.Table.from_batches(chunk_data)
|
||||
|
||||
# Write the chunk
|
||||
writer.write_table(chunk)
|
||||
|
||||
# Close the writer
|
||||
writer.close()
|
||||
|
||||
# Usage
|
||||
input_file = "../../sample_1m.arrow"
|
||||
output_file = "shuffled_trainset.arrow"
|
||||
shuffle_arrow_file(input_file, output_file)
|
101
sae/train.py
Normal file
101
sae/train.py
Normal file
@ -0,0 +1,101 @@
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import numpy
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from model import SAEConfig, SAE
|
||||
from shared import train_split, loaded_arrays, ckpt_path
|
||||
|
||||
device = "cuda"
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
model: SAEConfig
|
||||
lr: float
|
||||
weight_decay: float
|
||||
batch_size: int
|
||||
epochs: int
|
||||
compile: bool
|
||||
|
||||
config = TrainConfig(
|
||||
model=SAEConfig(
|
||||
d_emb=1152,
|
||||
d_hidden=65536,
|
||||
top_k=128,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
up_proj_bias=False
|
||||
),
|
||||
lr=3e-4,
|
||||
weight_decay=0.0,
|
||||
batch_size=64,
|
||||
epochs=5,
|
||||
compile=True,
|
||||
)
|
||||
|
||||
def exprange(min, max, n):
|
||||
lmin, lmax = math.log(min), math.log(max)
|
||||
step = (lmax - lmin) / (n - 1)
|
||||
return (math.exp(lmin + step * i) for i in range(n))
|
||||
|
||||
model = SAE(config.model)
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
|
||||
|
||||
def train_step(model, batch):
|
||||
optimizer.zero_grad()
|
||||
reconstructions = model(batch).float()
|
||||
loss = F.mse_loss(reconstructions, batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
|
||||
if config.compile:
|
||||
print("compiling...")
|
||||
train_step = torch.compile(train_step)
|
||||
|
||||
def save_ckpt(log, steps):
|
||||
#print("saving...")
|
||||
modelc, optimc = ckpt_path(steps)
|
||||
torch.save(optimizer.state_dict(), optimc)
|
||||
torch.save({"model": model.state_dict(), "config": config}, modelc)
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, torch.dtype):
|
||||
return str(o)
|
||||
else: return super().default(o)
|
||||
|
||||
logfile = f"logs/log-{time.time()}.jsonl"
|
||||
with open(logfile, "w") as log:
|
||||
steps = 0
|
||||
log.write(JSONEncoder().encode(asdict(config)) + "\n")
|
||||
for epoch in range(config.epochs):
|
||||
batch = []
|
||||
t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
|
||||
for batch_start in t:
|
||||
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
|
||||
|
||||
if len(batch) == config.batch_size:
|
||||
batch = torch.Tensor(batch).to(device)
|
||||
loss = train_step(model, batch)
|
||||
loss = loss.detach().cpu().item()
|
||||
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
|
||||
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
||||
if steps % 5000 == 0: save_ckpt(log, steps)
|
||||
steps += 1
|
||||
|
||||
save_ckpt(log, steps)
|
||||
print(model.feature_activation_counter.cpu())
|
||||
ctr = model.reset_counters()
|
||||
print(ctr)
|
||||
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
|
||||
|
||||
print(logfile)
|
Loading…
Reference in New Issue
Block a user