From ec27deddbf943019652ef04af535e1e24c791eae Mon Sep 17 00:00:00 2001 From: osmarks Date: Fri, 31 Jan 2025 17:12:25 +0000 Subject: [PATCH] fixed SAE export code --- sae/export_features.py | 42 +++++++++++++++++++++++------------------- sae/train.py | 4 +++- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/sae/export_features.py b/sae/export_features.py index 2fd2b96..98c4932 100644 --- a/sae/export_features.py +++ b/sae/export_features.py @@ -13,41 +13,45 @@ import aiohttp import aioitertools from model import SAEConfig, SAE -from shared import train_split, loaded_arrays, ckpt_path +from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation -model_path, _ = ckpt_path(12991) +model_path, _ = ckpt_path(111885) -model = SAE(SAEConfig( - d_emb=1152, - d_hidden=65536, - top_k=128, - device="cuda", - dtype=torch.float32, - up_proj_bias=False -)) +@dataclass +class TrainConfig: + model: SAEConfig + lr: float + weight_decay: float + batch_size: int + epochs: int + compile: bool state_dict = torch.load(model_path) +model = SAE(state_dict["config"].model) batch_size = 1024 retrieve_batch_size = 512 with torch.inference_mode(): - model.load_state_dict(state_dict) + model.load_state_dict(state_dict["model"]) model.eval() - validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):] + print("loading val set") + train_set_size = int(len(loaded_arrays_permutation) * train_split) + val_set_size = len(loaded_arrays_permutation) - train_set_size + print("sliced. executing.") - for batch_start in tqdm(range(0, len(validation_set), batch_size)): - batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ]) + for batch_start in tqdm(range(train_set_size, train_set_size+val_set_size, batch_size)): + batch = numpy.stack([ numpy.frombuffer(embedding, dtype=numpy.float16) for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + batch_size]] ]) batch = torch.Tensor(batch).to("cuda") reconstructions = model(batch).float() feature_frequencies = model.reset_counters() features = model.down_proj.weight.cpu().numpy() -meme_search_backend = "http://localhost:1707/" -memes_url = "https://i.osmarks.net/memes-or-something/" -meme_search_url = "https://mse.osmarks.net/?e=" +meme_search_backend = "http://localhost:5601/" +memes_url = "" +meme_search_url = "https://nooscope.osmarks.net/?page=advanced&e=" def emb_url(embedding): return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") @@ -71,7 +75,7 @@ async def get_exemplars(): "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry "k": 10 }) as res: - return (await res.json())["matches"] + return (await res.json())["matches"][:10] exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ]) negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ]) @@ -79,7 +83,7 @@ async def get_exemplars(): for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]): f.write(f"""

Feature {offset + base}

-

Frequency {frequency / len(validation_set)}

+

Frequency {frequency / val_set_size}

Max

""") diff --git a/sae/train.py b/sae/train.py index 46816f5..d53fc84 100644 --- a/sae/train.py +++ b/sae/train.py @@ -74,6 +74,7 @@ class JSONEncoder(json.JSONEncoder): return str(o) else: return super().default(o) +""" logfile = f"logs/log-{time.time()}.jsonl" with open(logfile, "w") as log: steps = 0 @@ -90,7 +91,7 @@ with open(logfile, "w") as log: loss = loss.detach().cpu().item() t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}") log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") - if steps % 5000 == 0: save_ckpt(log, steps) + #if steps % 5000 == 0: save_ckpt(log, steps) steps += 1 save_ckpt(log, steps) @@ -100,3 +101,4 @@ with open(logfile, "w") as log: numpy.save(f"ckpt/{steps}.counters.npy", ctr) print(logfile) +"""