mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2025-04-30 22:34:06 +00:00
repurpose meme rater
This commit is contained in:
parent
163dceca4b
commit
0a542ef579
@ -7,6 +7,7 @@ import numpy
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import sys
|
||||||
|
|
||||||
from model import Config, BradleyTerry
|
from model import Config, BradleyTerry
|
||||||
import shared
|
import shared
|
||||||
@ -20,11 +21,12 @@ config = Config(
|
|||||||
n_hidden=1,
|
n_hidden=1,
|
||||||
n_ensemble=16,
|
n_ensemble=16,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.float32,
|
||||||
dropout=0.5
|
output_channels=3,
|
||||||
|
dropout=0.1
|
||||||
)
|
)
|
||||||
model = BradleyTerry(config)
|
model = BradleyTerry(config)
|
||||||
modelc, _ = shared.checkpoint_for(2250)
|
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
||||||
model.load_state_dict(torch.load(modelc))
|
model.load_state_dict(torch.load(modelc))
|
||||||
params = sum(p.numel() for p in model.parameters())
|
params = sum(p.numel() for p in model.parameters())
|
||||||
print(f"{params/1e6:.1f}M parameters")
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
@ -45,11 +47,13 @@ with torch.inference_mode():
|
|||||||
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
|
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
|
||||||
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
|
||||||
win_probs = model(inputs)
|
win_probs = model(inputs)
|
||||||
|
#print(win_probs, win_probs.shape)
|
||||||
#print(win_probs.shape)
|
#print(win_probs.shape)
|
||||||
batchvar = torch.var(win_probs, dim=0)
|
batchvar = torch.var(win_probs, dim=0).max(-1).values
|
||||||
|
#print(batchvar, batchvar.shape)
|
||||||
for filename, var in zip(filenames, batchvar):
|
for filename, var in zip(filenames, batchvar):
|
||||||
variance[filename] = float(var)
|
variance[filename] = float(var)
|
||||||
|
|
||||||
top = sorted(variance.items(), key=lambda x: -x[1])
|
top = sorted(variance.items(), key=lambda x: -x[1])
|
||||||
with open("top.json", "w") as f:
|
with open("top.json", "w") as f:
|
||||||
json.dump(top[:256], f)
|
json.dump(top[:100], f)
|
||||||
|
64
meme-rater/active_learning_find_top.py
Normal file
64
meme-rater/active_learning_find_top.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import torch.nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
import sqlite3
|
||||||
|
import random
|
||||||
|
import numpy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from model import Config, BradleyTerry
|
||||||
|
import shared
|
||||||
|
|
||||||
|
batch_size = 128
|
||||||
|
num_pairs = batch_size * 1024
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
config = Config(
|
||||||
|
d_emb=1152,
|
||||||
|
n_hidden=1,
|
||||||
|
n_ensemble=16,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
output_channels=3,
|
||||||
|
dropout=0.1
|
||||||
|
)
|
||||||
|
model = BradleyTerry(config)
|
||||||
|
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
||||||
|
model.load_state_dict(torch.load(modelc))
|
||||||
|
params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
files = shared.fetch_all_files()
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.inference_mode():
|
||||||
|
for bstart in tqdm(range(0, len(files), batch_size)):
|
||||||
|
batch = files[bstart:bstart + batch_size]
|
||||||
|
filenames = [ f1 for f1, e1 in batch ]
|
||||||
|
embs = torch.stack([ torch.Tensor(e1).to(config.dtype) for f1, e1 in batch ])
|
||||||
|
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
||||||
|
scores = model.ensemble(inputs).median(dim=0).values.cpu().numpy()
|
||||||
|
#print(batchvar, batchvar.shape)
|
||||||
|
for filename, score in zip(filenames, scores):
|
||||||
|
results[filename] = score
|
||||||
|
|
||||||
|
channel = int(sys.argv[2])
|
||||||
|
percentile = float(sys.argv[3])
|
||||||
|
output_pairs = int(sys.argv[4])
|
||||||
|
mean_scores = numpy.mean(numpy.stack([score for filename, score in results.items()]))
|
||||||
|
top = sorted(((filename, score) for filename, score in results.items() if (score > mean_scores).all()), key=lambda x: x[1][channel], reverse=True)
|
||||||
|
select_from = top[:int(len(top) * percentile)]
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for _ in range(output_pairs):
|
||||||
|
# dummy score for compatibility with existing code
|
||||||
|
out.append(((random.choice(select_from)[0], random.choice(select_from)[0]), 0))
|
||||||
|
|
||||||
|
with open("top.json", "w") as f:
|
||||||
|
json.dump(out, f)
|
@ -8,11 +8,11 @@ import json
|
|||||||
import time
|
import time
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch.func import functional_call, vmap, grad
|
from torch.func import functional_call, vmap, grad
|
||||||
|
import sys
|
||||||
|
|
||||||
from model import Config, BradleyTerry
|
from model import Config, BradleyTerry
|
||||||
import shared
|
import shared
|
||||||
|
|
||||||
steps = 855
|
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
num_pairs = batch_size * 1024
|
num_pairs = batch_size * 1024
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
@ -22,10 +22,12 @@ config = Config(
|
|||||||
n_hidden=1,
|
n_hidden=1,
|
||||||
n_ensemble=1,
|
n_ensemble=1,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.bfloat16
|
dtype=torch.float32,
|
||||||
|
output_channels=3,
|
||||||
|
dropout=0.1
|
||||||
)
|
)
|
||||||
model = BradleyTerry(config)
|
model = BradleyTerry(config)
|
||||||
modelc, _ = shared.checkpoint_for(855)
|
modelc, _ = shared.checkpoint_for(int(sys.argv[1]))
|
||||||
model.load_state_dict(torch.load(modelc))
|
model.load_state_dict(torch.load(modelc))
|
||||||
params = sum(p.numel() for p in model.parameters())
|
params = sum(p.numel() for p in model.parameters())
|
||||||
print(f"{params/1e6:.1f}M parameters")
|
print(f"{params/1e6:.1f}M parameters")
|
||||||
@ -61,7 +63,7 @@ for bstart in tqdm(range(0, len(pairs), batch_size)):
|
|||||||
#win_probs = model(inputs)
|
#win_probs = model(inputs)
|
||||||
# TODO gradients
|
# TODO gradients
|
||||||
# don't take variance: do backwards pass and compute gradient norm
|
# don't take variance: do backwards pass and compute gradient norm
|
||||||
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device))
|
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size, config.output_channels), 0.95).to(device))
|
||||||
total_grad_norms = torch.zeros(batch_size).to(device)
|
total_grad_norms = torch.zeros(batch_size).to(device)
|
||||||
for k, v in grads.items():
|
for k, v in grads.items():
|
||||||
param_dims = tuple(range(1, len(v.shape)))
|
param_dims = tuple(range(1, len(v.shape)))
|
20
meme-rater/load_from_json.py
Normal file
20
meme-rater/load_from_json.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import jsonlines
|
||||||
|
import sqlite3
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import shared
|
||||||
|
|
||||||
|
shared.db.executescript("""
|
||||||
|
CREATE TABLE IF NOT EXISTS files (
|
||||||
|
filename TEXT NOT NULL,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
link TEXT NOT NULL,
|
||||||
|
embedding BLOB NOT NULL,
|
||||||
|
UNIQUE (filename)
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
with jsonlines.open("sample.jsonl") as reader:
|
||||||
|
for obj in reader:
|
||||||
|
shared.db.execute("INSERT INTO files (filename, title, link, embedding) VALUES (?, ?, ?, ?)", (obj["metadata"]["final_url"], obj["title"], f"https://reddit.com/r/{obj['subreddit']}/comments/{obj['id']}", sqlite3.Binary(np.array(obj["embedding"], dtype=np.float16).tobytes())))
|
||||||
|
shared.db.commit()
|
@ -13,13 +13,14 @@ class Config:
|
|||||||
device: str
|
device: str
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
dropout: float
|
dropout: float
|
||||||
|
output_channels: int
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
|
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
|
||||||
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
|
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
|
||||||
self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device)
|
self.output = nn.Linear(config.d_emb, config.output_channels, dtype=config.dtype, device=config.device)
|
||||||
|
|
||||||
def forward(self, embs):
|
def forward(self, embs):
|
||||||
x = embs
|
x = embs
|
||||||
@ -34,8 +35,7 @@ class Ensemble(nn.Module):
|
|||||||
|
|
||||||
# model batch
|
# model batch
|
||||||
def forward(self, embs):
|
def forward(self, embs):
|
||||||
xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
|
return torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
|
||||||
return xs.squeeze(-1)
|
|
||||||
|
|
||||||
class BradleyTerry(nn.Module):
|
class BradleyTerry(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -30,6 +30,7 @@ async def index(request):
|
|||||||
return web.Response(text=f"""
|
return web.Response(text=f"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
|
<title>Data Labelling Frontend (Not Evil)</title>
|
||||||
<style>
|
<style>
|
||||||
.memes img {{
|
.memes img {{
|
||||||
width: 45%;
|
width: 45%;
|
||||||
@ -46,38 +47,71 @@ async def index(request):
|
|||||||
}}
|
}}
|
||||||
</style>
|
</style>
|
||||||
<body>
|
<body>
|
||||||
<h1>Meme Rating</h1>
|
<h1>Data Labelling Frontend (Not Evil)</h1>
|
||||||
<form action="/rate" method="POST">
|
<form action="/rate" method="POST">
|
||||||
<input type="radio" name="rating" value="1" id="rating1"> <label for="rating1">Meme 1 is better</label>
|
<table>
|
||||||
<input type="radio" name="rating" value="2" id="rating2"> <label for="rating2">Meme 2 is better</label>
|
<tr>
|
||||||
|
<td><input type="radio" name="rating-useful" value="1" id="rq1"> <label for="rq1">LHS is better (useful)</label></td>
|
||||||
|
<td><input type="radio" name="rating-useful" value="eq" id="rqe"> <label for="rqe">Tie</label></td>
|
||||||
|
<td><input type="radio" name="rating-useful" value="2" id="rq2"> <label for="rq2">RHS is better (useful)</label></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><input type="radio" name="rating-meme" value="1" id="rm1"> <label for="rm1">LHS is better (memetically)</label></td>
|
||||||
|
<td><input type="radio" name="rating-meme" value="eq" id="rme"> <label for="rme">Tie</label></td>
|
||||||
|
<td><input type="radio" name="rating-meme" value="2" id="rm2"> <label for="rm2">RHS is better (memetically)</label></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><input type="radio" name="rating-aesthetic" value="1" id="ra1"> <label for="ra1">LHS is better (aesthetically)</label></td>
|
||||||
|
<td><input type="radio" name="rating-aesthetic" value="eq" id="rae"> <label for="rae">Tie</label></td>
|
||||||
|
<td><input type="radio" name="rating-aesthetic" value="2" id="ra2"> <label for="ra2">RHS is better (aesthetically)</label></td>
|
||||||
|
</td>
|
||||||
|
</table>
|
||||||
|
|
||||||
<input type="hidden" name="meme1" value="{meme1}">
|
<input type="hidden" name="meme1" value="{meme1}">
|
||||||
<input type="hidden" name="meme2" value="{meme2}">
|
<input type="hidden" name="meme2" value="{meme2}">
|
||||||
<input type="hidden" name="iteration" value="{str(iteration or 0)}">
|
<input type="hidden" name="iteration" value="{str(iteration or 0)}">
|
||||||
<input type="submit" value="Submit">
|
<input type="submit" value="Submit">
|
||||||
<div class="memes">
|
<div class="memes">
|
||||||
<img src="/memes/{meme1}" id="meme1">
|
<img src="{meme1}" id="meme1">
|
||||||
<img src="/memes/{meme2}" id="meme2">
|
<img src="{meme2}" id="meme2">
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
<script>
|
<script>
|
||||||
document.addEventListener("keypress", function(event) {{
|
const commitIfReady = () => {{
|
||||||
if (event.key === "1") {{
|
if (document.querySelector("input[name='rating-useful']:checked") && document.querySelector("input[name='rating-meme']:checked") && document.querySelector("input[name='rating-aesthetic']:checked")) {{
|
||||||
document.querySelector("input[name='rating'][value='1']").checked = true
|
|
||||||
document.querySelector("form").submit()
|
|
||||||
}} else if (event.key === "2") {{
|
|
||||||
document.querySelector("input[name='rating'][value='2']").checked = true
|
|
||||||
document.querySelector("form").submit()
|
document.querySelector("form").submit()
|
||||||
}}
|
}}
|
||||||
|
}}
|
||||||
|
document.addEventListener("keypress", function(event) {{
|
||||||
|
if (event.key === "q") {{
|
||||||
|
document.querySelector("input[name='rating-useful'][value='1']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}} else if (event.key === "w") {{
|
||||||
|
document.querySelector("input[name='rating-useful'][value='eq']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}} else if (event.key === "e") {{
|
||||||
|
document.querySelector("input[name='rating-useful'][value='2']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}} else if (event.key === "a") {{
|
||||||
|
document.querySelector("input[name='rating-meme'][value='1']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}} else if (event.key === "s") {{
|
||||||
|
document.querySelector("input[name='rating-meme'][value='eq']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}} else if (event.key === "d") {{
|
||||||
|
document.querySelector("input[name='rating-meme'][value='2']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}} else if (event.key === "z") {{
|
||||||
|
document.querySelector("input[name='rating-aesthetic'][value='1']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}} else if (event.key === "x") {{
|
||||||
|
document.querySelector("input[name='rating-aesthetic'][value='eq']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}} else if (event.key === "c") {{
|
||||||
|
document.querySelector("input[name='rating-aesthetic'][value='2']").checked = true
|
||||||
|
commitIfReady()
|
||||||
|
}}
|
||||||
}});
|
}});
|
||||||
document.querySelector("#meme1").addEventListener("click", function(event) {{
|
|
||||||
document.querySelector("input[name='rating'][value='1']").checked = true
|
|
||||||
document.querySelector("form").submit()
|
|
||||||
}})
|
|
||||||
document.querySelector("#meme2").addEventListener("click", function(event) {{
|
|
||||||
document.querySelector("input[name='rating'][value='2']").checked = true
|
|
||||||
document.querySelector("form").submit()
|
|
||||||
}})
|
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
@ -90,8 +124,8 @@ async def rate(request):
|
|||||||
meme1 = post["meme1"]
|
meme1 = post["meme1"]
|
||||||
meme2 = post["meme2"]
|
meme2 = post["meme2"]
|
||||||
iteration = post["iteration"]
|
iteration = post["iteration"]
|
||||||
rating = post["rating"]
|
rating = post["rating-useful"] + "," + post["rating-meme"] + "," + post["rating-aesthetic"]
|
||||||
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration))
|
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration, ip) VALUES (?, ?, ?, ?, ?)", (meme1, meme2, rating, iteration, request.remote))
|
||||||
await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
|
await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return web.HTTPFound("/")
|
return web.HTTPFound("/")
|
||||||
@ -104,6 +138,7 @@ CREATE TABLE IF NOT EXISTS ratings (
|
|||||||
meme2 TEXT NOT NULL,
|
meme2 TEXT NOT NULL,
|
||||||
rating TEXT NOT NULL,
|
rating TEXT NOT NULL,
|
||||||
iteration TEXT,
|
iteration TEXT,
|
||||||
|
ip TEXT,
|
||||||
UNIQUE (meme1, meme2)
|
UNIQUE (meme1, meme2)
|
||||||
);
|
);
|
||||||
CREATE TABLE IF NOT EXISTS queue (
|
CREATE TABLE IF NOT EXISTS queue (
|
||||||
|
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import sys
|
||||||
|
|
||||||
# Read data from log.jsonl
|
# Read data from log.jsonl
|
||||||
data = []
|
data = []
|
||||||
with open('log.jsonl', 'r') as file:
|
with open(sys.argv[1], 'r') as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
data.append(json.loads(line))
|
data.append(json.loads(line))
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import hashlib
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import numpy
|
import numpy
|
||||||
import random
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
db = sqlite3.connect("data.sqlite3")
|
db = sqlite3.connect("data.sqlite3")
|
||||||
db.row_factory = sqlite3.Row
|
db.row_factory = sqlite3.Row
|
||||||
@ -20,19 +21,24 @@ def fetch_embedding(filename):
|
|||||||
return x.copy() # PyTorch complains otherwise due to bad
|
return x.copy() # PyTorch complains otherwise due to bad
|
||||||
|
|
||||||
def map_rating(rating, uncertainty=0.05):
|
def map_rating(rating, uncertainty=0.05):
|
||||||
match rating:
|
def map_one(rating):
|
||||||
case "1": # meme 1 is better
|
match rating:
|
||||||
return 1 - uncertainty
|
case "1": # meme 1 is better
|
||||||
case "2":
|
return 1 - uncertainty
|
||||||
return uncertainty
|
case "2":
|
||||||
case _: raise ValueError("invalid rating, please fix")
|
return uncertainty
|
||||||
|
case "eq":
|
||||||
|
return 0.5
|
||||||
|
case _: raise ValueError("invalid rating, please fix")
|
||||||
|
|
||||||
|
return np.array([map_one(r) for r in rating.split(",")])
|
||||||
|
|
||||||
def fetch_ratings():
|
def fetch_ratings():
|
||||||
trains = defaultdict(list)
|
trains = defaultdict(list)
|
||||||
validations = defaultdict(list)
|
validations = defaultdict(list)
|
||||||
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
||||||
for meme1, meme2, rating, iteration in csr.fetchall():
|
for meme1, meme2, rating, iteration in csr.fetchall():
|
||||||
(validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
(validations if is_val_set(meme1, meme2) else trains)[int((iteration and iteration.split("-")[0]) or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
||||||
csr.close()
|
csr.close()
|
||||||
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
|
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
|
||||||
|
|
||||||
|
@ -36,7 +36,8 @@ config = TrainConfig(
|
|||||||
n_ensemble=16,
|
n_ensemble=16,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
dropout=0.1
|
dropout=0.1,
|
||||||
|
output_channels=3
|
||||||
),
|
),
|
||||||
lr=3e-4,
|
lr=3e-4,
|
||||||
weight_decay=0.2,
|
weight_decay=0.2,
|
||||||
@ -72,12 +73,12 @@ if config.compile:
|
|||||||
print("compiling...")
|
print("compiling...")
|
||||||
train_step = torch.compile(train_step)
|
train_step = torch.compile(train_step)
|
||||||
|
|
||||||
def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]):
|
def batch_from_inputs(inputs: list[list[tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]]]):
|
||||||
batch_input = torch.stack([
|
batch_input = torch.stack([
|
||||||
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
|
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
|
||||||
for input in inputs
|
for input in inputs
|
||||||
]).to(device)
|
]).to(device)
|
||||||
target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device)
|
target = torch.stack([ torch.Tensor(numpy.array([ rating for emb1, emb2, rating in input ])).to(config.model.dtype) for input in inputs ]).to(device)
|
||||||
return batch_input, target
|
return batch_input, target
|
||||||
|
|
||||||
def evaluate(steps):
|
def evaluate(steps):
|
||||||
@ -118,7 +119,7 @@ with open(logfile, "w") as log:
|
|||||||
print(steps, loss)
|
print(steps, loss)
|
||||||
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
||||||
if steps % 10 == 0:
|
if steps % 10 == 0:
|
||||||
if steps % 250 == 0: save_ckpt(log, steps)
|
if steps % 100 == 0: save_ckpt(log, steps)
|
||||||
loss = evaluate(steps)
|
loss = evaluate(steps)
|
||||||
#print(loss)
|
#print(loss)
|
||||||
#best = min(loss, best)
|
#best = min(loss, best)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user