mirror of
				https://github.com/zenorogue/hyperrogue.git
				synced 2025-10-30 21:42:59 +00:00 
			
		
		
		
	optimize_sag_loglik for match
This commit is contained in:
		| @@ -9,6 +9,7 @@ | |||||||
|  |  | ||||||
| #include "dhrg/dhrg.h" | #include "dhrg/dhrg.h" | ||||||
| #include <thread> | #include <thread> | ||||||
|  | #include "leastsquare.cpp" | ||||||
|  |  | ||||||
| namespace rogueviz { | namespace rogueviz { | ||||||
|  |  | ||||||
| @@ -84,7 +85,7 @@ namespace sag { | |||||||
|   bool loglik_repeat; |   bool loglik_repeat; | ||||||
|  |  | ||||||
|   /* parameters for smMatch */ |   /* parameters for smMatch */ | ||||||
|   ld match_a = 1, match_b = 1; |   ld match_a = 1, match_b = 0; | ||||||
|  |  | ||||||
|   /* parameters for smLogistic */ |   /* parameters for smLogistic */ | ||||||
|   dhrg::logistic lgsag(1, 1); |   dhrg::logistic lgsag(1, 1); | ||||||
| @@ -102,8 +103,6 @@ namespace sag { | |||||||
|   dhrg::logistic lgemb(1, 1); |   dhrg::logistic lgemb(1, 1); | ||||||
|   vector<hyperpoint> placement;   |   vector<hyperpoint> placement;   | ||||||
|  |  | ||||||
|   void optimize_sag_loglik(); |  | ||||||
|  |  | ||||||
|   string distance_file; |   string distance_file; | ||||||
|  |  | ||||||
|   void compute_dists() { |   void compute_dists() { | ||||||
| @@ -256,7 +255,9 @@ namespace sag { | |||||||
|         edgeinfo *ei = vd.edges[j].second; |         edgeinfo *ei = vd.edges[j].second; | ||||||
|         int t2 = vd.edges[j].first; |         int t2 = vd.edges[j].first; | ||||||
|         if(sagid[t2] != -1) { |         if(sagid[t2] != -1) { | ||||||
|           ld dist = sagdist[sid][sagid[t2]] - match_a / ei->weight2 - match_b; |           ld cdist = sagdist[sid][sagid[t2]]; | ||||||
|  |           ld expect = match_a / ei->weight2 + match_b; | ||||||
|  |           ld dist = cdist - expect; | ||||||
|           cost += dist * dist; |           cost += dist * dist; | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
| @@ -516,7 +517,7 @@ namespace sag { | |||||||
|     if(method == smLogistic) compute_loglik_tab(); |     if(method == smLogistic) compute_loglik_tab(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   void optimize_sag_loglik() { |   void optimize_sag_loglik_logistic() { | ||||||
|     vector<int> indist(max_sag_dist, 0); |     vector<int> indist(max_sag_dist, 0); | ||||||
|      |      | ||||||
|     const int mul = 1; |     const int mul = 1; | ||||||
| @@ -575,6 +576,29 @@ namespace sag { | |||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |   void optimize_sag_loglik_match() { | ||||||
|  |     lsq::leastsquare_solver<2> lsqs; | ||||||
|  |  | ||||||
|  |     for(auto& ei: sagedges) { | ||||||
|  |       ld y = sagdist[sagid[ei.i]][sagid[ei.j]]; | ||||||
|  |       ld x = 1. / ei.weight; | ||||||
|  |       lsqs.add_data({{x, 1}}, y); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |     array<ld, 2> solution = lsqs.solve(); | ||||||
|  |     match_a = solution[0]; | ||||||
|  |     match_b = solution[1]; | ||||||
|  |  | ||||||
|  |     println(hlog, "got a = ", match_a, " b = ", match_b); | ||||||
|  |     if(method == smMatch) | ||||||
|  |       prepare_graph(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |   void optimize_sag_loglik_auto() { | ||||||
|  |     if(method == smLogistic) optimize_sag_loglik_logistic(); | ||||||
|  |     if(method == smMatch) optimize_sag_loglik_match(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|   void disttable_add(ld dist, int qty0, int qty1) { |   void disttable_add(ld dist, int qty0, int qty1) { | ||||||
|     using namespace dhrg; |     using namespace dhrg; | ||||||
|     size_t i = dist * llcont_approx_prec; |     size_t i = dist * llcont_approx_prec; | ||||||
| @@ -1149,6 +1173,12 @@ int readArgs() { | |||||||
|     if(method == smLogistic) compute_loglik_tab(); |     if(method == smLogistic) compute_loglik_tab(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |   else if(argis("-sagmatch-ab")) { | ||||||
|  |     shift(); sag::match_a = argf(); | ||||||
|  |     shift(); sag::match_b = argf(); | ||||||
|  |     if(method == smMatch) prepare_graph(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|   else if(argis("-sagrt-auto")) { |   else if(argis("-sagrt-auto")) { | ||||||
|     compute_auto_rt(); |     compute_auto_rt(); | ||||||
|     } |     } | ||||||
| @@ -1162,6 +1192,7 @@ int readArgs() { | |||||||
|     if(mtd == 4) method = smMatch, loglik_repeat = true; |     if(mtd == 4) method = smMatch, loglik_repeat = true; | ||||||
|     if(method == smLogistic) |     if(method == smLogistic) | ||||||
|       compute_loglik_tab(); |       compute_loglik_tab(); | ||||||
|  |     if(method == smMatch) prepare_graph(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   else if(argis("-sagminhelp")) { |   else if(argis("-sagminhelp")) { | ||||||
| @@ -1239,8 +1270,14 @@ int readArgs() { | |||||||
|     PHASE(3); shift(); auto_save = args(); |     PHASE(3); shift(); auto_save = args(); | ||||||
|     } |     } | ||||||
| // (6) output loglikelihood | // (6) output loglikelihood | ||||||
|   else if(argis("-sagloglik")) { |   else if(argis("-sagloglik-l")) { | ||||||
|     sag::optimize_sag_loglik(); |     sag::optimize_sag_loglik_logistic(); | ||||||
|  |     } | ||||||
|  |   else if(argis("-sagloglik-m")) { | ||||||
|  |     sag::optimize_sag_loglik_match(); | ||||||
|  |     } | ||||||
|  |   else if(argis("-sagloglik-a")) { | ||||||
|  |     sag::optimize_sag_loglik_auto(); | ||||||
|     } |     } | ||||||
|   else if(argis("-sagmode")) { |   else if(argis("-sagmode")) { | ||||||
|     shift(); |     shift(); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Zeno Rogue
					Zeno Rogue