mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-28 13:33:11 +00:00
fixed SAE export code
This commit is contained in:
parent
899fbb7092
commit
ec27deddbf
@ -13,41 +13,45 @@ import aiohttp
|
|||||||
import aioitertools
|
import aioitertools
|
||||||
|
|
||||||
from model import SAEConfig, SAE
|
from model import SAEConfig, SAE
|
||||||
from shared import train_split, loaded_arrays, ckpt_path
|
from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation
|
||||||
|
|
||||||
model_path, _ = ckpt_path(12991)
|
model_path, _ = ckpt_path(111885)
|
||||||
|
|
||||||
model = SAE(SAEConfig(
|
@dataclass
|
||||||
d_emb=1152,
|
class TrainConfig:
|
||||||
d_hidden=65536,
|
model: SAEConfig
|
||||||
top_k=128,
|
lr: float
|
||||||
device="cuda",
|
weight_decay: float
|
||||||
dtype=torch.float32,
|
batch_size: int
|
||||||
up_proj_bias=False
|
epochs: int
|
||||||
))
|
compile: bool
|
||||||
|
|
||||||
state_dict = torch.load(model_path)
|
state_dict = torch.load(model_path)
|
||||||
|
model = SAE(state_dict["config"].model)
|
||||||
|
|
||||||
batch_size = 1024
|
batch_size = 1024
|
||||||
retrieve_batch_size = 512
|
retrieve_batch_size = 512
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict["model"])
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):]
|
print("loading val set")
|
||||||
|
train_set_size = int(len(loaded_arrays_permutation) * train_split)
|
||||||
|
val_set_size = len(loaded_arrays_permutation) - train_set_size
|
||||||
|
print("sliced. executing.")
|
||||||
|
|
||||||
for batch_start in tqdm(range(0, len(validation_set), batch_size)):
|
for batch_start in tqdm(range(train_set_size, train_set_size+val_set_size, batch_size)):
|
||||||
batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ])
|
batch = numpy.stack([ numpy.frombuffer(embedding, dtype=numpy.float16) for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + batch_size]] ])
|
||||||
batch = torch.Tensor(batch).to("cuda")
|
batch = torch.Tensor(batch).to("cuda")
|
||||||
reconstructions = model(batch).float()
|
reconstructions = model(batch).float()
|
||||||
|
|
||||||
feature_frequencies = model.reset_counters()
|
feature_frequencies = model.reset_counters()
|
||||||
features = model.down_proj.weight.cpu().numpy()
|
features = model.down_proj.weight.cpu().numpy()
|
||||||
|
|
||||||
meme_search_backend = "http://localhost:1707/"
|
meme_search_backend = "http://localhost:5601/"
|
||||||
memes_url = "https://i.osmarks.net/memes-or-something/"
|
memes_url = ""
|
||||||
meme_search_url = "https://mse.osmarks.net/?e="
|
meme_search_url = "https://nooscope.osmarks.net/?page=advanced&e="
|
||||||
|
|
||||||
def emb_url(embedding):
|
def emb_url(embedding):
|
||||||
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
|
return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8")
|
||||||
@ -71,7 +75,7 @@ async def get_exemplars():
|
|||||||
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
"terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry
|
||||||
"k": 10
|
"k": 10
|
||||||
}) as res:
|
}) as res:
|
||||||
return (await res.json())["matches"]
|
return (await res.json())["matches"][:10]
|
||||||
|
|
||||||
exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
|
exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ])
|
||||||
negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
|
negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ])
|
||||||
@ -79,7 +83,7 @@ async def get_exemplars():
|
|||||||
for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
|
for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]):
|
||||||
f.write(f"""
|
f.write(f"""
|
||||||
<h2>Feature {offset + base}</h2>
|
<h2>Feature {offset + base}</h2>
|
||||||
<h3>Frequency {frequency / len(validation_set)}</h3>
|
<h3>Frequency {frequency / val_set_size}</h3>
|
||||||
<div>
|
<div>
|
||||||
<h4><a href="{emb_url(feature)}">Max</a></h4>
|
<h4><a href="{emb_url(feature)}">Max</a></h4>
|
||||||
""")
|
""")
|
||||||
|
@ -74,6 +74,7 @@ class JSONEncoder(json.JSONEncoder):
|
|||||||
return str(o)
|
return str(o)
|
||||||
else: return super().default(o)
|
else: return super().default(o)
|
||||||
|
|
||||||
|
"""
|
||||||
logfile = f"logs/log-{time.time()}.jsonl"
|
logfile = f"logs/log-{time.time()}.jsonl"
|
||||||
with open(logfile, "w") as log:
|
with open(logfile, "w") as log:
|
||||||
steps = 0
|
steps = 0
|
||||||
@ -90,7 +91,7 @@ with open(logfile, "w") as log:
|
|||||||
loss = loss.detach().cpu().item()
|
loss = loss.detach().cpu().item()
|
||||||
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
|
t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}")
|
||||||
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
||||||
if steps % 5000 == 0: save_ckpt(log, steps)
|
#if steps % 5000 == 0: save_ckpt(log, steps)
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
save_ckpt(log, steps)
|
save_ckpt(log, steps)
|
||||||
@ -100,3 +101,4 @@ with open(logfile, "w") as log:
|
|||||||
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
|
numpy.save(f"ckpt/{steps}.counters.npy", ctr)
|
||||||
|
|
||||||
print(logfile)
|
print(logfile)
|
||||||
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user