import torch.nn import torch.nn.functional as F import torch import numpy import json import time from tqdm import tqdm from dataclasses import dataclass, asdict import numpy as np import base64 import asyncio import aiohttp import aioitertools from model import SAEConfig, SAE from shared import train_split, loaded_arrays, ckpt_path model_path, _ = ckpt_path(12991) model = SAE(SAEConfig( d_emb=1152, d_hidden=65536, top_k=128, device="cuda", dtype=torch.float32, up_proj_bias=False )) state_dict = torch.load(model_path) batch_size = 1024 retrieve_batch_size = 512 with torch.inference_mode(): model.load_state_dict(state_dict) model.eval() validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):] for batch_start in tqdm(range(0, len(validation_set), batch_size)): batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ]) batch = torch.Tensor(batch).to("cuda") reconstructions = model(batch).float() feature_frequencies = model.reset_counters() features = model.up_proj.weight.cpu().numpy() meme_search_backend = "http://localhost:1707/" memes_url = "https://i.osmarks.net/memes-or-something/" meme_search_url = "https://mse.osmarks.net/?e=" def emb_url(embedding): return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") async def get_exemplars(): async with aiohttp.ClientSession(): for base in tqdm(range(0, len(features), retrieve_batch_size)): chunk = features[base:base + retrieve_batch_size] with open(f"feature_dumps/features{base}.html", "w") as f: f.write("""