diff --git a/sae/compare_plots.py b/sae/compare_plots.py new file mode 100644 index 0000000..64d6f23 --- /dev/null +++ b/sae/compare_plots.py @@ -0,0 +1,40 @@ +# claude-3 + +import json +import matplotlib.pyplot as plt +import sys + +logs = sys.argv[1:] + +def read_log(log): + # Read data from log.jsonl + data = [] + with open(log, 'r') as file: + for line in file: + data.append(json.loads(line)) + + print(log, data[0]) # config + + # Extract steps, loss, and val_loss + steps = [entry['step'] for entry in data if "loss" in entry] + loss = [entry['loss'] for entry in data if "loss" in entry] + + # Calculate rolling average for loss + window_size = 50 + rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)] + rolling_steps = steps[window_size-1:] + + return rolling_steps, rolling_avg + +# Create the plot +plt.figure(figsize=(10, 6)) +#plt.plot(steps, loss, label='Loss') +for i, log in enumerate(logs): + rolling_steps, rolling_avg = read_log(log) + plt.plot(rolling_steps, rolling_avg, label=f"{i}") + +plt.xlabel('Steps') +plt.ylabel('Loss') +plt.legend() +plt.grid(True) +plt.show() diff --git a/sae/export_features.py b/sae/export_features.py new file mode 100644 index 0000000..d5cf64b --- /dev/null +++ b/sae/export_features.py @@ -0,0 +1,93 @@ +import torch.nn +import torch.nn.functional as F +import torch +import numpy +import json +import time +from tqdm import tqdm +from dataclasses import dataclass, asdict +import numpy as np +import base64 +import asyncio +import aiohttp +import aioitertools + +from model import SAEConfig, SAE +from shared import train_split, loaded_arrays, ckpt_path + +model_path, _ = ckpt_path(12991) + +model = SAE(SAEConfig( + d_emb=1152, + d_hidden=65536, + top_k=128, + device="cuda", + dtype=torch.float32, + up_proj_bias=False +)) + +state_dict = torch.load(model_path) + +batch_size = 1024 +retrieve_batch_size = 512 + +with torch.inference_mode(): + model.load_state_dict(state_dict) + model.eval() + + validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):] + + for batch_start in tqdm(range(0, len(validation_set), batch_size)): + batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ]) + batch = torch.Tensor(batch).to("cuda") + reconstructions = model(batch).float() + + feature_frequencies = model.reset_counters() + features = model.up_proj.weight.cpu().numpy() + +meme_search_backend = "http://localhost:1707/" +memes_url = "https://i.osmarks.net/memes-or-something/" +meme_search_url = "https://mse.osmarks.net/?e=" + +def emb_url(embedding): + return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") + +async def get_exemplars(): + async with aiohttp.ClientSession(): + for base in tqdm(range(0, len(features), retrieve_batch_size)): + chunk = features[base:base + retrieve_batch_size] + with open(f"feature_dumps/features{base}.html", "w") as f: + f.write(""" + Embeddings SAE Features + +

Embeddings SAE Features

""") + + async def lookup(embedding): + async with aiohttp.request("POST", meme_search_backend, json={ + "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry + "k": 10 + }) as res: + return (await res.json())["matches"] + + exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ]) + negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ]) + + for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]): + f.write(f""" +

Feature {offset + base}

+

Frequency {frequency / len(validation_set)}

+
+

Max

