mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-30 23:12:58 +00:00 
			
		
		
		
	fixed SAE export code
This commit is contained in:
		| @@ -13,41 +13,45 @@ import aiohttp | ||||
| import aioitertools | ||||
|  | ||||
| from model import SAEConfig, SAE | ||||
| from shared import train_split, loaded_arrays, ckpt_path | ||||
| from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation | ||||
|  | ||||
| model_path, _ = ckpt_path(12991) | ||||
| model_path, _ = ckpt_path(111885) | ||||
|  | ||||
| model = SAE(SAEConfig( | ||||
|     d_emb=1152, | ||||
|     d_hidden=65536, | ||||
|     top_k=128, | ||||
|     device="cuda", | ||||
|     dtype=torch.float32, | ||||
|     up_proj_bias=False | ||||
| )) | ||||
| @dataclass | ||||
| class TrainConfig: | ||||
|     model: SAEConfig | ||||
|     lr: float | ||||
|     weight_decay: float | ||||
|     batch_size: int | ||||
|     epochs: int | ||||
|     compile: bool | ||||
|  | ||||
| state_dict = torch.load(model_path) | ||||
| model = SAE(state_dict["config"].model) | ||||
|  | ||||
| batch_size = 1024 | ||||
| retrieve_batch_size = 512 | ||||
|  | ||||
| with torch.inference_mode(): | ||||
|     model.load_state_dict(state_dict) | ||||
|     model.load_state_dict(state_dict["model"]) | ||||
|     model.eval() | ||||
|  | ||||
|     validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):] | ||||
|     print("loading val set") | ||||
|     train_set_size = int(len(loaded_arrays_permutation) * train_split) | ||||
|     val_set_size = len(loaded_arrays_permutation) - train_set_size | ||||
|     print("sliced. executing.") | ||||
|  | ||||
|     for batch_start in tqdm(range(0, len(validation_set), batch_size)): | ||||
|         batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ]) | ||||
|     for batch_start in tqdm(range(train_set_size, train_set_size+val_set_size, batch_size)): | ||||
|         batch = numpy.stack([ numpy.frombuffer(embedding, dtype=numpy.float16) for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + batch_size]] ]) | ||||
|         batch = torch.Tensor(batch).to("cuda") | ||||
|         reconstructions = model(batch).float() | ||||
|  | ||||
|     feature_frequencies = model.reset_counters() | ||||
|     features = model.down_proj.weight.cpu().numpy() | ||||
|  | ||||
| meme_search_backend = "http://localhost:1707/" | ||||
| memes_url = "https://i.osmarks.net/memes-or-something/" | ||||
| meme_search_url = "https://mse.osmarks.net/?e=" | ||||
| meme_search_backend = "http://localhost:5601/" | ||||
| memes_url = "" | ||||
| meme_search_url = "https://nooscope.osmarks.net/?page=advanced&e=" | ||||
|  | ||||
| def emb_url(embedding): | ||||
|     return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") | ||||
| @@ -71,7 +75,7 @@ async def get_exemplars(): | ||||
|                         "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry | ||||
|                         "k": 10 | ||||
|                     }) as res: | ||||
|                         return (await res.json())["matches"] | ||||
|                         return (await res.json())["matches"][:10] | ||||
|  | ||||
|                 exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ]) | ||||
|                 negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ]) | ||||
| @@ -79,7 +83,7 @@ async def get_exemplars(): | ||||
|                 for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]): | ||||
|                     f.write(f""" | ||||
|     <h2>Feature {offset + base}</h2> | ||||
|     <h3>Frequency {frequency / len(validation_set)}</h3> | ||||
|     <h3>Frequency {frequency / val_set_size}</h3> | ||||
|     <div> | ||||
|     <h4><a href="{emb_url(feature)}">Max</a></h4> | ||||
|     """) | ||||
|   | ||||
| @@ -74,6 +74,7 @@ class JSONEncoder(json.JSONEncoder): | ||||
|             return str(o) | ||||
|         else: return super().default(o) | ||||
|  | ||||
| """ | ||||
| logfile = f"logs/log-{time.time()}.jsonl" | ||||
| with open(logfile, "w") as log: | ||||
|     steps = 0 | ||||
| @@ -90,7 +91,7 @@ with open(logfile, "w") as log: | ||||
|                 loss = loss.detach().cpu().item() | ||||
|                 t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}") | ||||
|                 log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") | ||||
|                 if steps % 5000 == 0: save_ckpt(log, steps) | ||||
|                 #if steps % 5000 == 0: save_ckpt(log, steps) | ||||
|                 steps += 1 | ||||
|  | ||||
|         save_ckpt(log, steps) | ||||
| @@ -100,3 +101,4 @@ with open(logfile, "w") as log: | ||||
|         numpy.save(f"ckpt/{steps}.counters.npy", ctr) | ||||
|  | ||||
| print(logfile) | ||||
| """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 osmarks
					osmarks