mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-31 07:13:04 +00:00 
			
		
		
		
	fixed SAE export code
This commit is contained in:
		| @@ -13,41 +13,45 @@ import aiohttp | |||||||
| import aioitertools | import aioitertools | ||||||
|  |  | ||||||
| from model import SAEConfig, SAE | from model import SAEConfig, SAE | ||||||
| from shared import train_split, loaded_arrays, ckpt_path | from shared import train_split, loaded_arrays, ckpt_path, loaded_arrays_permutation | ||||||
|  |  | ||||||
| model_path, _ = ckpt_path(12991) | model_path, _ = ckpt_path(111885) | ||||||
|  |  | ||||||
| model = SAE(SAEConfig( | @dataclass | ||||||
|     d_emb=1152, | class TrainConfig: | ||||||
|     d_hidden=65536, |     model: SAEConfig | ||||||
|     top_k=128, |     lr: float | ||||||
|     device="cuda", |     weight_decay: float | ||||||
|     dtype=torch.float32, |     batch_size: int | ||||||
|     up_proj_bias=False |     epochs: int | ||||||
| )) |     compile: bool | ||||||
|  |  | ||||||
| state_dict = torch.load(model_path) | state_dict = torch.load(model_path) | ||||||
|  | model = SAE(state_dict["config"].model) | ||||||
|  |  | ||||||
| batch_size = 1024 | batch_size = 1024 | ||||||
| retrieve_batch_size = 512 | retrieve_batch_size = 512 | ||||||
|  |  | ||||||
| with torch.inference_mode(): | with torch.inference_mode(): | ||||||
|     model.load_state_dict(state_dict) |     model.load_state_dict(state_dict["model"]) | ||||||
|     model.eval() |     model.eval() | ||||||
|  |  | ||||||
|     validation_set = loaded_arrays["embedding"][int(len(loaded_arrays["embedding"]) * train_split):] |     print("loading val set") | ||||||
|  |     train_set_size = int(len(loaded_arrays_permutation) * train_split) | ||||||
|  |     val_set_size = len(loaded_arrays_permutation) - train_set_size | ||||||
|  |     print("sliced. executing.") | ||||||
|  |  | ||||||
|     for batch_start in tqdm(range(0, len(validation_set), batch_size)): |     for batch_start in tqdm(range(train_set_size, train_set_size+val_set_size, batch_size)): | ||||||
|         batch = numpy.stack([ numpy.frombuffer(embedding.as_py(), dtype=numpy.float16) for embedding in validation_set[batch_start:batch_start + 1024] ]) |         batch = numpy.stack([ numpy.frombuffer(embedding, dtype=numpy.float16) for embedding in loaded_arrays[loaded_arrays_permutation[batch_start:batch_start + batch_size]] ]) | ||||||
|         batch = torch.Tensor(batch).to("cuda") |         batch = torch.Tensor(batch).to("cuda") | ||||||
|         reconstructions = model(batch).float() |         reconstructions = model(batch).float() | ||||||
|  |  | ||||||
|     feature_frequencies = model.reset_counters() |     feature_frequencies = model.reset_counters() | ||||||
|     features = model.down_proj.weight.cpu().numpy() |     features = model.down_proj.weight.cpu().numpy() | ||||||
|  |  | ||||||
| meme_search_backend = "http://localhost:1707/" | meme_search_backend = "http://localhost:5601/" | ||||||
| memes_url = "https://i.osmarks.net/memes-or-something/" | memes_url = "" | ||||||
| meme_search_url = "https://mse.osmarks.net/?e=" | meme_search_url = "https://nooscope.osmarks.net/?page=advanced&e=" | ||||||
|  |  | ||||||
| def emb_url(embedding): | def emb_url(embedding): | ||||||
|     return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") |     return meme_search_url + base64.urlsafe_b64encode(embedding.astype(np.float16).tobytes()).decode("utf-8") | ||||||
| @@ -71,7 +75,7 @@ async def get_exemplars(): | |||||||
|                         "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry |                         "terms": [{ "embedding": list(float(x) for x in embedding) }], # sorry | ||||||
|                         "k": 10 |                         "k": 10 | ||||||
|                     }) as res: |                     }) as res: | ||||||
|                         return (await res.json())["matches"] |                         return (await res.json())["matches"][:10] | ||||||
|  |  | ||||||
|                 exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ]) |                 exemplars = await aioitertools.asyncio.gather(*[ lookup(feature) for feature in chunk ]) | ||||||
|                 negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ]) |                 negative_exemplars = await aioitertools.asyncio.gather(*[ lookup(-feature) for feature in chunk ]) | ||||||
| @@ -79,7 +83,7 @@ async def get_exemplars(): | |||||||
|                 for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]): |                 for offset, (feature, frequency) in sorted(enumerate(zip(chunk, feature_frequencies[base:])), key=lambda x: -x[1][1]): | ||||||
|                     f.write(f""" |                     f.write(f""" | ||||||
|     <h2>Feature {offset + base}</h2> |     <h2>Feature {offset + base}</h2> | ||||||
|     <h3>Frequency {frequency / len(validation_set)}</h3> |     <h3>Frequency {frequency / val_set_size}</h3> | ||||||
|     <div> |     <div> | ||||||
|     <h4><a href="{emb_url(feature)}">Max</a></h4> |     <h4><a href="{emb_url(feature)}">Max</a></h4> | ||||||
|     """) |     """) | ||||||
|   | |||||||
| @@ -74,6 +74,7 @@ class JSONEncoder(json.JSONEncoder): | |||||||
|             return str(o) |             return str(o) | ||||||
|         else: return super().default(o) |         else: return super().default(o) | ||||||
|  |  | ||||||
|  | """ | ||||||
| logfile = f"logs/log-{time.time()}.jsonl" | logfile = f"logs/log-{time.time()}.jsonl" | ||||||
| with open(logfile, "w") as log: | with open(logfile, "w") as log: | ||||||
|     steps = 0 |     steps = 0 | ||||||
| @@ -90,7 +91,7 @@ with open(logfile, "w") as log: | |||||||
|                 loss = loss.detach().cpu().item() |                 loss = loss.detach().cpu().item() | ||||||
|                 t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}") |                 t.set_description_str(f"loss: {loss:.6f} epoch: {epoch}") | ||||||
|                 log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") |                 log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") | ||||||
|                 if steps % 5000 == 0: save_ckpt(log, steps) |                 #if steps % 5000 == 0: save_ckpt(log, steps) | ||||||
|                 steps += 1 |                 steps += 1 | ||||||
|  |  | ||||||
|         save_ckpt(log, steps) |         save_ckpt(log, steps) | ||||||
| @@ -100,3 +101,4 @@ with open(logfile, "w") as log: | |||||||
|         numpy.save(f"ckpt/{steps}.counters.npy", ctr) |         numpy.save(f"ckpt/{steps}.counters.npy", ctr) | ||||||
|  |  | ||||||
| print(logfile) | print(logfile) | ||||||
|  | """ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 osmarks
					osmarks