mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
meme rater model code (documentation "later")
This commit is contained in:
parent
0b0261f625
commit
58ce70bb5e
55
meme-rater/active_learning.py
Normal file
55
meme-rater/active_learning.py
Normal file
@ -0,0 +1,55 @@
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import sqlite3
|
||||
import random
|
||||
import numpy
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import Config, BradleyTerry
|
||||
import shared
|
||||
|
||||
batch_size = 128
|
||||
num_pairs = batch_size * 1024
|
||||
device = "cuda"
|
||||
|
||||
config = Config(
|
||||
d_emb=1152,
|
||||
n_hidden=1,
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
dropout=0.5
|
||||
)
|
||||
model = BradleyTerry(config)
|
||||
modelc, _ = shared.checkpoint_for(2250)
|
||||
model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
|
||||
files = shared.fetch_all_files()
|
||||
variance = {}
|
||||
|
||||
pairs = []
|
||||
for _ in range(num_pairs):
|
||||
pairs.append(tuple(random.sample(files, 2)))
|
||||
|
||||
model.eval()
|
||||
with torch.inference_mode():
|
||||
for bstart in tqdm(range(0, len(pairs), batch_size)):
|
||||
batch = pairs[bstart:bstart + batch_size]
|
||||
filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
|
||||
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
|
||||
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
|
||||
win_probs = model(inputs)
|
||||
#print(win_probs.shape)
|
||||
batchvar = torch.var(win_probs, dim=0)
|
||||
for filename, var in zip(filenames, batchvar):
|
||||
variance[filename] = float(var)
|
||||
|
||||
top = sorted(variance.items(), key=lambda x: -x[1])
|
||||
with open("top.json", "w") as f:
|
||||
json.dump(top[:256], f)
|
76
meme-rater/al2.py
Normal file
76
meme-rater/al2.py
Normal file
@ -0,0 +1,76 @@
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import sqlite3
|
||||
import random
|
||||
import numpy
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from torch.func import functional_call, vmap, grad
|
||||
|
||||
from model import Config, BradleyTerry
|
||||
import shared
|
||||
|
||||
steps = 855
|
||||
batch_size = 128
|
||||
num_pairs = batch_size * 1024
|
||||
device = "cuda"
|
||||
|
||||
config = Config(
|
||||
d_emb=1152,
|
||||
n_hidden=1,
|
||||
n_ensemble=1,
|
||||
device=device,
|
||||
dtype=torch.bfloat16
|
||||
)
|
||||
model = BradleyTerry(config)
|
||||
modelc, _ = shared.checkpoint_for(855)
|
||||
model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
|
||||
files = shared.fetch_all_files()
|
||||
importance = {}
|
||||
|
||||
params = {k: v.detach() for k, v in model.named_parameters()}
|
||||
buffers = {k: v.detach() for k, v in model.named_buffers()}
|
||||
|
||||
# https://pytorch.org/tutorials/intermediate/per_sample_grads.html
|
||||
def compute_loss(params, buffers, sample, target):
|
||||
batch = sample.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
|
||||
predictions = functional_call(model, (params, buffers), (batch,))
|
||||
loss = F.binary_cross_entropy(predictions, targets)
|
||||
return loss
|
||||
|
||||
ft_compute_grad = grad(compute_loss)
|
||||
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1))
|
||||
|
||||
pairs = []
|
||||
for _ in range(num_pairs):
|
||||
pairs.append(tuple(random.sample(files, 2)))
|
||||
|
||||
for bstart in tqdm(range(0, len(pairs), batch_size)):
|
||||
batch = pairs[bstart:bstart + batch_size]
|
||||
filenames = [ (f1, f2) for ((f1, e1), (f2, e2)) in batch ]
|
||||
embs = torch.stack([ torch.stack((torch.Tensor(e1).to(config.dtype), torch.Tensor(e2).to(config.dtype))) for ((f1, e1), (f2, e2)) in batch ])
|
||||
inputs = embs.unsqueeze(0).expand((config.n_ensemble, batch_size, 2, config.d_emb)).to(device)
|
||||
#win_probs = model(inputs)
|
||||
# TODO gradients
|
||||
# don't take variance: do backwards pass and compute gradient norm
|
||||
grads = ft_compute_sample_grad(params, buffers, inputs, torch.full((1, batch_size), 0.95).to(device))
|
||||
total_grad_norms = torch.zeros(batch_size).to(device)
|
||||
for k, v in grads.items():
|
||||
param_dims = tuple(range(1, len(v.shape)))
|
||||
total_grad_norms += torch.linalg.vector_norm(v, dim=param_dims)
|
||||
tgn = total_grad_norms.cpu().numpy()
|
||||
|
||||
for filename, tg in zip(filenames, tgn):
|
||||
importance[filename] = float(tg)
|
||||
|
||||
top = sorted(importance.items(), key=lambda x: -x[1])
|
||||
with open("top.json", "w") as f:
|
||||
json.dump(top[:256], f)
|
14
meme-rater/copy_into_queue.py
Normal file
14
meme-rater/copy_into_queue.py
Normal file
@ -0,0 +1,14 @@
|
||||
import sqlite3
|
||||
import json
|
||||
import sys
|
||||
|
||||
iteration = sys.argv[1]
|
||||
|
||||
db = sqlite3.connect("data.sqlite3")
|
||||
db.row_factory = sqlite3.Row
|
||||
|
||||
with open("top.json", "r") as f:
|
||||
listing = json.load(f)
|
||||
|
||||
db.executemany("INSERT INTO queue VALUES (?, ?, ?)", [ (x[0], x[1], iteration) for x, v in listing ])
|
||||
db.commit()
|
86
meme-rater/eval.py
Normal file
86
meme-rater/eval.py
Normal file
@ -0,0 +1,86 @@
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import sqlite3
|
||||
import random
|
||||
import numpy
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from model import Config, BradleyTerry
|
||||
import shared
|
||||
|
||||
batch_size = 128
|
||||
device = "cuda"
|
||||
|
||||
config = Config(
|
||||
d_emb=1152,
|
||||
n_hidden=1,
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
dropout=0.1
|
||||
)
|
||||
model = BradleyTerry(config).float()
|
||||
modelc, _ = shared.checkpoint_for(1500)
|
||||
model.load_state_dict(torch.load(modelc))
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
|
||||
files = shared.fetch_all_files()
|
||||
ratings = {}
|
||||
|
||||
model.eval()
|
||||
with torch.inference_mode():
|
||||
for bstart in tqdm(range(0, len(files), batch_size)):
|
||||
batch = files[bstart:bstart + batch_size]
|
||||
filenames = [ filename for filename, embedding in batch ]
|
||||
embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
|
||||
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
|
||||
scores = model.ensemble(inputs).float()
|
||||
mscores = torch.median(scores, dim=0).values
|
||||
for filename, mscore in zip(filenames, mscores):
|
||||
ratings[filename] = float(mscore)
|
||||
|
||||
ratings = sorted(ratings.items(), key=lambda x: x[1])
|
||||
|
||||
def percentile(p, n):
|
||||
base = round(p * len(ratings))
|
||||
return ratings[base:base + n]
|
||||
|
||||
N = 25
|
||||
def render_memeset(p):
|
||||
filenames = percentile(p, N)
|
||||
return f"""
|
||||
<div>
|
||||
<details><summary>Reveal Memeset</summary>{p}</details>
|
||||
{''.join(f'<div><img src="{"images/" + f}" width="30%"><br><input type=checkbox id="{"col-" + str(p) + "-" + str(i)}"></div>' for i, (f, s) in enumerate(filenames))}
|
||||
</div>
|
||||
"""
|
||||
|
||||
buf = """<!DOCTYPE html>"""
|
||||
probs = [0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 0.95, 0.98, 0.99]
|
||||
random.shuffle(probs)
|
||||
for p in probs:
|
||||
#for p in [0.3]:
|
||||
buf += render_memeset(p)
|
||||
|
||||
buf += """
|
||||
<script>
|
||||
const computeCounts = () => {
|
||||
const counts = {}
|
||||
for (const x of document.querySelectorAll("input[type=checkbox]")) {
|
||||
const [_, percentile, index] = x.getAttribute("id").split("-")
|
||||
counts[percentile] ??= 0
|
||||
if (x.checked) counts[percentile] += 1
|
||||
}
|
||||
console.log(counts)
|
||||
}
|
||||
</script>
|
||||
"""
|
||||
|
||||
with open("eval.html", "w") as f:
|
||||
f.write(buf)
|
33
meme-rater/final_eval_results.py
Normal file
33
meme-rater/final_eval_results.py
Normal file
@ -0,0 +1,33 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
|
||||
# Data as a JSON string
|
||||
data_json = '{"0.95":22,"0.75":21,"0.5":15,"0.98":23,"0.25":3,"0.05":0,"0.99":24,"0.1":2,"0.01":0,"0.02":0}'
|
||||
|
||||
# Parse the JSON string into a dictionary
|
||||
data = json.loads(data_json)
|
||||
|
||||
# Extract the keys and values from the dictionary
|
||||
keys = list(data.keys())
|
||||
values = list(data.values())
|
||||
|
||||
# Convert the keys to floats
|
||||
keys = [float(key) for key in keys]
|
||||
|
||||
# Sort the keys and values based on the keys
|
||||
sorted_data = sorted(zip(keys, values))
|
||||
keys, values = zip(*sorted_data)
|
||||
|
||||
plt.plot(keys, values)
|
||||
|
||||
# Set the x-axis tick labels
|
||||
plt.xticks(keys, rotation=45)
|
||||
|
||||
# Add labels and title
|
||||
plt.xlabel('Percentile')
|
||||
plt.ylabel('Memes Kept')
|
||||
plt.title('Final Model Evaluation')
|
||||
|
||||
# Display the plot
|
||||
plt.tight_layout()
|
||||
plt.show()
|
52
meme-rater/model.py
Normal file
52
meme-rater/model.py
Normal file
@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
d_emb: int
|
||||
n_hidden: int
|
||||
n_ensemble: int
|
||||
device: str
|
||||
dtype: torch.dtype
|
||||
dropout: float
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden = nn.ModuleList([ nn.Linear(config.d_emb, config.d_emb, dtype=config.dtype, device=config.device) for _ in range(config.n_hidden) ])
|
||||
self.dropout = nn.ModuleList([ nn.Dropout(p=config.dropout) for _ in range(config.n_hidden) ])
|
||||
self.output = nn.Linear(config.d_emb, 1, dtype=config.dtype, device=config.device)
|
||||
|
||||
def forward(self, embs):
|
||||
x = embs
|
||||
for (layer, dropout) in zip(self.hidden, self.dropout):
|
||||
x = F.silu(layer(dropout(x)))
|
||||
return self.output(x)
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.models = nn.ModuleList([ Model(config) for i in range(config.n_ensemble) ])
|
||||
|
||||
# model batch
|
||||
def forward(self, embs):
|
||||
xs = torch.stack([ x(embs[i]) for i, x in enumerate(self.models) ]) # model batch output_dim=1
|
||||
return xs.squeeze(-1)
|
||||
|
||||
class BradleyTerry(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.ensemble = Ensemble(config)
|
||||
|
||||
def forward(self, embs): # model batch input=2 d_emb
|
||||
scores1 = self.ensemble(embs[:, :, 0]).float() # model batch
|
||||
scores2 = self.ensemble(embs[:, :, 1]).float()
|
||||
# win probabilities
|
||||
#print(scores1, scores2)
|
||||
probs = torch.sigmoid(scores1 - scores2) # model batch
|
||||
#print(probs)
|
||||
return probs
|
@ -11,15 +11,22 @@ routes = web.RouteTableDef()
|
||||
|
||||
async def get_pair(db):
|
||||
while True:
|
||||
filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ]
|
||||
m1, m2 = tuple(sorted(random.sample(filenames, 2)))
|
||||
csr = await db.execute("SELECT * FROM queue")
|
||||
row = await csr.fetchone()
|
||||
await csr.close()
|
||||
iteration = None
|
||||
if row:
|
||||
m1, m2, iteration = row
|
||||
else:
|
||||
filenames = [ x[0] for x in await db.execute_fetchall("SELECT filename FROM files", ()) ]
|
||||
m1, m2 = tuple(sorted(random.sample(filenames, 2)))
|
||||
csr = await db.execute("SELECT 1 FROM ratings WHERE meme1 = ? AND meme2 = ?", (m1, m2))
|
||||
if not await csr.fetchone():
|
||||
return m1, m2
|
||||
return m1, m2, iteration
|
||||
|
||||
@routes.get("/")
|
||||
async def index(request):
|
||||
meme1, meme2 = await get_pair(request.app["db"])
|
||||
meme1, meme2, iteration = await get_pair(request.app["db"])
|
||||
return web.Response(text=f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
@ -46,6 +53,7 @@ async def index(request):
|
||||
|
||||
<input type="hidden" name="meme1" value="{meme1}">
|
||||
<input type="hidden" name="meme2" value="{meme2}">
|
||||
<input type="hidden" name="iteration" value="{str(iteration or 0)}">
|
||||
<input type="submit" value="Submit">
|
||||
<div class="memes">
|
||||
<img src="/memes/{meme1}" id="meme1">
|
||||
@ -81,8 +89,10 @@ async def rate(request):
|
||||
post = await request.post()
|
||||
meme1 = post["meme1"]
|
||||
meme2 = post["meme2"]
|
||||
iteration = post["iteration"]
|
||||
rating = post["rating"]
|
||||
await db.execute("INSERT INTO ratings (meme1, meme2, rating) VALUES (?, ?, ?)", (meme1, meme2, rating))
|
||||
await db.execute("INSERT INTO ratings (meme1, meme2, rating, iteration) VALUES (?, ?, ?, ?)", (meme1, meme2, rating, iteration))
|
||||
await db.execute("DELETE FROM queue WHERE meme1 = ? AND meme2 = ?", (meme1, meme2))
|
||||
await db.commit()
|
||||
return web.HTTPFound("/")
|
||||
|
||||
@ -93,8 +103,15 @@ CREATE TABLE IF NOT EXISTS ratings (
|
||||
meme1 TEXT NOT NULL,
|
||||
meme2 TEXT NOT NULL,
|
||||
rating TEXT NOT NULL,
|
||||
iteration TEXT,
|
||||
UNIQUE (meme1, meme2)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS queue (
|
||||
meme1 TEXT NOT NULL,
|
||||
meme2 TEXT NOT NULL,
|
||||
iteration TEXT NOT NULL,
|
||||
UNIQUE (meme1, meme2, iteration)
|
||||
);
|
||||
""")
|
||||
app.router.add_routes(routes)
|
||||
app.router.add_static("/memes/", "./images")
|
||||
|
52
meme-rater/run_graph.py
Normal file
52
meme-rater/run_graph.py
Normal file
@ -0,0 +1,52 @@
|
||||
# claude-3
|
||||
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Read data from log.jsonl
|
||||
data = []
|
||||
with open('log.jsonl', 'r') as file:
|
||||
for line in file:
|
||||
data.append(json.loads(line))
|
||||
|
||||
# Extract steps, loss, and val_loss
|
||||
steps = [entry['step'] for entry in data if "loss" in entry]
|
||||
loss = [entry['loss'] for entry in data if "loss" in entry]
|
||||
val_loss_data = [entry['val_loss'] for entry in data if 'val_loss' in entry]
|
||||
val_steps = [entry['step'] for entry in data if 'val_loss' in entry]
|
||||
|
||||
# Extract individual validation loss series
|
||||
val_loss_series = {}
|
||||
for val_loss in val_loss_data:
|
||||
for key, value in val_loss.items():
|
||||
if key not in val_loss_series:
|
||||
val_loss_series[key] = []
|
||||
val_loss_series[key].append(value)
|
||||
|
||||
# Calculate rolling average for loss
|
||||
window_size = 50
|
||||
rolling_avg = [sum(loss[i:i+window_size])/window_size for i in range(len(loss)-window_size+1)]
|
||||
rolling_steps = steps[window_size-1:]
|
||||
|
||||
# Calculate rolling averages for validation loss series
|
||||
val_rolling_avgs = {}
|
||||
for key, series in val_loss_series.items():
|
||||
val_rolling_avgs[key] = [sum(series[i:i+window_size])/window_size for i in range(len(series)-window_size+1)]
|
||||
|
||||
print([(name, min(series)) for name, series in val_loss_series.items()])
|
||||
|
||||
# Create the plot
|
||||
plt.figure(figsize=(10, 6))
|
||||
#plt.plot(steps, loss, label='Loss')
|
||||
plt.plot(rolling_steps, rolling_avg, label='Rolling Average (Loss)')
|
||||
|
||||
for key, series in val_loss_series.items():
|
||||
#plt.plot(val_steps, series, marker='o', linestyle='', label=f'Validation Loss ({key})')
|
||||
plt.plot(val_steps[window_size-1:], val_rolling_avgs[key], label=f'Rolling Average (Validation Loss {key})')
|
||||
|
||||
plt.xlabel('Steps')
|
||||
plt.ylabel('Loss')
|
||||
plt.title('Loss and Validation Loss vs. Steps')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.show()
|
53
meme-rater/shared.py
Normal file
53
meme-rater/shared.py
Normal file
@ -0,0 +1,53 @@
|
||||
import sqlite3
|
||||
import hashlib
|
||||
from collections import defaultdict
|
||||
import numpy
|
||||
import random
|
||||
|
||||
db = sqlite3.connect("data.sqlite3")
|
||||
db.row_factory = sqlite3.Row
|
||||
|
||||
val_fraction = 0.2
|
||||
def is_val_set(meme1, meme2):
|
||||
def is_one_val(meme):
|
||||
return hashlib.sha256(meme.encode("utf-8")).digest()[0] / 255 < (val_fraction / 2) # not strictly correct but good enough
|
||||
return is_one_val(meme1) or is_one_val(meme2)
|
||||
|
||||
def fetch_embedding(filename):
|
||||
csr = db.execute("SELECT embedding_vector FROM files WHERE filename = ?", (filename,))
|
||||
x = numpy.frombuffer(csr.fetchone()[0], dtype="float16")
|
||||
csr.close()
|
||||
return x.copy() # PyTorch complains otherwise due to bad
|
||||
|
||||
def map_rating(rating, uncertainty=0.05):
|
||||
match rating:
|
||||
case "1": # meme 1 is better
|
||||
return 1 - uncertainty
|
||||
case "2":
|
||||
return uncertainty
|
||||
case _: raise ValueError("invalid rating, please fix")
|
||||
|
||||
def fetch_ratings():
|
||||
trains = defaultdict(list)
|
||||
validations = defaultdict(list)
|
||||
csr = db.execute("SELECT meme1, meme2, rating, iteration FROM ratings")
|
||||
for meme1, meme2, rating, iteration in csr.fetchall():
|
||||
(validations if is_val_set(meme1, meme2) else trains)[int(iteration or "0")].append((fetch_embedding(meme1), fetch_embedding(meme2), map_rating(rating)))
|
||||
csr.close()
|
||||
return list(x[1] for x in sorted(trains.items())), list(x[1] for x in sorted(validations.items()))
|
||||
|
||||
def generate_random_permutations(x, n):
|
||||
out = []
|
||||
for _ in range(n):
|
||||
random.shuffle(x)
|
||||
out.append(x.copy())
|
||||
return out
|
||||
|
||||
def fetch_all_files():
|
||||
csr = db.execute("SELECT filename, embedding_vector FROM files")
|
||||
x = [ (row[0], numpy.frombuffer(row[1], dtype="float16").copy()) for row in csr.fetchall() ]
|
||||
csr.close()
|
||||
return x
|
||||
|
||||
def checkpoint_for(steps):
|
||||
return f"./ckpt/model-{steps}.pt", f"./ckpt/optim-{steps}.pt"
|
129
meme-rater/train.py
Normal file
129
meme-rater/train.py
Normal file
@ -0,0 +1,129 @@
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import sqlite3
|
||||
import random
|
||||
import numpy
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from model import Config as ModelConfig, BradleyTerry
|
||||
import shared
|
||||
|
||||
trains, validations = shared.fetch_ratings()
|
||||
for train, validation in zip(trains, validations):
|
||||
print(len(train), len(validation))
|
||||
|
||||
device = "cuda"
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
model: ModelConfig
|
||||
lr: float
|
||||
weight_decay: float
|
||||
batch_size: int
|
||||
epochs: int
|
||||
compile: bool
|
||||
data_grouped_by_iter: bool
|
||||
|
||||
config = TrainConfig(
|
||||
model=ModelConfig(
|
||||
d_emb=1152,
|
||||
n_hidden=1,
|
||||
n_ensemble=16,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
dropout=0.1
|
||||
),
|
||||
lr=3e-4,
|
||||
weight_decay=0.2,
|
||||
batch_size=1,
|
||||
epochs=5,
|
||||
compile=False,
|
||||
data_grouped_by_iter=False
|
||||
)
|
||||
|
||||
def exprange(min, max, n):
|
||||
lmin, lmax = math.log(min), math.log(max)
|
||||
step = (lmax - lmin) / (n - 1)
|
||||
return (math.exp(lmin + step * i) for i in range(n))
|
||||
|
||||
model = BradleyTerry(config.model)
|
||||
params = sum(p.numel() for p in model.parameters())
|
||||
print(f"{params/1e6:.1f}M parameters")
|
||||
print(model)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
|
||||
|
||||
def train_step(model, batch, real):
|
||||
optimizer.zero_grad()
|
||||
# model batch
|
||||
win_probabilities = model(batch).float()
|
||||
loss = F.binary_cross_entropy(win_probabilities, real)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
loss = loss.detach().cpu().item()
|
||||
return loss
|
||||
|
||||
if config.compile:
|
||||
print("compiling...")
|
||||
train_step = torch.compile(train_step)
|
||||
|
||||
def batch_from_inputs(inputs: list[tuple[numpy.ndarray, numpy.ndarray, float]]):
|
||||
batch_input = torch.stack([
|
||||
torch.stack([ torch.stack((torch.Tensor(emb1).to(config.model.dtype), torch.Tensor(emb2).to(config.model.dtype))) for emb1, emb2, rating in input ])
|
||||
for input in inputs
|
||||
]).to(device)
|
||||
target = torch.stack([ torch.Tensor([ rating for emb1, emb2, rating in input ]) for input in inputs ]).to(device)
|
||||
return batch_input, target
|
||||
|
||||
def evaluate(steps):
|
||||
print("evaluating...")
|
||||
model.eval()
|
||||
results = {"step": steps, "time": time.time(), "val_loss": {}}
|
||||
for vset, validation in enumerate(validations):
|
||||
with torch.no_grad():
|
||||
batch_input, target = batch_from_inputs([ validation[:128] for _ in range(config.model.n_ensemble) ])
|
||||
result = model(batch_input).float()
|
||||
val_loss = F.binary_cross_entropy(result, target).detach().cpu().item()
|
||||
model.train()
|
||||
results["val_loss"][vset] = val_loss
|
||||
log.write(json.dumps(results) + "\n")
|
||||
|
||||
def save_ckpt(log, steps):
|
||||
print("saving...")
|
||||
modelc, optimc = shared.checkpoint_for(steps)
|
||||
torch.save(optimizer.state_dict(), optimc)
|
||||
torch.save(model.state_dict(), modelc)
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, torch.dtype):
|
||||
return str(o)
|
||||
else: return super().default(o)
|
||||
|
||||
logfile = f"logs/log-{time.time()}.jsonl"
|
||||
with open(logfile, "w") as log:
|
||||
steps = 0
|
||||
log.write(JSONEncoder().encode(asdict(config)) + "\n")
|
||||
for epoch in range(config.epochs):
|
||||
for train in (trains if config.data_grouped_by_iter else [[ sample for trainss in trains for sample in trainss ]]):
|
||||
data_orders = shared.generate_random_permutations(train, config.model.n_ensemble)
|
||||
for bstart in range(0, len(train), config.batch_size):
|
||||
batch_input, target = batch_from_inputs([ order[bstart:bstart + config.batch_size] for order in data_orders ])
|
||||
loss = train_step(model, batch_input, target)
|
||||
print(steps, loss)
|
||||
log.write(json.dumps({"loss": loss, "step": steps, "time": time.time()}) + "\n")
|
||||
if steps % 10 == 0:
|
||||
if steps % 250 == 0: save_ckpt(log, steps)
|
||||
loss = evaluate(steps)
|
||||
#print(loss)
|
||||
#best = min(loss, best)
|
||||
steps += 1
|
||||
|
||||
save_ckpt(log, steps)
|
||||
|
||||
print(logfile)
|
Loading…
Reference in New Issue
Block a user