1
0
mirror of https://github.com/zenorogue/hyperrogue.git synced 2025-02-03 12:49:17 +00:00

rogueviz::sag:: smMatch method

This commit is contained in:
Zeno Rogue 2022-10-23 16:09:32 +02:00
parent bb5fa51965
commit 5a19fffd8f

View File

@ -78,11 +78,15 @@ namespace sag {
vector<edgeinfo> sagedges; vector<edgeinfo> sagedges;
vector<vector<int>> edges_yes, edges_no; vector<vector<int>> edges_yes, edges_no;
enum eSagMethod { smClosest, smLogistic }; enum eSagMethod { smClosest, smLogistic, smMatch };
eSagMethod method; eSagMethod method;
bool logistic_repeat; bool loglik_repeat;
/* parameters for smMatch */
ld match_a = 1, match_b = 1;
/* parameters for smLogistic */
dhrg::logistic lgsag(1, 1); dhrg::logistic lgsag(1, 1);
vector<ld> loglik_tab_y, loglik_tab_n; vector<ld> loglik_tab_y, loglik_tab_n;
@ -228,6 +232,7 @@ namespace sag {
for(int i=0; i<N; i++) ids[sagcells[i]] = i; for(int i=0; i<N; i++) ids[sagcells[i]] = i;
} }
/* separate hubs -- only for smClosest */
ld hub_penalty; ld hub_penalty;
string hub_filename; string hub_filename;
vector<int> hubval; vector<int> hubval;
@ -245,6 +250,19 @@ namespace sag {
return -cost; return -cost;
} }
if(method == smMatch) {
vertexdata& vd = vdata[vid];
for(int j=0; j<isize(vd.edges); j++) {
edgeinfo *ei = vd.edges[j].second;
int t2 = vd.edges[j].first;
if(sagid[t2] != -1) {
ld dist = sagdist[sid][sagid[t2]] - match_a / ei->weight2 - match_b;
cost += dist * dist;
}
}
return cost;
}
vertexdata& vd = vdata[vid]; vertexdata& vd = vdata[vid];
for(int j=0; j<isize(vd.edges); j++) { for(int j=0; j<isize(vd.edges); j++) {
edgeinfo *ei = vd.edges[j].second; edgeinfo *ei = vd.edges[j].second;
@ -1086,9 +1104,11 @@ int readArgs() {
else if(argis("-sag_use_loglik")) { else if(argis("-sag_use_loglik")) {
shift(); int mtd = argi(); shift(); int mtd = argi();
if(mtd == 0) method = smClosest, logistic_repeat = false; if(mtd == 0) method = smClosest, loglik_repeat = false;
if(mtd == 1) method = smLogistic, logistic_repeat = false; if(mtd == 1) method = smLogistic, loglik_repeat = false;
if(mtd == 2) method = smLogistic, logistic_repeat = true; if(mtd == 2) method = smLogistic, loglik_repeat = true;
if(mtd == 3) method = smMatch, loglik_repeat = false;
if(mtd == 4) method = smMatch, loglik_repeat = true;
if(method == smLogistic) if(method == smLogistic)
compute_loglik_tab(); compute_loglik_tab();
} }
@ -1199,7 +1219,7 @@ bool turn(int delta) {
if(vizsa_start) { if(vizsa_start) {
auto t = ticks; auto t = ticks;
double d = (t-vizsa_start) / (1000. * vizsa_len); double d = (t-vizsa_start) / (1000. * vizsa_len);
if(d > 1 && method == smLogistic && logistic_repeat) { if(d > 1 && loglik_repeat) {
vizsa_start = ticks; vizsa_start = ticks;
optimize_sag_loglik(); optimize_sag_loglik();
output_stats(); output_stats();