mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-27 13:03:12 +00:00
fixed SAE export code
This commit is contained in:
parent
899fbb7092
commit
ec27deddbf
@ -13,41 +13,45 @@ import aiohttp
|
||||
import aioitertools
|
||||
|
||||
from model import SAEConfig, SAE
|
||||
from shared import train_split, loaded_arrays, ckpt_path
|
||||
from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation
|
||||
|
||||
model_path, _ = ckpt_path(12991)
|
||||
model_path, _ = ckpt_path(111885)
|
||||
|
||||
model = SAE(SAEConfig(
|
||||
d_emb=1152,
|
||||
d_hidden=65536,
|
||||
top_k=128,
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
up_proj_bias=False
|
||||
))
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
model: SAEConfig
|
||||
lr: float
|
||||
weight_decay: float
|
||||
batch_size: int
|
||||
epochs: int
|
||||
compile: bool
|
||||
|
||||
state_dict = torch.load(model_path)
|
||||
model = SAE(state_dict["config"].model)
|
||||
|
||||
batch_size = 1024
|
||||
retrieve_batch_size = 512
|
||||
|
||||
with torch.inference_mode():
|
||||
model.load_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict["model"])
|
||||
model.eval()
|
||||
|
||||
validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):]
|
||||
print("loading val set")
|
||||
train_set_size = int(len(loaded_arrays_permutation) * train_split)
|
||||
val_set_size = len(loaded_arrays_permutation) - train_set_size
|
||||
print("sliced. executing.")
|
||||
|
||||
for batch_start in tqdm(range(0, len(validation_set), batch_size)):
|
||||
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ])
|
||||
for batch_start in tqdm(range(train_set_size, train_set_size+val_set_size, batch_size)):
|
||||
batch = numpy.stack([ numpy.frombuffer(embedding, dtype=numpy.float16) for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + batch_size]] ])
|
||||
batch = torch.Tensor(batch).to("cuda")
|
||||
reconstructions = model(batch).float()
|
||||
|
||||
feature_frequencies = model.reset_counters()
|
||||
features = model.down_proj.weight.cpu().numpy()
|
||||
|
||||
meme_search_backend = "http://localhost:1707/"
|
||||
memes_url = "https://i.osmarks.net/memes-or-something/"
|
||||
meme_search_url = "https://mse.osmarks.net/?e="
|
||||
meme_search_backend = "http://localhost:5601/"
|
||||
memes_url = ""
|
||||
meme_search_url = "https://nooscope.osmarks.net/?page=advanced&e="
|
||||
|
||||
def emb_url(embedding):
|
||||
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
|
||||
@ -71,7 +75,7 @@ async def get_exemplars():
|
||||
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
||||
"k": 10
|
||||
}) as res:
|
||||
return (await res.json())["matches"]
|
||||
return (await res.json())["matches"][:10]
|
||||
|
||||
exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
|
||||
negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
|
||||
@ -79,7 +83,7 @@ async def get_exemplars():
|
||||
for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
|
||||
f.write(f"""
|
||||
<h2>Feature {offset + base}</h2>
|
||||
<h3>Frequency {frequency / len(validation_set)}</h3>
|
||||
<h3>Frequency {frequency / val_set_size}</h3>
|
||||
<div>
|
||||
<h4><a href="{emb_url(feature)}">Max</a></h4>
|
||||
""")
|
||||
|
@ -74,6 +74,7 @@ class JSONEncoder(json.JSONEncoder):
|
||||
return str(o)
|
||||
else: return super().default(o)
|
||||
|
||||
"""
|
||||
logfile = f"logs/log-{time.time()}.jsonl"
|
||||
with open(logfile, "w") as log:
|
||||
steps = 0
|
||||
@ -90,7 +91,7 @@ with open(logfile, "w") as log:
|
||||
loss = loss.detach().cpu().item()
|
||||
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
|
||||
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
||||
if steps % 5000 == 0: save_ckpt(log, steps)
|
||||
#if steps % 5000 == 0: save_ckpt(log, steps)
|
||||
steps += 1
|
||||
|
||||
save_ckpt(log, steps)
|
||||
@ -100,3 +101,4 @@ with open(logfile, "w") as log:
|
||||
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
|
||||
|
||||
print(logfile)
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user