diff --git a/sae/export_features.py b/sae/export_features.py index 2fd2b96..98c4932 100644 --- a/sae/export_features.py +++ b/sae/export_features.py @@ -13,41 +13,45 @@ import aiohttp import aioitertools from model import SAEConfig, SAE -from shared import train_split, loaded_arrays, ckpt_path +from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation -model_path, _ = ckpt_path(12991) +model_path, _ = ckpt_path(111885) -model = SAE(SAEConfig( - d_emb=1152, - d_hidden=65536, - top_k=128, - device="cuda", - dtype=torch.float32, - up_proj_bias=False -)) +@dataclass +class TrainConfig: + model: SAEConfig + lr: float + weight_decay: float + batch_size: int + epochs: int + compile: bool state_dict = torch.load(model_path) +model = SAE(state_dict["config"].model) batch_size = 1024 retrieve_batch_size = 512 with torch.inference_mode(): - model.load_state_dict(state_dict) + model.load_state_dict(state_dict["model"]) model.eval() - validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):] + print("loading val set") + train_set_size = int(len(loaded_arrays_permutation) * train_split) + val_set_size = len(loaded_arrays_permutation) - train_set_size + print("sliced. executing.") - for batch_start in tqdm(range(0, len(validation_set), batch_size)): - batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ]) + for batch_start in tqdm(range(train_set_size, train_set_size+val_set_size, batch_size)): + batch = numpy.stack([ numpy.frombuffer(embedding, dtype=numpy.float16) for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + batch_size]] ]) batch = torch.Tensor(batch).to("cuda") reconstructions = model(batch).float() feature_frequencies = model.reset_counters() features = model.down_proj.weight.cpu().numpy() -meme_search_backend = "http://localhost:1707/" -memes_url = "https://i.osmarks.net/memes-or-something/" -meme_search_url = "https://mse.osmarks.net/?e=" +meme_search_backend = "http://localhost:5601/" +memes_url = "" +meme_search_url = "https://nooscope.osmarks.net/?page=advanced&e=" def emb_url(embedding): return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") @@ -71,7 +75,7 @@ async def get_exemplars(): "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry "k": 10 }) as res: - return (await res.json())["matches"] + return (await res.json())["matches"][:10] exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ]) negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ]) @@ -79,7 +83,7 @@ async def get_exemplars(): for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]): f.write(f"""