mirror of
				https://github.com/zenorogue/hyperrogue.git
				synced 2025-10-30 13:32:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			328 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			328 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // RogueViz -- SAG embedder: implementation of SAG method evaluation
 | |
| // Copyright (C) 2011-24 Zeno Rogue, see 'hyper.cpp' for details
 | |
| 
 | |
| #include "../rogueviz.h"
 | |
| 
 | |
| #include "../dhrg/dhrg.h"
 | |
| #include "../statistics.cpp"
 | |
| 
 | |
| namespace rogueviz {
 | |
| namespace sag {
 | |
| 
 | |
| enum eSagMethod { smClosest, smLogistic, smMatch };
 | |
| eSagMethod method;
 | |
| 
 | |
| vector<string> method_names = {"closest", "logistic", "match"};
 | |
| 
 | |
| int method_count = 3;
 | |
| 
 | |
| /* parameters for smMatch */
 | |
| ld match_a = 1, match_b = 0;
 | |
| 
 | |
| /* parameters for smLogistic */
 | |
| dhrg::logistic lgsag(1, 1), lgsag_pre(1, 1);
 | |
| 
 | |
| vector<ld> loglik_tab_y, loglik_tab_n;
 | |
| 
 | |
| dhrg::logistic best;
 | |
| 
 | |
| bool opt_debug = false;
 | |
| 
 | |
| bool should_good = false;
 | |
| 
 | |
| double costat(int vid, int sid) {
 | |
|   if(vid < 0) return 0;
 | |
|   double cost = 0;
 | |
| 
 | |
|   switch(method) {
 | |
|     case smLogistic: {
 | |
|       auto s = sagdist[sid];
 | |
|       for(auto j: edges_yes[vid]) if(sagid[j] >= -1)
 | |
|         cost += loglik_tab_y[s[sagid[j]]];
 | |
|       for(auto j: edges_no[vid]) if(sagid[j] >= -1)
 | |
|         cost += loglik_tab_n[s[sagid[j]]];
 | |
|       return -cost;
 | |
|       }
 | |
| 
 | |
|     case smMatch: {
 | |
|       for(auto& e: edge_weights[vid]) {
 | |
|         auto t2 = e.first;
 | |
|         if(sagid[t2] != -1) {
 | |
|           ld cdist = sagdist[sid][sagid[t2]];
 | |
|           ld expect = match_a / e.second + match_b;
 | |
|           ld dist = cdist - expect;
 | |
|           cost += dist * dist;
 | |
|           }
 | |
|         }
 | |
|       return cost;
 | |
|       }
 | |
| 
 | |
|     case smClosest: {
 | |
|       for(auto& e: edge_weights[vid]) {
 | |
|         auto t2 = e.first;
 | |
|         if(sagid[t2] != -1) cost += sagdist[sid][sagid[t2]] * e.second;
 | |
|         }
 | |
|       
 | |
|       if(!hubval.empty()) {
 | |
|         for(auto sid2: neighbors[sid]) {
 | |
|           int vid2 = sagnode[sid2];
 | |
|           if(vid2 >= 0 && (hubval[vid] & hubval[vid]) == 0)
 | |
|             cost += hub_penalty;
 | |
|           }
 | |
|         }
 | |
|       
 | |
|       return cost;
 | |
|       }
 | |
|     }
 | |
| 
 | |
|   throw hr_exception("unknwon SAG method");
 | |
|   }
 | |
| 
 | |
| double cost;
 | |
| 
 | |
| double best_cost = 1000000000;
 | |
| 
 | |
| void compute_cost() {
 | |
|   int DN = isize(sagid);
 | |
|   cost = 0;
 | |
|   for(int i=0; i<DN; i++)
 | |
|     cost += costat(i, sagid[i]);
 | |
|   cost /= 2;
 | |
|   }
 | |
|   
 | |
| void compute_loglik_tab() {
 | |
|   loglik_tab_y.resize(max_sag_dist);
 | |
|   loglik_tab_n.resize(max_sag_dist);
 | |
|   for(int i=0; i<max_sag_dist; i++) {
 | |
|     loglik_tab_y[i] = lgsag.lyes(i);
 | |
|     loglik_tab_n[i] = lgsag.lno(i);
 | |
|     }
 | |
|   }
 | |
| 
 | |
| void compute_auto_rt() {
 | |
|   ld sum0 = 0, sum1 = 0, sum2 = 0;
 | |
| 
 | |
|   for(auto i: sagdist) {
 | |
|     sum0 ++;
 | |
|     sum1 += i;
 | |
|     sum2 += i*i;
 | |
|     }
 | |
| 
 | |
|   lgsag.R = sum1 / sum0;
 | |
|   lgsag.T = sqrt((sum2 - sum1*sum1/sum0) / sum0);
 | |
|   println(hlog, "automatically set R = ", lgsag.R, " and ", lgsag.T, " max_sag_dist = ", max_sag_dist);
 | |
|   if(method == smLogistic) compute_loglik_tab();
 | |
|   }
 | |
| 
 | |
| void optimize_sag_loglik_logistic() {
 | |
|   vector<int> indist(max_sag_dist, 0);
 | |
|   
 | |
|   const int mul = 1;
 | |
| 
 | |
|   int N = isize(sagid);
 | |
|   for(int i=0; i<N; i++)
 | |
|   for(int j=0; j<i; j++) {
 | |
|     int d = sagdist[sagid[i]][sagid[j]];
 | |
|     indist[d]++;
 | |
|     }
 | |
|   
 | |
|   vector<int> pedge(max_sag_dist, 0);
 | |
|     
 | |
|   for(int i=0; i<isize(sagedges); i++) {
 | |
|     edgeinfo& ei = sagedges[i];
 | |
|     // if(int(sagdist[sagid[ei.i]][sagid[ei.j]] * mul) == 136) printf("E %d,%d\n", ei.i, ei.j);
 | |
|     if(ei.i != ei.j)
 | |
|     if(ei.weight >= sag_edge->visible_from)
 | |
|       pedge[sagdist[sagid[ei.i]][sagid[ei.j]] * mul]++;
 | |
|     }
 | |
|   
 | |
|   if(opt_debug) for(int d=0; d<max_sag_dist; d++) 
 | |
|     if(indist[d])
 | |
|       printf("%2d: %7d/%7d %7.3lf\n", 
 | |
|         d, pedge[d], indist[d], double(pedge[d] * 100. / indist[d]));
 | |
|       
 | |
|   ld loglik = 0;
 | |
|   for(int d=0; d<max_sag_dist; d++) {
 | |
|     int p = pedge[d], pq = indist[d];
 | |
|     int q = pq - p;
 | |
|     if(p && q) {
 | |
|       loglik += p * log(p) + q * log(q) - pq * log(pq);
 | |
|       if(opt_debug) println(hlog, tie(d, p, q), loglik);
 | |
|       }
 | |
|     }
 | |
|   
 | |
|   if(opt_debug) println(hlog, "loglikelihood best = ", fts(loglik));
 | |
|   
 | |
|   auto logisticf = [&] (dhrg::logistic&  l) {
 | |
|     ld loglik = 0;
 | |
|     for(int d=0; d<max_sag_dist; d++) {
 | |
|       int p = pedge[d], pq = indist[d];
 | |
|       if(p) loglik += p * l.lyes(d);
 | |
|       if(pq > p) loglik += (pq-p) * l.lno(d);
 | |
|       }
 | |
|     return loglik;
 | |
|     };
 | |
| 
 | |
|   if(opt_debug) println(hlog, "cost = ", cost, " logisticf = ", logisticf(lgsag), " R= ", lgsag.R, " T= ", lgsag.T);
 | |
|   if(should_good && abs(cost + logisticf(lgsag)) > 0.1) throw hr_exception("computation error");
 | |
|   
 | |
|   dhrg::fast_loglik_cont(lgsag, logisticf, nullptr, 1, 1e-5);
 | |
|   if(opt_debug) println(hlog, "loglikelihood logistic = ", logisticf(lgsag), " R= ", lgsag.R, " T= ", lgsag.T);    
 | |
|   
 | |
|   if(method == smLogistic) {
 | |
|     compute_loglik_tab();
 | |
|     compute_cost();
 | |
|     if(opt_debug) println(hlog, "cost = ", cost);
 | |
|     }
 | |
|   }
 | |
| 
 | |
| void optimize_sag_loglik_match() {
 | |
|   if(state &~ SS_WEIGHTED) return;
 | |
|   stats::leastsquare_solver<2> lsqs;
 | |
| 
 | |
|   for(auto& ei: sagedges) {
 | |
|     ld y = sagdist[sagid[ei.i]][sagid[ei.j]];
 | |
|     ld x = 1. / ei.weight;
 | |
|     lsqs.add_data({{x, 1}}, y);
 | |
|     }
 | |
| 
 | |
|   array<ld, 2> solution = lsqs.solve();
 | |
|   match_a = solution[0];
 | |
|   match_b = solution[1];
 | |
| 
 | |
|   println(hlog, "got a = ", match_a, " b = ", match_b);
 | |
|   if(method == smMatch)
 | |
|     prepare_graph();
 | |
|   }
 | |
| 
 | |
| void optimize_sag_loglik_auto() {
 | |
|   if(method == smLogistic) optimize_sag_loglik_logistic();
 | |
|   if(method == smMatch) optimize_sag_loglik_match();
 | |
|   }
 | |
| 
 | |
| pair<ld, ld> compute_mAP() {
 | |
|   int DN = isize(sagid);
 | |
| 
 | |
|   ld meanrank = 0;
 | |
|   int tgood = 0;
 | |
|   ld maprank = 0;
 | |
| 
 | |
|   for(int i=0; i<DN; i++) {
 | |
|     vector<int> alldist(max_sag_dist, 0);
 | |
|     for(int j=0; j<DN; j++) if(i != j) alldist[sagdist[sagid[i]][sagid[j]]]++;
 | |
|     vector<int> edgedist(max_sag_dist, 0);
 | |
|     for(auto j: edges_yes[i]) edgedist[sagdist[sagid[i]][sagid[j]]]++;
 | |
| 
 | |
|     int pgood = 0;
 | |
|     ld bad = 0;
 | |
|     ld ap = 0;
 | |
|     ld pall = 0;
 | |
| 
 | |
|     for(int j=0; j<max_sag_dist; j++) {
 | |
|       ld good = edgedist[j];
 | |
|       ld all = alldist[j];
 | |
|       ld err = all - good;
 | |
| 
 | |
|       bad += err / 2.;
 | |
|       meanrank += bad * good;
 | |
|       bad += err / 2.;
 | |
| 
 | |
|       for(int k=0; k<good; k++) {
 | |
|         pgood++, pall++;
 | |
|         pall += err/2 / good;
 | |
|         ap += pgood / pall;
 | |
|         pall += err/2 / good;
 | |
|         }
 | |
|       if(!good) pall += err;
 | |
|       }
 | |
| 
 | |
|     tgood += pgood;
 | |
|     if(pgood) maprank += ap / pgood;
 | |
|     }
 | |
|   return make_pair(maprank / DN, meanrank / tgood);
 | |
|   }
 | |
| 
 | |
| ld kendall;
 | |
| void compute_kendall() {
 | |
|   compute_cost(); println(hlog, "cost = ", cost);
 | |
|   vector<vector<ld> > weights;
 | |
|   int DN = isize(sagid);
 | |
|   weights.resize(DN);
 | |
|   for(int i=0; i<DN; i++) weights[i].resize(DN, 0);
 | |
|   for(auto& e: sagedges) weights[e.i][e.j] += e.weight2, weights[e.j][e.i] += e.weight2;
 | |
|   vector<pair<int, ld>> kdata;
 | |
|   for(int i=0; i<DN; i++) for(int j=0; j<i; j++) kdata.emplace_back(sagdist[sagid[i]][sagid[j]], -weights[i][j]);
 | |
|   kendall = stats::kendall(kdata);
 | |
|   }
 | |
| 
 | |
| void prepare_method() {
 | |
|   if(method == smLogistic) compute_loglik_tab();
 | |
|   optimize_sag_loglik_auto();
 | |
|   if(method == smClosest) compute_cost();
 | |
|   }
 | |
| 
 | |
| bool known_pairs = false;
 | |
| 
 | |
| int function_read_args() {
 | |
| #if CAP_COMMANDLINE
 | |
|   using namespace arg;
 | |
| 
 | |
|   if(0) ;
 | |
| 
 | |
|   else if(argis("-sagrt")) {
 | |
|     shift(); sag::lgsag.R = argf();
 | |
|     shift(); sag::lgsag.T = argf();
 | |
|     if(method == smLogistic) compute_loglik_tab();
 | |
|     }
 | |
| 
 | |
|   else if(argis("-sagmatch-ab")) {
 | |
|     shift(); sag::match_a = argf();
 | |
|     shift(); sag::match_b = argf();
 | |
|     if(method == smMatch) prepare_graph();
 | |
|     }
 | |
| 
 | |
|   else if(argis("-sagrt-auto")) {
 | |
|     compute_auto_rt();
 | |
|     }
 | |
| 
 | |
|   else if(argis("-sag-kendall")) {
 | |
|     compute_kendall();
 | |
|     println(hlog, "kendall = ", kendall);
 | |
|     }
 | |
| 
 | |
|   else if(argis("-sag-method-closest")) {  
 | |
|     method = smClosest; prepare_method();
 | |
|     }
 | |
| 
 | |
|   else if(argis("-sag-method-logistic")) {  
 | |
|     method = smLogistic; prepare_method();
 | |
|     }
 | |
| 
 | |
|   else if(argis("-sag-method-match")) {  
 | |
|     method = smMatch; prepare_method();
 | |
|     }
 | |
| 
 | |
|   else if(argis("-sag-logistic-recalc")) {
 | |
|     method = smLogistic; prepare_method();
 | |
|     }
 | |
| 
 | |
|   else if(argis("-sagloglik-l")) {
 | |
|     sag::optimize_sag_loglik_logistic();
 | |
|     }
 | |
|   else if(argis("-sagloglik-m")) {
 | |
|     sag::optimize_sag_loglik_match();
 | |
|     }
 | |
|   else if(argis("-sagloglik-a")) {
 | |
|     sag::optimize_sag_loglik_auto();
 | |
|     }
 | |
| 
 | |
|   else return 1;  
 | |
| #endif
 | |
|   return 0;
 | |
|   }
 | |
| 
 | |
| int ahfun = addHook(hooks_args, 100, function_read_args);
 | |
| 
 | |
| }
 | |
| }
 | 
