mirror of
				https://github.com/osmarks/meme-search-engine.git
				synced 2025-10-31 15:23:04 +00:00 
			
		
		
		
	better evals
This commit is contained in:
		
							
								
								
									
										69
									
								
								meme-rater/auroc_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								meme-rater/auroc_test.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| import torch.nn | ||||
| import torch.nn.functional as F | ||||
| import torch | ||||
| import sqlite3 | ||||
| import random | ||||
| import numpy | ||||
| import json | ||||
| import time | ||||
| from tqdm import tqdm | ||||
| import torch | ||||
|  | ||||
| from model import Config, BradleyTerry | ||||
| import shared | ||||
|  | ||||
| batch_size = 128 | ||||
| device = "cuda" | ||||
|  | ||||
| config = Config( | ||||
|     d_emb=1152, | ||||
|     n_hidden=1, | ||||
|     n_ensemble=16, | ||||
|     device=device, | ||||
|     dtype=torch.float32, | ||||
|     dropout=0.1 | ||||
| ) | ||||
| model = BradleyTerry(config).float() | ||||
| modelc, _ = shared.checkpoint_for(1500) | ||||
| model.load_state_dict(torch.load(modelc)) | ||||
| params = sum(p.numel() for p in model.parameters()) | ||||
| print(f"{params/1e6:.1f}M parameters") | ||||
| print(model) | ||||
|  | ||||
| files = shared.fetch_all_files() | ||||
| ratings = {} | ||||
|  | ||||
| model.eval() | ||||
| with torch.inference_mode(): | ||||
|     for bstart in tqdm(range(0, len(files), batch_size)): | ||||
|         batch = files[bstart:bstart + batch_size] | ||||
|         filenames = [ filename for filename, embedding in batch ] | ||||
|         embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ]) | ||||
|         inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device) | ||||
|         scores = model.ensemble(inputs).float() | ||||
|         mscores = torch.median(scores, dim=0).values | ||||
|         for filename, mscore in zip(filenames, mscores): | ||||
|             ratings[filename] = float(mscore) | ||||
|  | ||||
| ratings = sorted(ratings.items(), key=lambda x: x[1]) | ||||
| random.shuffle(ratings) | ||||
|  | ||||
| N = 150 | ||||
|  | ||||
| buf = f"""<!DOCTYPE html> | ||||
| <div> | ||||
| {''.join(f'<div><img src="{"images/" + f}" width="30%"><br><input type=checkbox data-score="{s}"></div>' for i, (f, s) in enumerate(ratings[:N]))} | ||||
| </div> | ||||
| <script> | ||||
|     const dump = () => {{ | ||||
|         const data = [] | ||||
|         for (const x of document.querySelectorAll("input[type=checkbox]")) {{ | ||||
|             data.push([parseFloat(x.getAttribute("data-score")), x.checked]) | ||||
|         }} | ||||
|         console.log(JSON.stringify(data)) | ||||
|     }} | ||||
| </script> | ||||
| """ | ||||
|  | ||||
| with open("eval.html", "w") as f: | ||||
|     f.write(buf) | ||||
							
								
								
									
										32
									
								
								meme-rater/roc_plot.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								meme-rater/roc_plot.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| import matplotlib.pyplot as plt | ||||
