1
0
mirror of https://github.com/osmarks/meme-search-engine.git synced 2025-04-28 13:33:11 +00:00

fixed SAE export code

This commit is contained in:
osmarks 2025-01-31 17:12:25 +00:00
parent 899fbb7092
commit ec27deddbf
2 changed files with 26 additions and 20 deletions

View File

@ -13,41 +13,45 @@ import aiohttp
import aioitertools import aioitertools
from model import SAEConfig, SAE from model import SAEConfig, SAE
from shared import train_split, loaded_arrays, ckpt_path from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation
model_path, _ = ckpt_path(12991) model_path, _ = ckpt_path(111885)
model = SAE(SAEConfig( @dataclass
d_emb=1152, class TrainConfig:
d_hidden=65536, model: SAEConfig
top_k=128, lr: float
device="cuda", weight_decay: float
dtype=torch.float32, batch_size: int
up_proj_bias=False epochs: int
)) compile: bool
state_dict = torch.load(model_path) state_dict = torch.load(model_path)
model = SAE(state_dict["config"].model)
batch_size = 1024 batch_size = 1024
retrieve_batch_size = 512 retrieve_batch_size = 512
with torch.inference_mode(): with torch.inference_mode():
model.load_state_dict(state_dict) model.load_state_dict(state_dict["model"])
model.eval() model.eval()
validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):] print("loading val set")
train_set_size = int(len(loaded_arrays_permutation) * train_split)
val_set_size = len(loaded_arrays_permutation) - train_set_size
print("sliced. executing.")
for batch_start in tqdm(range(0, len(validation_set), batch_size)): for batch_start in tqdm(range(train_set_size, train_set_size+val_set_size, batch_size)):
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ]) batch = numpy.stack([ numpy.frombuffer(embedding, dtype=numpy.float16) for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + batch_size]] ])
batch = torch.Tensor(batch).to("cuda") batch = torch.Tensor(batch).to("cuda")
reconstructions = model(batch).float() reconstructions = model(batch).float()
feature_frequencies = model.reset_counters() feature_frequencies = model.reset_counters()
features = model.down_proj.weight.cpu().numpy() features = model.down_proj.weight.cpu().numpy()
meme_search_backend = "http://localhost:1707/" meme_search_backend = "http://localhost:5601/"
memes_url = "https://i.osmarks.net/memes-or-something/" memes_url = ""
meme_search_url = "https://mse.osmarks.net/?e=" meme_search_url = "https://nooscope.osmarks.net/?page=advanced&e="
def emb_url(embedding): def emb_url(embedding):
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
@ -71,7 +75,7 @@ async def get_exemplars():
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
"k": 10 "k": 10
}) as res: }) as res:
return (await res.json())["matches"] return (await res.json())["matches"][:10]
exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ]) exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ]) negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
@ -79,7 +83,7 @@ async def get_exemplars():
for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]): for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
f.write(f""" f.write(f"""
<h2>Feature {offset + base}</h2> <h2>Feature {offset + base}</h2>
<h3>Frequency {frequency / len(validation_set)}</h3> <h3>Frequency {frequency / val_set_size}</h3>
<div> <div>
<h4><a href="{emb_url(feature)}">Max</a></h4> <h4><a href="{emb_url(feature)}">Max</a></h4>
""") """)

View File

@ -74,6 +74,7 @@ class JSONEncoder(json.JSONEncoder):
return str(o) return str(o)
else: return super().default(o) else: return super().default(o)
"""
logfile = f"logs/log-{time.time()}.jsonl" logfile = f"logs/log-{time.time()}.jsonl"
with open(logfile, "w") as log: with open(logfile, "w") as log:
steps = 0 steps = 0
@ -90,7 +91,7 @@ with open(logfile, "w") as log:
loss = loss.detach().cpu().item() loss = loss.detach().cpu().item()
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}") t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
if steps % 5000 == 0: save_ckpt(log, steps) #if steps % 5000 == 0: save_ckpt(log, steps)
steps += 1 steps += 1
save_ckpt(log, steps) save_ckpt(log, steps)
@ -100,3 +101,4 @@ with open(logfile, "w") as log:
numpy.save(f"ckpt/{steps}.counters.npy", ctr) numpy.save(f"ckpt/{steps}.counters.npy", ctr)
print(logfile) print(logfile)
"""