diff --git a/rogueviz/sag.cpp b/rogueviz/sag.cpp index 161a4bb0..5c8f8666 100644 --- a/rogueviz/sag.cpp +++ b/rogueviz/sag.cpp @@ -9,6 +9,7 @@ #include "dhrg/dhrg.h" #include +#include "leastsquare.cpp" namespace rogueviz { @@ -84,7 +85,7 @@ namespace sag { bool loglik_repeat; /* parameters for smMatch */ - ld match_a = 1, match_b = 1; + ld match_a = 1, match_b = 0; /* parameters for smLogistic */ dhrg::logistic lgsag(1, 1); @@ -102,8 +103,6 @@ namespace sag { dhrg::logistic lgemb(1, 1); vector placement; - void optimize_sag_loglik(); - string distance_file; void compute_dists() { @@ -256,7 +255,9 @@ namespace sag { edgeinfo *ei = vd.edges[j].second; int t2 = vd.edges[j].first; if(sagid[t2] != -1) { - ld dist = sagdist[sid][sagid[t2]] - match_a / ei->weight2 - match_b; + ld cdist = sagdist[sid][sagid[t2]]; + ld expect = match_a / ei->weight2 + match_b; + ld dist = cdist - expect; cost += dist * dist; } } @@ -516,7 +517,7 @@ namespace sag { if(method == smLogistic) compute_loglik_tab(); } - void optimize_sag_loglik() { + void optimize_sag_loglik_logistic() { vector indist(max_sag_dist, 0); const int mul = 1; @@ -575,6 +576,29 @@ namespace sag { } } + void optimize_sag_loglik_match() { + lsq::leastsquare_solver<2> lsqs; + + for(auto& ei: sagedges) { + ld y = sagdist[sagid[ei.i]][sagid[ei.j]]; + ld x = 1. / ei.weight; + lsqs.add_data({{x, 1}}, y); + } + + array solution = lsqs.solve(); + match_a = solution[0]; + match_b = solution[1]; + + println(hlog, "got a = ", match_a, " b = ", match_b); + if(method == smMatch) + prepare_graph(); + } + + void optimize_sag_loglik_auto() { + if(method == smLogistic) optimize_sag_loglik_logistic(); + if(method == smMatch) optimize_sag_loglik_match(); + } + void disttable_add(ld dist, int qty0, int qty1) { using namespace dhrg; size_t i = dist * llcont_approx_prec; @@ -1149,6 +1173,12 @@ int readArgs() { if(method == smLogistic) compute_loglik_tab(); } + else if(argis("-sagmatch-ab")) { + shift(); sag::match_a = argf(); + shift(); sag::match_b = argf(); + if(method == smMatch) prepare_graph(); + } + else if(argis("-sagrt-auto")) { compute_auto_rt(); } @@ -1162,6 +1192,7 @@ int readArgs() { if(mtd == 4) method = smMatch, loglik_repeat = true; if(method == smLogistic) compute_loglik_tab(); + if(method == smMatch) prepare_graph(); } else if(argis("-sagminhelp")) { @@ -1239,8 +1270,14 @@ int readArgs() { PHASE(3); shift(); auto_save = args(); } // (6) output loglikelihood - else if(argis("-sagloglik")) { - sag::optimize_sag_loglik(); + else if(argis("-sagloglik-l")) { + sag::optimize_sag_loglik_logistic(); + } + else if(argis("-sagloglik-m")) { + sag::optimize_sag_loglik_match(); + } + else if(argis("-sagloglik-a")) { + sag::optimize_sag_loglik_auto(); } else if(argis("-sagmode")) { shift();