mirror of
https://github.com/zenorogue/hyperrogue.git
synced 2025-01-12 18:30:34 +00:00
optimize_sag_loglik for match
This commit is contained in:
parent
e011f86671
commit
5cdbca3b21
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include "dhrg/dhrg.h"
|
#include "dhrg/dhrg.h"
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include "leastsquare.cpp"
|
||||||
|
|
||||||
namespace rogueviz {
|
namespace rogueviz {
|
||||||
|
|
||||||
@ -84,7 +85,7 @@ namespace sag {
|
|||||||
bool loglik_repeat;
|
bool loglik_repeat;
|
||||||
|
|
||||||
/* parameters for smMatch */
|
/* parameters for smMatch */
|
||||||
ld match_a = 1, match_b = 1;
|
ld match_a = 1, match_b = 0;
|
||||||
|
|
||||||
/* parameters for smLogistic */
|
/* parameters for smLogistic */
|
||||||
dhrg::logistic lgsag(1, 1);
|
dhrg::logistic lgsag(1, 1);
|
||||||
@ -102,8 +103,6 @@ namespace sag {
|
|||||||
dhrg::logistic lgemb(1, 1);
|
dhrg::logistic lgemb(1, 1);
|
||||||
vector<hyperpoint> placement;
|
vector<hyperpoint> placement;
|
||||||
|
|
||||||
void optimize_sag_loglik();
|
|
||||||
|
|
||||||
string distance_file;
|
string distance_file;
|
||||||
|
|
||||||
void compute_dists() {
|
void compute_dists() {
|
||||||
@ -256,7 +255,9 @@ namespace sag {
|
|||||||
edgeinfo *ei = vd.edges[j].second;
|
edgeinfo *ei = vd.edges[j].second;
|
||||||
int t2 = vd.edges[j].first;
|
int t2 = vd.edges[j].first;
|
||||||
if(sagid[t2] != -1) {
|
if(sagid[t2] != -1) {
|
||||||
ld dist = sagdist[sid][sagid[t2]] - match_a / ei->weight2 - match_b;
|
ld cdist = sagdist[sid][sagid[t2]];
|
||||||
|
ld expect = match_a / ei->weight2 + match_b;
|
||||||
|
ld dist = cdist - expect;
|
||||||
cost += dist * dist;
|
cost += dist * dist;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -516,7 +517,7 @@ namespace sag {
|
|||||||
if(method == smLogistic) compute_loglik_tab();
|
if(method == smLogistic) compute_loglik_tab();
|
||||||
}
|
}
|
||||||
|
|
||||||
void optimize_sag_loglik() {
|
void optimize_sag_loglik_logistic() {
|
||||||
vector<int> indist(max_sag_dist, 0);
|
vector<int> indist(max_sag_dist, 0);
|
||||||
|
|
||||||
const int mul = 1;
|
const int mul = 1;
|
||||||
@ -575,6 +576,29 @@ namespace sag {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void optimize_sag_loglik_match() {
|
||||||
|
lsq::leastsquare_solver<2> lsqs;
|
||||||
|
|
||||||
|
for(auto& ei: sagedges) {
|
||||||
|
ld y = sagdist[sagid[ei.i]][sagid[ei.j]];
|
||||||
|
ld x = 1. / ei.weight;
|
||||||
|
lsqs.add_data({{x, 1}}, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
array<ld, 2> solution = lsqs.solve();
|
||||||
|
match_a = solution[0];
|
||||||
|
match_b = solution[1];
|
||||||
|
|
||||||
|
println(hlog, "got a = ", match_a, " b = ", match_b);
|
||||||
|
if(method == smMatch)
|
||||||
|
prepare_graph();
|
||||||
|
}
|
||||||
|
|
||||||
|
void optimize_sag_loglik_auto() {
|
||||||
|
if(method == smLogistic) optimize_sag_loglik_logistic();
|
||||||
|
if(method == smMatch) optimize_sag_loglik_match();
|
||||||
|
}
|
||||||
|
|
||||||
void disttable_add(ld dist, int qty0, int qty1) {
|
void disttable_add(ld dist, int qty0, int qty1) {
|
||||||
using namespace dhrg;
|
using namespace dhrg;
|
||||||
size_t i = dist * llcont_approx_prec;
|
size_t i = dist * llcont_approx_prec;
|
||||||
@ -1149,6 +1173,12 @@ int readArgs() {
|
|||||||
if(method == smLogistic) compute_loglik_tab();
|
if(method == smLogistic) compute_loglik_tab();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if(argis("-sagmatch-ab")) {
|
||||||
|
shift(); sag::match_a = argf();
|
||||||
|
shift(); sag::match_b = argf();
|
||||||
|
if(method == smMatch) prepare_graph();
|
||||||
|
}
|
||||||
|
|
||||||
else if(argis("-sagrt-auto")) {
|
else if(argis("-sagrt-auto")) {
|
||||||
compute_auto_rt();
|
compute_auto_rt();
|
||||||
}
|
}
|
||||||
@ -1162,6 +1192,7 @@ int readArgs() {
|
|||||||
if(mtd == 4) method = smMatch, loglik_repeat = true;
|
if(mtd == 4) method = smMatch, loglik_repeat = true;
|
||||||
if(method == smLogistic)
|
if(method == smLogistic)
|
||||||
compute_loglik_tab();
|
compute_loglik_tab();
|
||||||
|
if(method == smMatch) prepare_graph();
|
||||||
}
|
}
|
||||||
|
|
||||||
else if(argis("-sagminhelp")) {
|
else if(argis("-sagminhelp")) {
|
||||||
@ -1239,8 +1270,14 @@ int readArgs() {
|
|||||||
PHASE(3); shift(); auto_save = args();
|
PHASE(3); shift(); auto_save = args();
|
||||||
}
|
}
|
||||||
// (6) output loglikelihood
|
// (6) output loglikelihood
|
||||||
else if(argis("-sagloglik")) {
|
else if(argis("-sagloglik-l")) {
|
||||||
sag::optimize_sag_loglik();
|
sag::optimize_sag_loglik_logistic();
|
||||||
|
}
|
||||||
|
else if(argis("-sagloglik-m")) {
|
||||||
|
sag::optimize_sag_loglik_match();
|
||||||
|
}
|
||||||
|
else if(argis("-sagloglik-a")) {
|
||||||
|
sag::optimize_sag_loglik_auto();
|
||||||
}
|
}
|
||||||
else if(argis("-sagmode")) {
|
else if(argis("-sagmode")) {
|
||||||
shift();
|
shift();
|
||||||
|
Loading…
Reference in New Issue
Block a user