rogueviz::sag:: smMatch method

This commit is contained in:
Zeno Rogue 2022-10-23 16:09:32 +02:00
parent bb5fa51965
commit 5a19fffd8f
1 changed files with 26 additions and 6 deletions

View File

@ -78,11 +78,15 @@ namespace sag {
vector<edgeinfo> sagedges;
vector<vector<int>> edges_yes, edges_no;
enum eSagMethod { smClosest, smLogistic };
enum eSagMethod { smClosest, smLogistic, smMatch };
eSagMethod method;
bool logistic_repeat;
bool loglik_repeat;
/* parameters for smMatch */
ld match_a = 1, match_b = 1;
/* parameters for smLogistic */
dhrg::logistic lgsag(1, 1);
vector<ld> loglik_tab_y, loglik_tab_n;
@ -228,6 +232,7 @@ namespace sag {
for(int i=0; i<N; i++) ids[sagcells[i]] = i;
}
/* separate hubs -- only for smClosest */
ld hub_penalty;
string hub_filename;
vector<int> hubval;
@ -245,6 +250,19 @@ namespace sag {
return -cost;
}
if(method == smMatch) {
vertexdata& vd = vdata[vid];
for(int j=0; j<isize(vd.edges); j++) {
edgeinfo *ei = vd.edges[j].second;
int t2 = vd.edges[j].first;
if(sagid[t2] != -1) {
ld dist = sagdist[sid][sagid[t2]] - match_a / ei->weight2 - match_b;
cost += dist * dist;
}
}
return cost;
}
vertexdata& vd = vdata[vid];
for(int j=0; j<isize(vd.edges); j++) {
edgeinfo *ei = vd.edges[j].second;
@ -1086,9 +1104,11 @@ int readArgs() {
else if(argis("-sag_use_loglik")) {
shift(); int mtd = argi();
if(mtd == 0) method = smClosest, logistic_repeat = false;
if(mtd == 1) method = smLogistic, logistic_repeat = false;
if(mtd == 2) method = smLogistic, logistic_repeat = true;
if(mtd == 0) method = smClosest, loglik_repeat = false;
if(mtd == 1) method = smLogistic, loglik_repeat = false;
if(mtd == 2) method = smLogistic, loglik_repeat = true;
if(mtd == 3) method = smMatch, loglik_repeat = false;
if(mtd == 4) method = smMatch, loglik_repeat = true;
if(method == smLogistic)
compute_loglik_tab();
}
@ -1199,7 +1219,7 @@ bool turn(int delta) {
if(vizsa_start) {
auto t = ticks;
double d = (t-vizsa_start) / (1000. * vizsa_len);
if(d > 1 && method == smLogistic && logistic_repeat) {
if(d > 1 && loglik_repeat) {
vizsa_start = ticks;
optimize_sag_loglik();
output_stats();