From bb5fa51965e9c8a6822a11dfb472b8dec35b0d99 Mon Sep 17 00:00:00 2001 From: Zeno Rogue Date: Sun, 23 Oct 2022 16:03:06 +0200 Subject: [PATCH] rogueviz::sag:: prepared for further methods --- rogueviz/sag.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/rogueviz/sag.cpp b/rogueviz/sag.cpp index 4dca7583..9af9cde5 100644 --- a/rogueviz/sag.cpp +++ b/rogueviz/sag.cpp @@ -78,7 +78,11 @@ namespace sag { vector sagedges; vector> edges_yes, edges_no; - int logistic_cost; /* 0 = disable, 1 = enable */ + enum eSagMethod { smClosest, smLogistic }; + eSagMethod method; + + bool logistic_repeat; + dhrg::logistic lgsag(1, 1); vector loglik_tab_y, loglik_tab_n; @@ -232,7 +236,7 @@ namespace sag { if(vid < 0) return 0; double cost = 0; - if(logistic_cost) { + if(method == smLogistic) { auto &s = sagdist[sid]; for(auto j: edges_yes[vid]) cost += loglik_tab_y[s[sagid[j]]]; @@ -491,7 +495,7 @@ namespace sag { lgsag.R = sum1 / sum0; lgsag.T = sqrt((sum2 - sum1*sum1/sum0) / sum0); println(hlog, "automatically set R = ", lgsag.R, " and ", lgsag.T, " max_sag_dist = ", max_sag_dist); - if(logistic_cost) compute_loglik_tab(); + if(method == smLogistic) compute_loglik_tab(); } void optimize_sag_loglik() { @@ -546,7 +550,7 @@ namespace sag { dhrg::fast_loglik_cont(lgsag, logisticf, nullptr, 1, 1e-5); println(hlog, "loglikelihood logistic = ", logisticf(lgsag), " R= ", lgsag.R, " T= ", lgsag.T); - if(logistic_cost) { + if(method == smLogistic) { compute_loglik_tab(); prepare_graph(); println(hlog, "cost = ", cost); @@ -1073,7 +1077,7 @@ int readArgs() { else if(argis("-sagrt")) { shift(); sag::lgsag.R = argf(); shift(); sag::lgsag.T = argf(); - if(sag::logistic_cost) compute_loglik_tab(); + if(method == smLogistic) compute_loglik_tab(); } else if(argis("-sagrt-auto")) { @@ -1081,9 +1085,14 @@ int readArgs() { } else if(argis("-sag_use_loglik")) { - shift(); sag::logistic_cost = argi(); - if(sag::logistic_cost) compute_loglik_tab(); + shift(); int mtd = argi(); + if(mtd == 0) method = smClosest, logistic_repeat = false; + if(mtd == 1) method = smLogistic, logistic_repeat = false; + if(mtd == 2) method = smLogistic, logistic_repeat = true; + if(method == smLogistic) + compute_loglik_tab(); } + else if(argis("-sagminhelp")) { shift_arg_formula(default_edgetype.visible_from_help); } @@ -1190,7 +1199,7 @@ bool turn(int delta) { if(vizsa_start) { auto t = ticks; double d = (t-vizsa_start) / (1000. * vizsa_len); - if(d > 1 && logistic_cost == 2) { + if(d > 1 && method == smLogistic && logistic_repeat) { vizsa_start = ticks; optimize_sag_loglik(); output_stats();