1
0
mirror of https://github.com/zenorogue/hyperrogue.git synced 2024-06-16 10:19:58 +00:00

optimize_sag_loglik for match

This commit is contained in:
Zeno Rogue 2022-10-23 17:32:28 +02:00
parent e011f86671
commit 5cdbca3b21

View File

@ -9,6 +9,7 @@
#include "dhrg/dhrg.h"
#include <thread>
#include "leastsquare.cpp"
namespace rogueviz {
@ -84,7 +85,7 @@ namespace sag {
bool loglik_repeat;
/* parameters for smMatch */
ld match_a = 1, match_b = 1;
ld match_a = 1, match_b = 0;
/* parameters for smLogistic */
dhrg::logistic lgsag(1, 1);
@ -102,8 +103,6 @@ namespace sag {
dhrg::logistic lgemb(1, 1);
vector<hyperpoint> placement;
void optimize_sag_loglik();
string distance_file;
void compute_dists() {
@ -256,7 +255,9 @@ namespace sag {
edgeinfo *ei = vd.edges[j].second;
int t2 = vd.edges[j].first;
if(sagid[t2] != -1) {
ld dist = sagdist[sid][sagid[t2]] - match_a / ei->weight2 - match_b;
ld cdist = sagdist[sid][sagid[t2]];
ld expect = match_a / ei->weight2 + match_b;
ld dist = cdist - expect;
cost += dist * dist;
}
}
@ -516,7 +517,7 @@ namespace sag {
if(method == smLogistic) compute_loglik_tab();
}
void optimize_sag_loglik() {
void optimize_sag_loglik_logistic() {
vector<int> indist(max_sag_dist, 0);
const int mul = 1;
@ -575,6 +576,29 @@ namespace sag {
}
}
void optimize_sag_loglik_match() {
lsq::leastsquare_solver<2> lsqs;
for(auto& ei: sagedges) {
ld y = sagdist[sagid[ei.i]][sagid[ei.j]];
ld x = 1. / ei.weight;
lsqs.add_data({{x, 1}}, y);
}
array<ld, 2> solution = lsqs.solve();
match_a = solution[0];
match_b = solution[1];
println(hlog, "got a = ", match_a, " b = ", match_b);
if(method == smMatch)
prepare_graph();
}
void optimize_sag_loglik_auto() {
if(method == smLogistic) optimize_sag_loglik_logistic();
if(method == smMatch) optimize_sag_loglik_match();
}
void disttable_add(ld dist, int qty0, int qty1) {
using namespace dhrg;
size_t i = dist * llcont_approx_prec;
@ -1149,6 +1173,12 @@ int readArgs() {
if(method == smLogistic) compute_loglik_tab();
}
else if(argis("-sagmatch-ab")) {
shift(); sag::match_a = argf();
shift(); sag::match_b = argf();
if(method == smMatch) prepare_graph();
}
else if(argis("-sagrt-auto")) {
compute_auto_rt();
}
@ -1162,6 +1192,7 @@ int readArgs() {
if(mtd == 4) method = smMatch, loglik_repeat = true;
if(method == smLogistic)
compute_loglik_tab();
if(method == smMatch) prepare_graph();
}
else if(argis("-sagminhelp")) {
@ -1239,8 +1270,14 @@ int readArgs() {
PHASE(3); shift(); auto_save = args();
}
// (6) output loglikelihood
else if(argis("-sagloglik")) {
sag::optimize_sag_loglik();
else if(argis("-sagloglik-l")) {
sag::optimize_sag_loglik_logistic();
}
else if(argis("-sagloglik-m")) {
sag::optimize_sag_loglik_match();
}
else if(argis("-sagloglik-a")) {
sag::optimize_sag_loglik_auto();
}
else if(argis("-sagmode")) {
shift();