1
0
mirror of https://github.com/zenorogue/hyperrogue.git synced 2025-01-12 18:30:34 +00:00

optimize_sag_loglik for match

This commit is contained in:
Zeno Rogue 2022-10-23 17:32:28 +02:00
parent e011f86671
commit 5cdbca3b21

View File

@ -9,6 +9,7 @@
#include "dhrg/dhrg.h" #include "dhrg/dhrg.h"
#include <thread> #include <thread>
#include "leastsquare.cpp"
namespace rogueviz { namespace rogueviz {
@ -84,7 +85,7 @@ namespace sag {
bool loglik_repeat; bool loglik_repeat;
/* parameters for smMatch */ /* parameters for smMatch */
ld match_a = 1, match_b = 1; ld match_a = 1, match_b = 0;
/* parameters for smLogistic */ /* parameters for smLogistic */
dhrg::logistic lgsag(1, 1); dhrg::logistic lgsag(1, 1);
@ -102,8 +103,6 @@ namespace sag {
dhrg::logistic lgemb(1, 1); dhrg::logistic lgemb(1, 1);
vector<hyperpoint> placement; vector<hyperpoint> placement;
void optimize_sag_loglik();
string distance_file; string distance_file;
void compute_dists() { void compute_dists() {
@ -256,7 +255,9 @@ namespace sag {
edgeinfo *ei = vd.edges[j].second; edgeinfo *ei = vd.edges[j].second;
int t2 = vd.edges[j].first; int t2 = vd.edges[j].first;
if(sagid[t2] != -1) { if(sagid[t2] != -1) {
ld dist = sagdist[sid][sagid[t2]] - match_a / ei->weight2 - match_b; ld cdist = sagdist[sid][sagid[t2]];
ld expect = match_a / ei->weight2 + match_b;
ld dist = cdist - expect;
cost += dist * dist; cost += dist * dist;
} }
} }
@ -516,7 +517,7 @@ namespace sag {
if(method == smLogistic) compute_loglik_tab(); if(method == smLogistic) compute_loglik_tab();
} }
void optimize_sag_loglik() { void optimize_sag_loglik_logistic() {
vector<int> indist(max_sag_dist, 0); vector<int> indist(max_sag_dist, 0);
const int mul = 1; const int mul = 1;
@ -575,6 +576,29 @@ namespace sag {
} }
} }
void optimize_sag_loglik_match() {
lsq::leastsquare_solver<2> lsqs;
for(auto& ei: sagedges) {
ld y = sagdist[sagid[ei.i]][sagid[ei.j]];
ld x = 1. / ei.weight;
lsqs.add_data({{x, 1}}, y);
}
array<ld, 2> solution = lsqs.solve();
match_a = solution[0];
match_b = solution[1];
println(hlog, "got a = ", match_a, " b = ", match_b);
if(method == smMatch)
prepare_graph();
}
void optimize_sag_loglik_auto() {
if(method == smLogistic) optimize_sag_loglik_logistic();
if(method == smMatch) optimize_sag_loglik_match();
}
void disttable_add(ld dist, int qty0, int qty1) { void disttable_add(ld dist, int qty0, int qty1) {
using namespace dhrg; using namespace dhrg;
size_t i = dist * llcont_approx_prec; size_t i = dist * llcont_approx_prec;
@ -1149,6 +1173,12 @@ int readArgs() {
if(method == smLogistic) compute_loglik_tab(); if(method == smLogistic) compute_loglik_tab();
} }
else if(argis("-sagmatch-ab")) {
shift(); sag::match_a = argf();
shift(); sag::match_b = argf();
if(method == smMatch) prepare_graph();
}
else if(argis("-sagrt-auto")) { else if(argis("-sagrt-auto")) {
compute_auto_rt(); compute_auto_rt();
} }
@ -1162,6 +1192,7 @@ int readArgs() {
if(mtd == 4) method = smMatch, loglik_repeat = true; if(mtd == 4) method = smMatch, loglik_repeat = true;
if(method == smLogistic) if(method == smLogistic)
compute_loglik_tab(); compute_loglik_tab();
if(method == smMatch) prepare_graph();
} }
else if(argis("-sagminhelp")) { else if(argis("-sagminhelp")) {
@ -1239,8 +1270,14 @@ int readArgs() {
PHASE(3); shift(); auto_save = args(); PHASE(3); shift(); auto_save = args();
} }
// (6) output loglikelihood // (6) output loglikelihood
else if(argis("-sagloglik")) { else if(argis("-sagloglik-l")) {
sag::optimize_sag_loglik(); sag::optimize_sag_loglik_logistic();
}
else if(argis("-sagloglik-m")) {
sag::optimize_sag_loglik_match();
}
else if(argis("-sagloglik-a")) {
sag::optimize_sag_loglik_auto();
} }
else if(argis("-sagmode")) { else if(argis("-sagmode")) {
shift(); shift();