diff --git a/sae/compare_plots.py b/sae/compare_plots.py
new file mode 100644
index 0000000..64d6f23
--- /dev/null
+++ b/sae/compare_plots.py
@@ -0,0 +1,40 @@
+# claude-3
+
+import json
+import matplotlib.pyplot as plt
+import sys
+
+logs = sys.argv[1:]
+
+def read_log(log):
+ # Read data from log.jsonl
+ data = []
+ with open(log, 'r') as file:
+ for line in file:
+ data.append(json.loads(line))
+
+ print(log, data[0]) # config
+
+ # Extract steps, loss, and val_loss
+ steps = [entry['step'] for entry in data if "loss" in entry]
+ loss = [entry['loss'] for entry in data if "loss" in entry]
+
+ # Calculate rolling average for loss
+ window_size = 50
+ rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)]
+ rolling_steps = steps[window_size-1:]
+
+ return rolling_steps, rolling_avg
+
+# Create the plot
+plt.figure(figsize=(10, 6))
+#plt.plot(steps, loss, label='Loss')
+for i, log in enumerate(logs):
+ rolling_steps, rolling_avg = read_log(log)
+ plt.plot(rolling_steps, rolling_avg, label=f"{i}")
+
+plt.xlabel('Steps')
+plt.ylabel('Loss')
+plt.legend()
+plt.grid(True)
+plt.show()
diff --git a/sae/export_features.py b/sae/export_features.py
new file mode 100644
index 0000000..d5cf64b
--- /dev/null
+++ b/sae/export_features.py
@@ -0,0 +1,93 @@
+import torch.nn
+import torch.nn.functional as F
+import torch
+import numpy
+import json
+import time
+from tqdm import tqdm
+from dataclasses import dataclass, asdict
+import numpy as np
+import base64
+import asyncio
+import aiohttp
+import aioitertools
+
+from model import SAEConfig, SAE
+from shared import train_split, loaded_arrays, ckpt_path
+
+model_path, _ = ckpt_path(12991)
+
+model = SAE(SAEConfig(
+ d_emb=1152,
+ d_hidden=65536,
+ top_k=128,
+ device="cuda",
+ dtype=torch.float32,
+ up_proj_bias=False
+))
+
+state_dict = torch.load(model_path)
+
+batch_size = 1024
+retrieve_batch_size = 512
+
+with torch.inference_mode():
+ model.load_state_dict(state_dict)
+ model.eval()
+
+ validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):]
+
+ for batch_start in tqdm(range(0, len(validation_set), batch_size)):
+ batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ])
+ batch = torch.Tensor(batch).to("cuda")
+ reconstructions = model(batch).float()
+
+ feature_frequencies = model.reset_counters()
+ features = model.up_proj.weight.cpu().numpy()
+
+meme_search_backend = "http://localhost:1707/"
+memes_url = "https://i.osmarks.net/memes-or-something/"
+meme_search_url = "https://mse.osmarks.net/?e="
+
+def emb_url(embedding):
+ return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
+
+async def get_exemplars():
+ async with aiohttp.ClientSession():
+ for base in tqdm(range(0, len(features), retrieve_batch_size)):
+ chunk = features[base:base + retrieve_batch_size]
+ with open(f"feature_dumps/features{base}.html", "w") as f:
+ f.write("""
+
Embeddings SAE Features
+
+ Embeddings SAE Features
""")
+
+ async def lookup(embedding):
+ async with aiohttp.request("POST", meme_search_backend, json={
+ "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
+ "k": 10
+ }) as res:
+ return (await res.json())["matches"]
+
+ exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
+ negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
+
+ for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
+ f.write(f"""
+ Feature {offset + base}
+ Frequency {frequency / len(validation_set)}
+
+
+ """)
+ for match in exemplars[offset]:
+ f.write(f'
')
+ f.write(f'
')
+ for match in negative_exemplars[offset]:
+ f.write(f'
')
+ f.write("
")
+
+asyncio.run(get_exemplars())
\ No newline at end of file
diff --git a/sae/model.py b/sae/model.py
new file mode 100644
index 0000000..d146441
--- /dev/null
+++ b/sae/model.py
@@ -0,0 +1,43 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dataclasses import dataclass
+from functools import partial
+
+@dataclass
+class SAEConfig:
+ d_emb: int
+ d_hidden: int
+ top_k: int
+ up_proj_bias: bool
+ device: str
+ dtype: torch.dtype
+
+class SAE(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.up_proj = nn.Linear(config.d_emb, config.d_hidden, dtype=config.dtype, device=config.device, bias=config.up_proj_bias)
+ self.down_proj = nn.Linear(config.d_hidden, config.d_emb, dtype=config.dtype, device=config.device)
+ self.down_proj.weight = nn.Parameter(self.up_proj.weight.T.clone())
+ self.feature_activation_counter = torch.zeros(config.d_hidden, dtype=torch.int32, device=config.device)
+ self.reset_counters()
+
+ def reset_counters(self):
+ old = self.feature_activation_counter.detach().cpu().numpy()
+ torch.zero_(self.feature_activation_counter)
+ return old
+
+ def forward(self, embs):
+ x = self.up_proj(embs)
+ x = F.relu(x)
+ topk = torch.kthvalue(x, k=(self.config.d_hidden - self.config.top_k), dim=-1)
+ thresholds = topk.values.unsqueeze(-1).expand_as(x)
+ zero = torch.zeros_like(x)
+ # If multiple values are the same, we don't actually pick exactly k values. This can happen quite easily if for some reason a lot of values are negative and thus get ReLUed to 0.
+ # This should not really happen but it does.
+ # This uses greater than rather than greater than or equal to work around this. We compensate for this by setting k off by one in the kthvalue call.
+ mask = x > thresholds
+ x = torch.where(mask, x, zero)
+ self.feature_activation_counter += mask.sum(0)
+ return self.down_proj(x)
\ No newline at end of file
diff --git a/sae/shared.py b/sae/shared.py
new file mode 100644
index 0000000..f01fdbd
--- /dev/null
+++ b/sae/shared.py
@@ -0,0 +1,12 @@
+import torch
+import pyarrow as pa
+
+torch.set_float32_matmul_precision("high")
+
+with pa.memory_map("../../sample_1m.arrow", "r") as source:
+ loaded_arrays = pa.ipc.open_file(source).read_all()
+
+train_split = 0.8
+
+def ckpt_path(steps):
+ return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt"
\ No newline at end of file
diff --git a/sae/shuffle.py b/sae/shuffle.py
new file mode 100644
index 0000000..158ad9a
--- /dev/null
+++ b/sae/shuffle.py
@@ -0,0 +1,51 @@
+# claude output
+
+import pyarrow as pa
+import numpy as np
+import pyarrow.compute as pc
+
+def shuffle_arrow_file(input_file, output_file, chunk_size=100000):
+ # Open the input file
+ with pa.memory_map(input_file, 'r') as source:
+ reader = pa.ipc.open_file(source)
+
+ # Get the schema and create a writer for the output file
+ schema = reader.schema
+ with pa.OSFile(output_file, 'wb') as sink:
+ writer = pa.ipc.new_file(sink, schema)
+
+ # Calculate total number of rows
+ total_rows = reader.num_record_batches
+ all_batches = [reader.get_batch(i) for i in range(total_rows)]
+ total_rows = sum(batch.num_rows for batch in all_batches)
+
+ # Generate shuffled indices
+ indices = np.random.permutation(total_rows)
+
+ # Process in chunks
+ for i in range(0, total_rows, chunk_size):
+ # Get indices for this chunk
+ chunk_indices = indices[i:i+chunk_size]
+
+ # Take rows using these indices
+ chunk_data = []
+ for idx in chunk_indices:
+ batch_idx = 0
+ row_idx = idx
+ while row_idx >= all_batches[batch_idx].num_rows:
+ row_idx -= all_batches[batch_idx].num_rows
+ batch_idx += 1
+ chunk_data.append(all_batches[batch_idx].slice(row_idx, 1))
+
+ chunk = pa.Table.from_batches(chunk_data)
+
+ # Write the chunk
+ writer.write_table(chunk)
+
+ # Close the writer
+ writer.close()
+
+# Usage
+input_file = "../../sample_1m.arrow"
+output_file = "shuffled_trainset.arrow"
+shuffle_arrow_file(input_file, output_file)
diff --git a/sae/train.py b/sae/train.py
new file mode 100644
index 0000000..2c5e37e
--- /dev/null
+++ b/sae/train.py
@@ -0,0 +1,101 @@
+import torch.nn
+import torch.nn.functional as F
+import torch
+import numpy
+import json
+import time
+from tqdm import tqdm
+from dataclasses import dataclass, asdict
+
+from model import SAEConfig, SAE
+from shared import train_split, loaded_arrays, ckpt_path
+
+device = "cuda"
+
+@dataclass
+class TrainConfig:
+ model: SAEConfig
+ lr: float
+ weight_decay: float
+ batch_size: int
+ epochs: int
+ compile: bool
+
+config = TrainConfig(
+ model=SAEConfig(
+ d_emb=1152,
+ d_hidden=65536,
+ top_k=128,
+ device=device,
+ dtype=torch.float32,
+ up_proj_bias=False
+ ),
+ lr=3e-4,
+ weight_decay=0.0,
+ batch_size=64,
+ epochs=5,
+ compile=True,
+)
+
+def exprange(min, max, n):
+ lmin, lmax = math.log(min), math.log(max)
+ step = (lmax - lmin) / (n - 1)
+ return (math.exp(lmin + step * i) for i in range(n))
+
+model = SAE(config.model)
+params = sum(p.numel() for p in model.parameters())
+print(f"{params/1e6:.1f}M parameters")
+print(model)
+
+optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
+
+def train_step(model, batch):
+ optimizer.zero_grad()
+ reconstructions = model(batch).float()
+ loss = F.mse_loss(reconstructions, batch)
+ loss.backward()
+ optimizer.step()
+ return loss
+
+if config.compile:
+ print("compiling...")
+ train_step = torch.compile(train_step)
+
+def save_ckpt(log, steps):
+ #print("saving...")
+ modelc, optimc = ckpt_path(steps)
+ torch.save(optimizer.state_dict(), optimc)
+ torch.save({"model": model.state_dict(), "config": config}, modelc)
+
+class JSONEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, torch.dtype):
+ return str(o)
+ else: return super().default(o)
+
+logfile = f"logs/log-{time.time()}.jsonl"
+with open(logfile, "w") as log:
+ steps = 0
+ log.write(JSONEncoder().encode(asdict(config)) + "\n")
+ for epoch in range(config.epochs):
+ batch = []
+ t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size))
+ for batch_start in t:
+ batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ])
+
+ if len(batch) == config.batch_size:
+ batch = torch.Tensor(batch).to(device)
+ loss = train_step(model, batch)
+ loss = loss.detach().cpu().item()
+ t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
+ log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
+ if steps % 5000 == 0: save_ckpt(log, steps)
+ steps += 1
+
+ save_ckpt(log, steps)
+ print(model.feature_activation_counter.cpu())
+ ctr = model.reset_counters()
+ print(ctr)
+ numpy.save(f"ckpt/{steps}.counters.npy", ctr)
+
+print(logfile)
\ No newline at end of file