| import json | ||||
|  | ||||
| data = json.loads("[[1.2792096138000488,true],[1.1153279542922974,true],[0.9720794558525085,true],[-0.5180545449256897,false],[1.4547114372253418,true],[1.3289614915847778,true],[1.8748269081115723,true],[0.05465051531791687,false],[0.7888763546943665,true],[1.368210792541504,true],[1.4808461666107178,true],[0.9501181244850159,true],[1.2592355012893677,true],[1.0127032995224,true],[-0.8805797100067139,false],[-0.08946493268013,true],[0.4224545955657959,false],[1.0051900148391724,true],[0.5121232271194458,false],[1.0876282453536987,false],[1.5552432537078857,true],[-0.3680466413497925,false],[0.45498305559158325,true],[1.3851803541183472,true],[-0.8842921853065491,false],[2.6869430541992188,false],[1.6892706155776978,false],[0.7087478637695312,false],[-0.5138207077980042,false],[0.16498255729675293,false],[1.265992283821106,true],[0.47311416268348694,false],[0.04918492212891579,false],[1.283980369567871,true],[1.0510015487670898,false],[1.6323922872543335,false],[0.4570896625518799,true],[1.5262614488601685,true],[1.4057230949401855,true],[1.0391144752502441,true],[0.9190238118171692,true],[1.2970502376556396,true],[2.025949478149414,true],[0.6396026611328125,true],[2.3505871295928955,true],[1.0854156017303467,false],[1.0216373205184937,true],[-1.163207769393921,false],[1.8854788541793823,true],[0.249663308262825,false],[-0.8619526028633118,false],[1.9995672702789307,true],[1.0939114093780518,false],[0.6106101870536804,false],[1.8383781909942627,false],[-0.0637127161026001,false],[-0.34953051805496216,false],[0.988452672958374,false],[0.5209289193153381,false],[-0.4708566963672638,false],[0.4715256690979004,false],[-0.7905446887016296,false],[2.0255637168884277,true],[0.8488644361495972,false],[1.6645262241363525,true],[1.0948383808135986,true],[-0.8315924406051636,false],[1.5533114671707153,true],[0.9333463907241821,true],[-0.5723654627799988,false],[1.9510998725891113,true],[0.2842162549495697,false],[1.1901239156723022,false],[1.5058742761611938,false],[0.7622374296188354,false],[0.2894713282585144,false],[0.0965774804353714,false],[0.6335093379020691,false],[-0.7369110584259033,false],[1.2673722505569458,true],[0.9775630235671997,false],[0.7889275550842285,false],[-0.9432369470596313,false],[0.24122865498065948,false],[1.075297474861145,false],[0.545269250869751,false],[-0.1398508995771408,false],[-0.31118375062942505,false],[1.47971510887146,false],[0.5115379691123962,true],[0.8894630074501038,true],[0.4365079700946808,true],[2.5944597721099854,true],[0.8613907694816589,false],[1.1540073156356812,false],[1.6798168420791626,true],[1.5266021490097046,true],[0.2556634545326233,false],[0.90388423204422,false],[0.36393579840660095,false],[1.297504186630249,true],[1.091887354850769,true],[0.931088924407959,true],[0.8854649066925049,true],[0.0385725162923336,false],[1.5259686708450317,true],[-0.725635826587677,false],[-1.72086501121521,false],[1.9044498205184937,true],[-0.10369344800710678,false],[-0.5889104604721069,true],[0.2478746473789215,false],[1.4628609418869019,false],[1.1434470415115356,false],[0.20635242760181427,false],[0.8324120044708252,false],[0.676543653011322,false],[1.1111537218093872,true],[0.0488731786608696,false],[0.8705015182495117,true],[0.5464357733726501,true],[0.6190940737724304,true],[0.33756133913993835,false],[0.8019527196884155,true],[1.1540179252624512,true],[-1.4343260526657104,true],[1.4069069623947144,true],[0.5078597664833069,true],[0.1831521838903427,false],[-0.5352457761764526,false],[1.3706591129302979,true],[-0.8636290431022644,false],[0.8164027333259583,false],[0.6665022969245911,false],[0.5028047561645508,false],[-0.7765756845474243,false],[1.204775333404541,false],[1.2527906894683838,false],[0.7420544028282166,false],[1.0363034009933472,true],[1.0559784173965454,false],[-0.72457355260849,false],[1.9217685461044312,true],[0.9770780205726624,false],[0.8808136582374573,true],[1.0174754858016968,false],[0.4287119507789612,false],[1.0718724727630615,true],[0.8409612774848938,true],[-1.3366127014160156,false]]") | ||||
| data = sorted(data, reverse=True) | ||||
|  | ||||
| tprs, fprs = [], [] | ||||
| positives = sum(1 for _, ground_truth in data if ground_truth) | ||||
| negatives = len(data) - positives | ||||
|  | ||||
| for threshold, _ in data: | ||||
|     tp = sum(1 for score, ground_truth in data if ground_truth and score >= threshold) | ||||
|     fp = sum(1 for score, ground_truth in data if not ground_truth and score >= threshold) | ||||
|     tpr = tp / positives | ||||
|     fpr = fp / negatives | ||||
|     tprs.append(tpr) | ||||
|     fprs.append(fpr) | ||||
|  | ||||
| auroc = 0 | ||||
| for i in range(len(fprs) - 1): | ||||
|     auroc += (fprs[i+1] - fprs[i]) * (tprs[i+1] + tprs[i]) / 2 | ||||
|  | ||||
| print(f"AUROC: {auroc}") | ||||
|  | ||||
| plt.plot(fprs, tprs) | ||||
|  | ||||
| plt.xlabel("FPR") | ||||
| plt.ylabel("TPR") | ||||
| plt.title("ROC") | ||||
|  | ||||
| plt.tight_layout() | ||||
| plt.show() | ||||
		Reference in New Issue
	
	Block a user