mirror of
https://github.com/zenorogue/hyperrogue.git
synced 2025-01-12 10:20:32 +00:00
optimize_sag_loglik for match
This commit is contained in:
parent
e011f86671
commit
5cdbca3b21
@ -9,6 +9,7 @@
|
||||
|
||||
#include "dhrg/dhrg.h"
|
||||
#include <thread>
|
||||
#include "leastsquare.cpp"
|
||||
|
||||
namespace rogueviz {
|
||||
|
||||
@ -84,7 +85,7 @@ namespace sag {
|
||||
bool loglik_repeat;
|
||||
|
||||
/* parameters for smMatch */
|
||||
ld match_a = 1, match_b = 1;
|
||||
ld match_a = 1, match_b = 0;
|
||||
|
||||
/* parameters for smLogistic */
|
||||
dhrg::logistic lgsag(1, 1);
|
||||
@ -102,8 +103,6 @@ namespace sag {
|
||||
dhrg::logistic lgemb(1, 1);
|
||||
vector<hyperpoint> placement;
|
||||
|
||||
void optimize_sag_loglik();
|
||||
|
||||
string distance_file;
|
||||
|
||||
void compute_dists() {
|
||||
@ -256,7 +255,9 @@ namespace sag {
|
||||
edgeinfo *ei = vd.edges[j].second;
|
||||
int t2 = vd.edges[j].first;
|
||||
if(sagid[t2] != -1) {
|
||||
ld dist = sagdist[sid][sagid[t2]] - match_a / ei->weight2 - match_b;
|
||||
ld cdist = sagdist[sid][sagid[t2]];
|
||||
ld expect = match_a / ei->weight2 + match_b;
|
||||
ld dist = cdist - expect;
|
||||
cost += dist * dist;
|
||||
}
|
||||
}
|
||||
@ -516,7 +517,7 @@ namespace sag {
|
||||
if(method == smLogistic) compute_loglik_tab();
|
||||
}
|
||||
|
||||
void optimize_sag_loglik() {
|
||||
void optimize_sag_loglik_logistic() {
|
||||
vector<int> indist(max_sag_dist, 0);
|
||||
|
||||
const int mul = 1;
|
||||
@ -575,6 +576,29 @@ namespace sag {
|
||||
}
|
||||
}
|
||||
|
||||
void optimize_sag_loglik_match() {
|
||||
lsq::leastsquare_solver<2> lsqs;
|
||||
|
||||
for(auto& ei: sagedges) {
|
||||
ld y = sagdist[sagid[ei.i]][sagid[ei.j]];
|
||||
ld x = 1. / ei.weight;
|
||||
lsqs.add_data({{x, 1}}, y);
|
||||
}
|
||||
|
||||
array<ld, 2> solution = lsqs.solve();
|
||||
match_a = solution[0];
|
||||
match_b = solution[1];
|
||||
|
||||
println(hlog, "got a = ", match_a, " b = ", match_b);
|
||||
if(method == smMatch)
|
||||
prepare_graph();
|
||||
}
|
||||
|
||||
void optimize_sag_loglik_auto() {
|
||||
if(method == smLogistic) optimize_sag_loglik_logistic();
|
||||
if(method == smMatch) optimize_sag_loglik_match();
|
||||
}
|
||||
|
||||
void disttable_add(ld dist, int qty0, int qty1) {
|
||||
using namespace dhrg;
|
||||
size_t i = dist * llcont_approx_prec;
|
||||
@ -1149,6 +1173,12 @@ int readArgs() {
|
||||
if(method == smLogistic) compute_loglik_tab();
|
||||
}
|
||||
|
||||
else if(argis("-sagmatch-ab")) {
|
||||
shift(); sag::match_a = argf();
|
||||
shift(); sag::match_b = argf();
|
||||
if(method == smMatch) prepare_graph();
|
||||
}
|
||||
|
||||
else if(argis("-sagrt-auto")) {
|
||||
compute_auto_rt();
|
||||
}
|
||||
@ -1162,6 +1192,7 @@ int readArgs() {
|
||||
if(mtd == 4) method = smMatch, loglik_repeat = true;
|
||||
if(method == smLogistic)
|
||||
compute_loglik_tab();
|
||||
if(method == smMatch) prepare_graph();
|
||||
}
|
||||
|
||||
else if(argis("-sagminhelp")) {
|
||||
@ -1239,8 +1270,14 @@ int readArgs() {
|
||||
PHASE(3); shift(); auto_save = args();
|
||||
}
|
||||
// (6) output loglikelihood
|
||||
else if(argis("-sagloglik")) {
|
||||
sag::optimize_sag_loglik();
|
||||
else if(argis("-sagloglik-l")) {
|
||||
sag::optimize_sag_loglik_logistic();
|
||||
}
|
||||
else if(argis("-sagloglik-m")) {
|
||||
sag::optimize_sag_loglik_match();
|
||||
}
|
||||
else if(argis("-sagloglik-a")) {
|
||||
sag::optimize_sag_loglik_auto();
|
||||
}
|
||||
else if(argis("-sagmode")) {
|
||||
shift();
|
||||
|
Loading…
Reference in New Issue
Block a user