1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-27 13:03:12 +00:00

fixed SAE export code

This commit is contained in:
osmarks 2025-01-31 17:12:25 +00:00
parent 899fbb7092
commit ec27deddbf
2 changed files with 26 additions and 20 deletions

View File

@ -13,41 +13,45 @@ import aiohttp
import aioitertools
from model import SAEConfig, SAE
from shared import train_split, loaded_arrays, ckpt_path
from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation
model_path, _ = ckpt_path(12991)
model_path, _ = ckpt_path(111885)
model = SAE(SAEConfig(
d_emb=1152,
d_hidden=65536,
top_k=128,
device="cuda",
dtype=torch.float32,
up_proj_bias=False
))
@dataclass
class TrainConfig:
model: SAEConfig
lr: float
weight_decay: float
batch_size: int
epochs: int
compile: bool
state_dict = torch.load(model_path)
model = SAE(state_dict["config"].model)
batch_size = 1024
retrieve_batch_size = 512
with torch.inference_mode():
model.load_state_dict(state_dict)
model.load_state_dict(state_dict["model"])
model.eval()
validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):]
print("loading val set")
train_set_size = int(len(loaded_arrays_permutation) * train_split)
val_set_size = len(loaded_arrays_permutation) - train_set_size
print("sliced. executing.")
for batch_start in tqdm(range(0, len(validation_set), batch_size)):
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ])
for batch_start in tqdm(range(train_set_size, train_set_size+val_set_size, batch_size)):
batch = numpy.stack([ numpy.frombuffer(embedding, dtype=numpy.float16) for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + batch_size]] ])
batch = torch.Tensor(batch).to("cuda")
reconstructions = model(batch).float()
feature_frequencies = model.reset_counters()
features = model.down_proj.weight.cpu().numpy()
meme_search_backend = "http://localhost:1707/"
memes_url = "https://i.osmarks.net/memes-or-something/"
meme_search_url = "https://mse.osmarks.net/?e="
meme_search_backend = "http://localhost:5601/"
memes_url = ""
meme_search_url = "https://nooscope.osmarks.net/?page=advanced&e="
def emb_url(embedding):
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
@ -71,7 +75,7 @@ async def get_exemplars():
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
"k": 10
}) as res:
return (await res.json())["matches"]
return (await res.json())["matches"][:10]
exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
@ -79,7 +83,7 @@ async def get_exemplars():
for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
f.write(f"""
<h2>Feature {offset + base}</h2>
<h3>Frequency {frequency / len(validation_set)}</h3>
<h3>Frequency {frequency / val_set_size}</h3>
<div>
<h4><a href="{emb_url(feature)}">Max</a></h4>
""")

View File

@ -74,6 +74,7 @@ class JSONEncoder(json.JSONEncoder):
return str(o)
else: return super().default(o)
"""
logfile = f"logs/log-{time.time()}.jsonl"
with open(logfile, "w") as log:
steps = 0
@ -90,7 +91,7 @@ with open(logfile, "w") as log:
loss = loss.detach().cpu().item()
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 5000 == 0: save_ckpt(log, steps)
#if steps % 5000 == 0: save_ckpt(log, steps)
steps += 1
save_ckpt(log, steps)
@ -100,3 +101,4 @@ with open(logfile, "w") as log:
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
print(logfile)
"""