+ """) + for match in exemplars[offset]: + f.write(f'') + f.write(f'

Min

') + for match in negative_exemplars[offset]: + f.write(f'') + f.write("
") + +asyncio.run(get_exemplars()) \ No newline at end of file diff --git a/sae/model.py b/sae/model.py new file mode 100644 index 0000000..d146441 --- /dev/null +++ b/sae/model.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from functools import partial + +@dataclass +class SAEConfig: + d_emb: int + d_hidden: int + top_k: int + up_proj_bias: bool + device: str + dtype: torch.dtype + +class SAE(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.up_proj = nn.Linear(config.d_emb, config.d_hidden, dtype=config.dtype, device=config.device, bias=config.up_proj_bias) + self.down_proj = nn.Linear(config.d_hidden, config.d_emb, dtype=config.dtype, device=config.device) + self.down_proj.weight = nn.Parameter(self.up_proj.weight.T.clone()) + self.feature_activation_counter = torch.zeros(config.d_hidden, dtype=torch.int32, device=config.device) + self.reset_counters() + + def reset_counters(self): + old = self.feature_activation_counter.detach().cpu().numpy() + torch.zero_(self.feature_activation_counter) + return old + + def forward(self, embs): + x = self.up_proj(embs) + x = F.relu(x) + topk = torch.kthvalue(x, k=(self.config.d_hidden - self.config.top_k), dim=-1) + thresholds = topk.values.unsqueeze(-1).expand_as(x) + zero = torch.zeros_like(x) + # If multiple values are the same, we don't actually pick exactly k values. This can happen quite easily if for some reason a lot of values are negative and thus get ReLUed to 0. + # This should not really happen but it does. + # This uses greater than rather than greater than or equal to work around this. We compensate for this by setting k off by one in the kthvalue call. + mask = x > thresholds + x = torch.where(mask, x, zero) + self.feature_activation_counter += mask.sum(0) + return self.down_proj(x) \ No newline at end of file diff --git a/sae/shared.py b/sae/shared.py new file mode 100644 index 0000000..f01fdbd --- /dev/null +++ b/sae/shared.py @@ -0,0 +1,12 @@ +import torch +import pyarrow as pa + +torch.set_float32_matmul_precision("high") + +with pa.memory_map("../../sample_1m.arrow", "r") as source: + loaded_arrays = pa.ipc.open_file(source).read_all() + +train_split = 0.8 + +def ckpt_path(steps): + return f"ckpt/{steps}.pt", f"ckpt/{steps}.optim.pt" \ No newline at end of file diff --git a/sae/shuffle.py b/sae/shuffle.py new file mode 100644 index 0000000..158ad9a --- /dev/null +++ b/sae/shuffle.py @@ -0,0 +1,51 @@ +# claude output + +import pyarrow as pa +import numpy as np +import pyarrow.compute as pc + +def shuffle_arrow_file(input_file, output_file, chunk_size=100000): + # Open the input file + with pa.memory_map(input_file, 'r') as source: + reader = pa.ipc.open_file(source) + + # Get the schema and create a writer for the output file + schema = reader.schema + with pa.OSFile(output_file, 'wb') as sink: + writer = pa.ipc.new_file(sink, schema) + + # Calculate total number of rows + total_rows = reader.num_record_batches + all_batches = [reader.get_batch(i) for i in range(total_rows)] + total_rows = sum(batch.num_rows for batch in all_batches) + + # Generate shuffled indices + indices = np.random.permutation(total_rows) + + # Process in chunks + for i in range(0, total_rows, chunk_size): + # Get indices for this chunk + chunk_indices = indices[i:i+chunk_size] + + # Take rows using these indices + chunk_data = [] + for idx in chunk_indices: + batch_idx = 0 + row_idx = idx + while row_idx >= all_batches[batch_idx].num_rows: + row_idx -= all_batches[batch_idx].num_rows + batch_idx += 1 + chunk_data.append(all_batches[batch_idx].slice(row_idx, 1)) + + chunk = pa.Table.from_batches(chunk_data) + + # Write the chunk + writer.write_table(chunk) + + # Close the writer + writer.close() + +# Usage +input_file = "../../sample_1m.arrow" +output_file = "shuffled_trainset.arrow" +shuffle_arrow_file(input_file, output_file) diff --git a/sae/train.py b/sae/train.py new file mode 100644 index 0000000..2c5e37e --- /dev/null +++ b/sae/train.py @@ -0,0 +1,101 @@ +import torch.nn +import torch.nn.functional as F +import torch +import numpy +import json +import time +from tqdm import tqdm +from dataclasses import dataclass, asdict + +from model import SAEConfig, SAE +from shared import train_split, loaded_arrays, ckpt_path + +device = "cuda" + +@dataclass +class TrainConfig: + model: SAEConfig + lr: float + weight_decay: float + batch_size: int + epochs: int + compile: bool + +config = TrainConfig( + model=SAEConfig( + d_emb=1152, + d_hidden=65536, + top_k=128, + device=device, + dtype=torch.float32, + up_proj_bias=False + ), + lr=3e-4, + weight_decay=0.0, + batch_size=64, + epochs=5, + compile=True, +) + +def exprange(min, max, n): + lmin, lmax = math.log(min), math.log(max) + step = (lmax - lmin) / (n - 1) + return (math.exp(lmin + step * i) for i in range(n)) + +model = SAE(config.model) +params = sum(p.numel() for p in model.parameters()) +print(f"{params/1e6:.1f}M parameters") +print(model) + +optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) + +def train_step(model, batch): + optimizer.zero_grad() + reconstructions = model(batch).float() + loss = F.mse_loss(reconstructions, batch) + loss.backward() + optimizer.step() + return loss + +if config.compile: + print("compiling...") + train_step = torch.compile(train_step) + +def save_ckpt(log, steps): + #print("saving...") + modelc, optimc = ckpt_path(steps) + torch.save(optimizer.state_dict(), optimc) + torch.save({"model": model.state_dict(), "config": config}, modelc) + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, torch.dtype): + return str(o) + else: return super().default(o) + +logfile = f"logs/log-{time.time()}.jsonl" +with open(logfile, "w") as log: + steps = 0 + log.write(JSONEncoder().encode(asdict(config)) + "\n") + for epoch in range(config.epochs): + batch = [] + t = tqdm(range(0, int(len(loaded_arrays) * train_split), config.batch_size)) + for batch_start in t: + batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in loaded_arrays["embedding"][batch_start:batch_start + config.batch_size] ]) + + if len(batch) == config.batch_size: + batch = torch.Tensor(batch).to(device) + loss = train_step(model, batch) + loss = loss.detach().cpu().item() + t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}") + log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") + if steps % 5000 == 0: save_ckpt(log, steps) + steps += 1 + + save_ckpt(log, steps) + print(model.feature_activation_counter.cpu()) + ctr = model.reset_counters() + print(ctr) + numpy.save(f"ckpt/{steps}.counters.npy", ctr) + +print(logfile) \ No newline at end of file