1
0
mirror of https://github.com/zenorogue/hyperrogue.git synced 2025-05-18 23:24:07 +00:00

rogueviz::sag:: prepared for further methods

This commit is contained in:
Zeno Rogue 2022-10-23 16:03:06 +02:00
parent d828ff0e7e
commit bb5fa51965

View File

@ -78,7 +78,11 @@ namespace sag {
vector<edgeinfo> sagedges; vector<edgeinfo> sagedges;
vector<vector<int>> edges_yes, edges_no; vector<vector<int>> edges_yes, edges_no;
int logistic_cost; /* 0 = disable, 1 = enable */ enum eSagMethod { smClosest, smLogistic };
eSagMethod method;
bool logistic_repeat;
dhrg::logistic lgsag(1, 1); dhrg::logistic lgsag(1, 1);
vector<ld> loglik_tab_y, loglik_tab_n; vector<ld> loglik_tab_y, loglik_tab_n;
@ -232,7 +236,7 @@ namespace sag {
if(vid < 0) return 0; if(vid < 0) return 0;
double cost = 0; double cost = 0;
if(logistic_cost) { if(method == smLogistic) {
auto &s = sagdist[sid]; auto &s = sagdist[sid];
for(auto j: edges_yes[vid]) for(auto j: edges_yes[vid])
cost += loglik_tab_y[s[sagid[j]]]; cost += loglik_tab_y[s[sagid[j]]];
@ -491,7 +495,7 @@ namespace sag {
lgsag.R = sum1 / sum0; lgsag.R = sum1 / sum0;
lgsag.T = sqrt((sum2 - sum1*sum1/sum0) / sum0); lgsag.T = sqrt((sum2 - sum1*sum1/sum0) / sum0);
println(hlog, "automatically set R = ", lgsag.R, " and ", lgsag.T, " max_sag_dist = ", max_sag_dist); println(hlog, "automatically set R = ", lgsag.R, " and ", lgsag.T, " max_sag_dist = ", max_sag_dist);
if(logistic_cost) compute_loglik_tab(); if(method == smLogistic) compute_loglik_tab();
} }
void optimize_sag_loglik() { void optimize_sag_loglik() {
@ -546,7 +550,7 @@ namespace sag {
dhrg::fast_loglik_cont(lgsag, logisticf, nullptr, 1, 1e-5); dhrg::fast_loglik_cont(lgsag, logisticf, nullptr, 1, 1e-5);
println(hlog, "loglikelihood logistic = ", logisticf(lgsag), " R= ", lgsag.R, " T= ", lgsag.T); println(hlog, "loglikelihood logistic = ", logisticf(lgsag), " R= ", lgsag.R, " T= ", lgsag.T);
if(logistic_cost) { if(method == smLogistic) {
compute_loglik_tab(); compute_loglik_tab();
prepare_graph(); prepare_graph();
println(hlog, "cost = ", cost); println(hlog, "cost = ", cost);
@ -1073,7 +1077,7 @@ int readArgs() {
else if(argis("-sagrt")) { else if(argis("-sagrt")) {
shift(); sag::lgsag.R = argf(); shift(); sag::lgsag.R = argf();
shift(); sag::lgsag.T = argf(); shift(); sag::lgsag.T = argf();
if(sag::logistic_cost) compute_loglik_tab(); if(method == smLogistic) compute_loglik_tab();
} }
else if(argis("-sagrt-auto")) { else if(argis("-sagrt-auto")) {
@ -1081,9 +1085,14 @@ int readArgs() {
} }
else if(argis("-sag_use_loglik")) { else if(argis("-sag_use_loglik")) {
shift(); sag::logistic_cost = argi(); shift(); int mtd = argi();
if(sag::logistic_cost) compute_loglik_tab(); if(mtd == 0) method = smClosest, logistic_repeat = false;
if(mtd == 1) method = smLogistic, logistic_repeat = false;
if(mtd == 2) method = smLogistic, logistic_repeat = true;
if(method == smLogistic)
compute_loglik_tab();
} }
else if(argis("-sagminhelp")) { else if(argis("-sagminhelp")) {
shift_arg_formula(default_edgetype.visible_from_help); shift_arg_formula(default_edgetype.visible_from_help);
} }
@ -1190,7 +1199,7 @@ bool turn(int delta) {
if(vizsa_start) { if(vizsa_start) {
auto t = ticks; auto t = ticks;
double d = (t-vizsa_start) / (1000. * vizsa_len); double d = (t-vizsa_start) / (1000. * vizsa_len);
if(d > 1 && logistic_cost == 2) { if(d > 1 && method == smLogistic && logistic_repeat) {
vizsa_start = ticks; vizsa_start = ticks;
optimize_sag_loglik(); optimize_sag_loglik();
output_stats(); output_stats();