mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-31 07:13:04 +00:00 
			
		
		
		
	repurpose meme rater
This commit is contained in:
		| @@ -7,6 +7,7 @@ import numpy | |||||||
| import json | import json | ||||||
| import time | import time | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
|  | import sys | ||||||
|  |  | ||||||
| from model import Config, BradleyTerry | from model import Config, BradleyTerry | ||||||
| import shared | import shared | ||||||
| @@ -20,11 +21,12 @@ config = Config( | |||||||
|     n_hidden=1, |     n_hidden=1, | ||||||
|     n_ensemble=16, |     n_ensemble=16, | ||||||
|     device=device, |     device=device, | ||||||
|     dtype=torch.bfloat16, |     dtype=torch.float32, | ||||||
|     dropout=0.5 |     output_channels=3, | ||||||
|  |     dropout=0.1 | ||||||
| ) | ) | ||||||
| model = BradleyTerry(config) | model = BradleyTerry(config) | ||||||
| modelc, _ = shared.checkpoint_for(2250) | modelc, _ = shared.checkpoint_for(int(sys.argv[1])) | ||||||
| model.load_state_dict(torch.load(modelc)) | model.load_state_dict(torch.load(modelc)) | ||||||
| params = sum(p.numel() for p in model.parameters()) | params = sum(p.numel() for p in model.parameters()) | ||||||
| print(f"{params/1e6:.1f}M parameters") | print(f"{params/1e6:.1f}M parameters") | ||||||
| @@ -45,11 +47,13 @@ with torch.inference_mode(): | |||||||
|         embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ]) |         embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ]) | ||||||
|         inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device) |         inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device) | ||||||
|         win_probs = model(inputs) |         win_probs = model(inputs) | ||||||
|  |         #print(win_probs, win_probs.shape) | ||||||
|         #print(win_probs.shape) |         #print(win_probs.shape) | ||||||
|         batchvar = torch.var(win_probs, dim=0) |         batchvar = torch.var(win_probs, dim=0).max(-1).values | ||||||
|  |         #print(batchvar, batchvar.shape) | ||||||
|         for filename, var in zip(filenames, batchvar): |         for filename, var in zip(filenames, batchvar): | ||||||
|             variance[filename] = float(var) |             variance[filename] = float(var) | ||||||
|  |  | ||||||
| top = sorted(variance.items(), key=lambda x: -x[1]) | top = sorted(variance.items(), key=lambda x: -x[1]) | ||||||
| with open("top.json", "w") as f: | with open("top.json", "w") as f: | ||||||
|     json.dump(top[:256], f) |     json.dump(top[:100], f) | ||||||
|   | |||||||
							
								
								
									
										64
									
								
								meme-rater/active_learning_find_top.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								meme-rater/active_learning_find_top.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | |||||||
