mirror of
https://github.com/osmarks/meme-search-engine.git
synced 2024-11-10 22:09:54 +00:00
19 lines
665 B
Python
19 lines
665 B
Python
|
import numpy
|
||
|
import xgboost as xgb
|
||
|
|
||
|
import shared
|
||
|
|
||
|
trains, validations = shared.fetch_ratings()
|
||
|
|
||
|
ranker = xgb.XGBRanker(
|
||
|
tree_method="hist",
|
||
|
lambdarank_num_pair_per_sample=8,
|
||
|
objective="rank:ndcg",
|
||
|
lambdarank_pair_method="topk",
|
||
|
device="cuda"
|
||
|
)
|
||
|
flat_samples = [ sample for trainss in trains for sample in trainss ]
|
||
|
X = numpy.concatenate([ numpy.stack((meme1, meme2)) for meme1, meme2, rating in flat_samples ])
|
||
|
Y = numpy.concatenate([ numpy.stack((int(rating), int(1 - rating))) for meme1, meme2, rating in flat_samples ])
|
||
|
qid = numpy.concatenate([ numpy.stack((i, i)) for i in range(len(flat_samples)) ])
|
||
|
ranker.fit(X, Y, qid=qid, verbose=True)
|