rogueviz::sag:: prepared for further methods

This commit is contained in:
Zeno Rogue 2022-10-23 16:03:06 +02:00
parent d828ff0e7e
commit bb5fa51965
1 changed files with 17 additions and 8 deletions

View File

@ -78,7 +78,11 @@ namespace sag {
vector<edgeinfo> sagedges;
vector<vector<int>> edges_yes, edges_no;
int logistic_cost; /* 0 = disable, 1 = enable */
enum eSagMethod { smClosest, smLogistic };
eSagMethod method;
bool logistic_repeat;
dhrg::logistic lgsag(1, 1);
vector<ld> loglik_tab_y, loglik_tab_n;
@ -232,7 +236,7 @@ namespace sag {
if(vid < 0) return 0;
double cost = 0;
if(logistic_cost) {
if(method == smLogistic) {
auto &s = sagdist[sid];
for(auto j: edges_yes[vid])
cost += loglik_tab_y[s[sagid[j]]];
@ -491,7 +495,7 @@ namespace sag {
lgsag.R = sum1 / sum0;
lgsag.T = sqrt((sum2 - sum1*sum1/sum0) / sum0);
println(hlog, "automatically set R = ", lgsag.R, " and ", lgsag.T, " max_sag_dist = ", max_sag_dist);
if(logistic_cost) compute_loglik_tab();
if(method == smLogistic) compute_loglik_tab();
}
void optimize_sag_loglik() {
@ -546,7 +550,7 @@ namespace sag {
dhrg::fast_loglik_cont(lgsag, logisticf, nullptr, 1, 1e-5);
println(hlog, "loglikelihood logistic = ", logisticf(lgsag), " R= ", lgsag.R, " T= ", lgsag.T);
if(logistic_cost) {
if(method == smLogistic) {
compute_loglik_tab();
prepare_graph();
println(hlog, "cost = ", cost);
@ -1073,7 +1077,7 @@ int readArgs() {
else if(argis("-sagrt")) {
shift(); sag::lgsag.R = argf();
shift(); sag::lgsag.T = argf();
if(sag::logistic_cost) compute_loglik_tab();
if(method == smLogistic) compute_loglik_tab();
}
else if(argis("-sagrt-auto")) {
@ -1081,9 +1085,14 @@ int readArgs() {
}
else if(argis("-sag_use_loglik")) {
shift(); sag::logistic_cost = argi();
if(sag::logistic_cost) compute_loglik_tab();
shift(); int mtd = argi();
if(mtd == 0) method = smClosest, logistic_repeat = false;
if(mtd == 1) method = smLogistic, logistic_repeat = false;
if(mtd == 2) method = smLogistic, logistic_repeat = true;
if(method == smLogistic)
compute_loglik_tab();
}
else if(argis("-sagminhelp")) {
shift_arg_formula(default_edgetype.visible_from_help);
}
@ -1190,7 +1199,7 @@ bool turn(int delta) {
if(vizsa_start) {
auto t = ticks;
double d = (t-vizsa_start) / (1000. * vizsa_len);
if(d > 1 && logistic_cost == 2) {
if(d > 1 && method == smLogistic && logistic_repeat) {
vizsa_start = ticks;
optimize_sag_loglik();
output_stats();