|  | import torch.nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | import torch | ||||||
|  | import sqlite3 | ||||||
|  | import random | ||||||
|  | import numpy | ||||||
|  | import json | ||||||
|  | import time | ||||||
|  | from tqdm import tqdm | ||||||
|  | import sys | ||||||
|  | from collections import defaultdict | ||||||
|  |  | ||||||
|  | from model import Config, BradleyTerry | ||||||
|  | import shared | ||||||
|  |  | ||||||
|  | batch_size = 128 | ||||||
|  | num_pairs = batch_size * 1024 | ||||||
|  | device = "cuda" | ||||||
|  |  | ||||||
|  | config = Config( | ||||||
|  |     d_emb=1152, | ||||||
|  |     n_hidden=1, | ||||||
|  |     n_ensemble=16, | ||||||
|  |     device=device, | ||||||
|  |     dtype=torch.float32, | ||||||
|  |     output_channels=3, | ||||||
|  |     dropout=0.1 | ||||||
|  | ) | ||||||
|  | model = BradleyTerry(config) | ||||||
|  | modelc, _ = shared.checkpoint_for(int(sys.argv[1])) | ||||||
|  | model.load_state_dict(torch.load(modelc)) | ||||||
|  | params = sum(p.numel() for p in model.parameters()) | ||||||
|  | print(f"{params/1e6:.1f}M parameters") | ||||||
|  | print(model) | ||||||
|  |  | ||||||
|  | files = shared.fetch_all_files() | ||||||
|  | results = {} | ||||||
|  |  | ||||||
|  | model.eval() | ||||||
|  | with torch.inference_mode(): | ||||||
|  |     for bstart in tqdm(range(0, len(files), batch_size)): | ||||||
|  |         batch = files[bstart:bstart + batch_size] | ||||||
|  |         filenames = [ f1 for f1, e1 in batch ] | ||||||
|  |         embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ]) | ||||||
|  |         inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device) | ||||||
|  |         scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy() | ||||||
|  |         #print(batchvar, batchvar.shape) | ||||||
|  |         for filename, score in zip(filenames, scores): | ||||||
|  |             results[filename] = score | ||||||
|  |  | ||||||
|  | channel = int(sys.argv[2]) | ||||||
|  | percentile = float(sys.argv[3]) | ||||||
|  | output_pairs = int(sys.argv[4]) | ||||||
|  | mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()])) | ||||||
|  | top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True) | ||||||
|  | select_from = top[:int(len(top) * percentile)] | ||||||
|  |  | ||||||
|  | out = [] | ||||||
|  | for _ in range(output_pairs): | ||||||
|  |     # dummy score for compatibility with existing code | ||||||
|  |     out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0)) | ||||||
|  |  | ||||||
|  | with open("top.json", "w") as f: | ||||||
|  |     json.dump(out, f) | ||||||
| @@ -8,11 +8,11 @@ import json | |||||||
| import time | import time | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| from torch.func import functional_call, vmap, grad | from torch.func import functional_call, vmap, grad | ||||||
|  | import sys | ||||||
| 
 | 
 | ||||||
| from model import Config, BradleyTerry | from model import Config, BradleyTerry | ||||||
| import shared | import shared | ||||||
| 
 | 
 | ||||||
| steps = 855 |  | ||||||
| batch_size = 128 | batch_size = 128 | ||||||
| num_pairs = batch_size * 1024 | num_pairs = batch_size * 1024 | ||||||
| device = "cuda" | device = "cuda" | ||||||
| @@ -22,10 +22,12 @@ config = Config( | |||||||
|     n_hidden=1, |     n_hidden=1, | ||||||
|     n_ensemble=1, |     n_ensemble=1, | ||||||
|     device=device, |     device=device, | ||||||
|     dtype=torch.bfloat16 |     dtype=torch.float32, | ||||||
|  |     output_channels=3, | ||||||
|  |     dropout=0.1 | ||||||
| ) | ) | ||||||
| model = BradleyTerry(config) | model = BradleyTerry(config) | ||||||
| modelc, _ = shared.checkpoint_for(855) | modelc, _ = shared.checkpoint_for(int(sys.argv[1])) | ||||||
| model.load_state_dict(torch.load(modelc)) | model.load_state_dict(torch.load(modelc)) | ||||||
| params = sum(p.numel() for p in model.parameters()) | params = sum(p.numel() for p in model.parameters()) | ||||||
| print(f"{params/1e6:.1f}M parameters") | print(f"{params/1e6:.1f}M parameters") | ||||||
| @@ -61,7 +63,7 @@ for bstart in tqdm(range(0, len(pairs), batch_size)): | |||||||
|     #win_probs = model(inputs) |     #win_probs = model(inputs) | ||||||
|     # TODO gradients |     # TODO gradients | ||||||
|     # don't take variance: do backwards pass and compute gradient norm |     # don't take variance: do backwards pass and compute gradient norm | ||||||
|     grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device)) |     grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size, config.output_channels), 0.95).to(device)) | ||||||
|     total_grad_norms = torch.zeros(batch_size).to(device) |     total_grad_norms = torch.zeros(batch_size).to(device) | ||||||
|     for k, v in grads.items(): |     for k, v in grads.items(): | ||||||
|         param_dims = tuple(range(1, len(v.shape))) |         param_dims = tuple(range(1, len(v.shape))) | ||||||
| @@ -73,4 +75,4 @@ for bstart in tqdm(range(0, len(pairs), batch_size)): | |||||||
| 
 | 
 | ||||||
| top = sorted(importance.items(), key=lambda x: -x[1]) | top = sorted(importance.items(), key=lambda x: -x[1]) | ||||||
| with open("top.json", "w") as f: | with open("top.json", "w") as f: | ||||||
|     json.dump(top[:256], f) |     json.dump(top[:256], f) | ||||||
							
								
								
									
										20
									
								
								meme-rater/load_from_json.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								meme-rater/load_from_json.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | import jsonlines | ||||||
|  | import sqlite3 | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | import shared | ||||||
|  |  | ||||||
|  | shared.db.executescript(""" | ||||||
|  | CREATE TABLE IF NOT EXISTS files ( | ||||||
|  |     filename TEXT NOT NULL, | ||||||
|  |     title TEXT NOT NULL, | ||||||
|  |     link TEXT NOT NULL, | ||||||
|  |     embedding BLOB NOT NULL, | ||||||
|  |     UNIQUE (filename) | ||||||
|  | ); | ||||||
|  | """) | ||||||
|  |  | ||||||
|  | with jsonlines.open("sample.jsonl") as reader: | ||||||
|  |     for obj in reader: | ||||||
|  |         shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes()))) | ||||||
|  | shared.db.commit() | ||||||
| @@ -13,13 +13,14 @@ class Config: | |||||||
|     device: str |     device: str | ||||||
|     dtype: torch.dtype |     dtype: torch.dtype | ||||||
|     dropout: float |     dropout: float | ||||||
|  |     output_channels: int | ||||||
|  |  | ||||||
| class Model(nn.Module): | class Model(nn.Module): | ||||||
|     def __init__(self, config): |     def __init__(self, config): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ]) |         self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ]) | ||||||
|         self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ]) |         self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ]) | ||||||
|         self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device) |         self.output = nn.Linear(config.d_emb, config.output_channels, dtype=config.dtype, device=config.device) | ||||||
|  |  | ||||||
|     def forward(self, embs): |     def forward(self, embs): | ||||||
|         x = embs |         x = embs | ||||||
| @@ -34,8 +35,7 @@ class Ensemble(nn.Module): | |||||||
|  |  | ||||||
|     # model batch |     # model batch | ||||||
|     def forward(self, embs): |     def forward(self, embs): | ||||||
|         xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 |         return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1 | ||||||
|         return xs.squeeze(-1) |  | ||||||
|  |  | ||||||
| class BradleyTerry(nn.Module): | class BradleyTerry(nn.Module): | ||||||
|     def __init__(self, config): |     def __init__(self, config): | ||||||
| @@ -49,4 +49,4 @@ class BradleyTerry(nn.Module): | |||||||
|         #print(scores1, scores2) |         #print(scores1, scores2) | ||||||
|         probs = torch.sigmoid(scores1 - scores2) # model batch |         probs = torch.sigmoid(scores1 - scores2) # model batch | ||||||
|         #print(probs) |         #print(probs) | ||||||
|         return probs |         return probs | ||||||
|   | |||||||
| @@ -30,9 +30,10 @@ async def index(request): | |||||||
|     return web.Response(text=f""" |     return web.Response(text=f""" | ||||||
| <!DOCTYPE html> | <!DOCTYPE html> | ||||||
| <html> | <html> | ||||||
|  |     <title>Data Labelling Frontend (Not Evil)</title> | ||||||
|     <style> |     <style> | ||||||
| .memes img {{ | .memes img {{ | ||||||
|   width: 45%;  |   width: 45%; | ||||||
| }} | }} | ||||||
|  |  | ||||||
| @media (max-width: 768px) {{ | @media (max-width: 768px) {{ | ||||||
| @@ -46,38 +47,71 @@ async def index(request): | |||||||
| }} | }} | ||||||
|     </style> |     </style> | ||||||
|     <body> |     <body> | ||||||
|         <h1>Meme Rating</h1> |         <h1>Data Labelling Frontend (Not Evil)</h1> | ||||||
|         <form action="/rate" method="POST"> |         <form action="/rate" method="POST"> | ||||||
|             <input type="radio" name="rating" value="1" id="rating1"> <label for="rating1">Meme 1 is better</label> |             <table> | ||||||
|             <input type="radio" name="rating" value="2" id="rating2"> <label for="rating2">Meme 2 is better</label> |             <tr> | ||||||
|  |             <td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td> | ||||||
|  |             <td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td> | ||||||
|  |             <td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td> | ||||||
|  |             </tr> | ||||||
|  |             <tr> | ||||||
|  |             <td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td> | ||||||
|  |             <td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td> | ||||||
|  |             <td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td> | ||||||
|  |             </tr> | ||||||
|  |             <tr> | ||||||
|  |             <td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td> | ||||||
|  |             <td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td> | ||||||
|  |             <td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td> | ||||||
|  |             </td> | ||||||
|  |             </table> | ||||||
|  |  | ||||||
|             <input type="hidden" name="meme1" value="{meme1}"> |             <input type="hidden" name="meme1" value="{meme1}"> | ||||||
|             <input type="hidden" name="meme2" value="{meme2}"> |             <input type="hidden" name="meme2" value="{meme2}"> | ||||||
|             <input type="hidden" name="iteration" value="{str(iteration or 0)}"> |             <input type="hidden" name="iteration" value="{str(iteration or 0)}"> | ||||||
|             <input type="submit" value="Submit"> |             <input type="submit" value="Submit"> | ||||||
|             <div class="memes"> |             <div class="memes"> | ||||||
|                 <img src="/memes/{meme1}" id="meme1"> |                 <img src="{meme1}" id="meme1"> | ||||||
|                 <img src="/memes/{meme2}" id="meme2"> |                 <img src="{meme2}" id="meme2"> | ||||||
|             </div> |             </div> | ||||||
|         </form> |         </form> | ||||||
|         <script> |         <script> | ||||||
|             document.addEventListener("keypress", function(event) {{ |             const commitIfReady = () => {{ | ||||||
|                 if (event.key === "1") {{ |                 if (document.querySelector("input[name='rating-useful']:checked") && document.querySelector("input[name='rating-meme']:checked") && document.querySelector("input[name='rating-aesthetic']:checked")) {{ | ||||||
|                     document.querySelector("input[name='rating'][value='1']").checked = true |  | ||||||
|                     document.querySelector("form").submit() |  | ||||||
|                 }} else if (event.key === "2") {{ |  | ||||||
|                     document.querySelector("input[name='rating'][value='2']").checked = true |  | ||||||
|                     document.querySelector("form").submit() |                     document.querySelector("form").submit() | ||||||
|                 }} |                 }} | ||||||
|  |             }} | ||||||
|  |             document.addEventListener("keypress", function(event) {{ | ||||||
|  |                 if (event.key === "q") {{ | ||||||
|  |                     document.querySelector("input[name='rating-useful'][value='1']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} else if (event.key === "w") {{ | ||||||
|  |                     document.querySelector("input[name='rating-useful'][value='eq']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} else if (event.key === "e") {{ | ||||||
|  |                     document.querySelector("input[name='rating-useful'][value='2']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} else if (event.key === "a") {{ | ||||||
|  |                     document.querySelector("input[name='rating-meme'][value='1']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} else if (event.key === "s") {{ | ||||||
|  |                     document.querySelector("input[name='rating-meme'][value='eq']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} else if (event.key === "d") {{ | ||||||
|  |                     document.querySelector("input[name='rating-meme'][value='2']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} else if (event.key === "z") {{ | ||||||
|  |                     document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} else if (event.key === "x") {{ | ||||||
|  |                     document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} else if (event.key === "c") {{ | ||||||
|  |                     document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true | ||||||
|  |                     commitIfReady() | ||||||
|  |                 }} | ||||||
|             }}); |             }}); | ||||||
|             document.querySelector("#meme1").addEventListener("click", function(event) {{ |  | ||||||
|                 document.querySelector("input[name='rating'][value='1']").checked = true |  | ||||||
|                 document.querySelector("form").submit() |  | ||||||
|             }}) |  | ||||||
|             document.querySelector("#meme2").addEventListener("click", function(event) {{ |  | ||||||
|                 document.querySelector("input[name='rating'][value='2']").checked = true |  | ||||||
|                 document.querySelector("form").submit() |  | ||||||
|             }}) |  | ||||||
|         </script> |         </script> | ||||||
|     </body> |     </body> | ||||||
| </html> | </html> | ||||||
| @@ -90,8 +124,8 @@ async def rate(request): | |||||||
|     meme1 = post["meme1"] |     meme1 = post["meme1"] | ||||||
|     meme2 = post["meme2"] |     meme2 = post["meme2"] | ||||||
|     iteration = post["iteration"] |     iteration = post["iteration"] | ||||||
|     rating = post["rating"] |     rating = post["rating-useful"] + "," + post["rating-meme"] + "," + post["rating-aesthetic"] | ||||||
|     await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration)) |     await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration, ip) VALUES (?, ?, ?, ?, ?)", (meme1, meme2, rating, iteration, request.remote)) | ||||||
|     await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2)) |     await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2)) | ||||||
|     await db.commit() |     await db.commit() | ||||||
|     return web.HTTPFound("/") |     return web.HTTPFound("/") | ||||||
| @@ -104,6 +138,7 @@ CREATE TABLE IF NOT EXISTS ratings ( | |||||||
|     meme2 TEXT NOT NULL, |     meme2 TEXT NOT NULL, | ||||||
|     rating TEXT NOT NULL, |     rating TEXT NOT NULL, | ||||||
|     iteration TEXT, |     iteration TEXT, | ||||||
|  |     ip TEXT, | ||||||
|     UNIQUE (meme1, meme2) |     UNIQUE (meme1, meme2) | ||||||
| ); | ); | ||||||
| CREATE TABLE IF NOT EXISTS queue ( | CREATE TABLE IF NOT EXISTS queue ( | ||||||
|   | |||||||
| @@ -2,10 +2,11 @@ | |||||||
|  |  | ||||||
| import json | import json | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
|  | import sys | ||||||
|  |  | ||||||
| # Read data from log.jsonl | # Read data from log.jsonl | ||||||
| data = [] | data = [] | ||||||
| with open('log.jsonl', 'r') as file: | with open(sys.argv[1], 'r') as file: | ||||||
|     for line in file: |     for line in file: | ||||||
|         data.append(json.loads(line)) |         data.append(json.loads(line)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ import hashlib | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| import numpy | import numpy | ||||||
| import random | import random | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
| db = sqlite3.connect("data.sqlite3") | db = sqlite3.connect("data.sqlite3") | ||||||
| db.row_factory = sqlite3.Row | db.row_factory = sqlite3.Row | ||||||
| @@ -20,19 +21,24 @@ def fetch_embedding(filename): | |||||||
|     return x.copy() # PyTorch complains otherwise due to bad |     return x.copy() # PyTorch complains otherwise due to bad | ||||||
|  |  | ||||||
| def map_rating(rating, uncertainty=0.05): | def map_rating(rating, uncertainty=0.05): | ||||||
|     match rating: |     def map_one(rating): | ||||||
|         case "1": # meme 1 is better |         match rating: | ||||||
|             return 1 - uncertainty |             case "1": # meme 1 is better | ||||||
|         case "2": |                 return 1 - uncertainty | ||||||
|             return uncertainty |             case "2": | ||||||
|         case _: raise ValueError("invalid rating, please fix") |                 return uncertainty | ||||||
|  |             case "eq": | ||||||
|  |                 return 0.5 | ||||||
|  |             case _: raise ValueError("invalid rating, please fix") | ||||||
|  |  | ||||||
|  |     return np.array([map_one(r) for r in rating.split(",")]) | ||||||
|  |  | ||||||
| def fetch_ratings(): | def fetch_ratings(): | ||||||
|     trains = defaultdict(list) |     trains = defaultdict(list) | ||||||
|     validations = defaultdict(list) |     validations = defaultdict(list) | ||||||
|     csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings") |     csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings") | ||||||
|     for meme1, meme2, rating, iteration in csr.fetchall(): |     for meme1, meme2, rating, iteration in csr.fetchall(): | ||||||
|         (validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating))) |         (validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating))) | ||||||
|     csr.close() |     csr.close() | ||||||
|     return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items())) |     return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items())) | ||||||
|  |  | ||||||
| @@ -50,4 +56,4 @@ def fetch_all_files(): | |||||||
|     return x |     return x | ||||||
|  |  | ||||||
| def checkpoint_for(steps): | def checkpoint_for(steps): | ||||||
|     return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt" |     return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt" | ||||||
|   | |||||||
| @@ -36,7 +36,8 @@ config = TrainConfig( | |||||||
|         n_ensemble=16, |         n_ensemble=16, | ||||||
|         device=device, |         device=device, | ||||||
|         dtype=torch.float32, |         dtype=torch.float32, | ||||||
|         dropout=0.1 |         dropout=0.1, | ||||||
|  |         output_channels=3 | ||||||
|     ), |     ), | ||||||
|     lr=3e-4, |     lr=3e-4, | ||||||
|     weight_decay=0.2, |     weight_decay=0.2, | ||||||
| @@ -72,12 +73,12 @@ if config.compile: | |||||||
|     print("compiling...") |     print("compiling...") | ||||||
|     train_step = torch.compile(train_step) |     train_step = torch.compile(train_step) | ||||||
|  |  | ||||||
| def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]): | def batch_from_inputs(inputs: list[list[tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]]]): | ||||||
|     batch_input = torch.stack([ |     batch_input = torch.stack([ | ||||||
|         torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ]) |         torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ]) | ||||||
|         for input in inputs |         for input in inputs | ||||||
|     ]).to(device) |     ]).to(device) | ||||||
|     target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device) |     target = torch.stack([ torch.Tensor(numpy.array([ rating for emb1, emb2, rating in input ])).to(config.model.dtype) for input in inputs ]).to(device) | ||||||
|     return batch_input, target |     return batch_input, target | ||||||
|  |  | ||||||
| def evaluate(steps): | def evaluate(steps): | ||||||
| @@ -118,7 +119,7 @@ with open(logfile, "w") as log: | |||||||
|                 print(steps, loss) |                 print(steps, loss) | ||||||
|                 log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") |                 log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n") | ||||||
|                 if steps % 10 == 0: |                 if steps % 10 == 0: | ||||||
|                     if steps % 250 == 0: save_ckpt(log, steps) |                     if steps % 100 == 0: save_ckpt(log, steps) | ||||||
|                     loss = evaluate(steps) |                     loss = evaluate(steps) | ||||||
|                     #print(loss) |                     #print(loss) | ||||||
|                     #best = min(loss, best) |                     #best = min(loss, best) | ||||||
| @@ -126,4 +127,4 @@ with open(logfile, "w") as log: | |||||||
|  |  | ||||||
|         save_ckpt(log, steps) |         save_ckpt(log, steps) | ||||||
|  |  | ||||||
| print(logfile) | print(logfile) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 osmarks
					osmarks