1
0
mirror of https://github.com/zenorogue/hyperrogue.git synced 2024-11-04 21:26:17 +00:00
hyperrogue/devmods/knuth-bendix.cpp
2022-07-12 14:46:31 +02:00

595 lines
16 KiB
C++

#include "../hyper.h"
#include <fstream>
#include <semaphore.h>
namespace hr {
namespace rulegen {
bool set_general(const string& s);
}
struct symbol {
unsigned char sid;
unsigned char eid;
bool operator < (const symbol s1) const { return (int16_t&) (*this) < (int16_t&) s1; }
bool operator == (const symbol s1) const { return (int16_t&) (*this) == (int16_t&) s1; }
bool operator != (const symbol s1) const { return (int16_t&) (*this) != (int16_t&) s1; }
};
using seq = vector<symbol>;
void block(symbol& s) {
s.sid ^= 128;
}
void unblock(symbol& s) {
s.sid ^= 128;
}
bool is_blocked(symbol& s) {
return s.sid & 128;
}
void print(hstream& hs, symbol s) {
print(hs, "[", int(s.sid), ":", int(s.eid), "]");
}
bool reversed;
bool lexless(const seq& a, const seq& b) {
int al = isize(a);
int bl = isize(b);
if(al != bl) return al < bl;
if(reversed) {
for(int i=al-1; i>=0; i--) if(a[i].eid < 120 && a[i] != b[i]) return a[i] < b[i];
return a.back() < b.back();
}
else {
return a < b;
}
}
struct lexless_comparator {
bool operator()(const seq& a, const seq& b) const { return lexless(a, b); }
};
map<seq, seq, lexless_comparator> rules;
seq sub(const seq& s, int a, int l) {
seq result;
for(int i=0; i<l; i++) result.push_back(s[a+i]);
return result;
}
seq cat(const seq& a, const seq& b) {
seq result;
for(auto e: a) result.push_back(e);
for(auto e: b) result.push_back(e);
return result;
}
void catto(seq& a, const seq& b) {
a.reserve(a.size() + b.size());
for(auto e: b) a.push_back(e);
}
void catto_sub(seq& a, const seq& b, int pos, int len) {
a.reserve(a.size() + len);
for(int i=0; i<len; i++) a.push_back(b[pos+i]);
}
int max_rule_length;
set<pair<seq, seq>> rules_checked;
queue<pair<seq, seq>> rulequeue;
void enqueue_rule(seq l, seq r) {
rulequeue.emplace(l, r);
}
cellwalker trace(cellwalker cw, const seq& sq, bool debug) {
for(auto sym: sq) {
if(shvid(cw.at) != sym.sid) println(hlog, "error: wrong type");
if(debug) println(hlog, "at: ", cw);
cw += sym.eid;
cw += wstep;
cw -= cw.spin % arb::current.shapes[shvid(cw.at)].cycle_length;
}
return cw;
}
cell *get_sample(int sid) {
celllister cl(currentmap->gamestart(), 20, 100000, nullptr);
for(cell *c: cl.lst) if(shvid(c) == sid) return c;
println(hlog, "sample not found of ", sid);
exit(1);
}
string verify_rule(const pair<seq, seq>& rule, bool debug = false) {
cell *cstart = get_sample(rule.first[0].sid);
cellwalker cw(cstart);
cellwalker cw1 = trace(cw, rule.first, debug);
cellwalker cw2 = trace(cw, rule.second, debug);
return cw1 == cw2 ? "OK" : lalign(0, "FAIL ", cw1, " VS ", cw2);
}
int add_rule(seq l, seq r) {
indenter ind(2);
if(l == r) return 0;
if(lexless(l, r)) swap(l, r);
again:
if(!rules.count(l)) {
// println(hlog, "rule added: ", l, " -> ", r, " verify: ", verify_rule({l, r}));
rules[l] = r;
return 1;
}
else if(rules[l] == r) return 0;
else if(lexless(r, rules[l])) {
l = rules[l];
goto again;
}
else {
// println(hlog, "rule simplified: ", l, " -> ", r);
auto x = rules[l];
x = r;
l = x;
goto again;
}
}
symbol gen_symbol(int sid, int eid) {
symbol sym;
sym.sid = sid;
sym.eid = eid;
return sym;
}
bool include_rotations = false;
void prepare_rules() {
auto& ac = arb::current;
int N = isize(ac.shapes);
/* move and back and move */
if(!include_rotations) for(int i=0; i<N; i++) {
auto& sh = ac.shapes[i];
int K = sh.size();
for(int k=0; k<K; k++)
for(int l=0; l<K; l++) {
seq sq, sq1;
sq.push_back(gen_symbol(i, k));
auto con = sh.connections[k % sh.cycle_length];
auto& sh1 = ac.shapes[con.sid];
con.eid %= sh1.cycle_length;
sq.push_back(gen_symbol(con.sid, con.eid));
sq.push_back(gen_symbol(i, l));
int k1 = k / sh.cycle_length * sh.cycle_length;
// if(k1)
sq1.push_back(gen_symbol(i, (k1 + l) % K));
enqueue_rule(sq, sq1);
}
}
/* move and back -> nothing or rotate */
if(include_rotations) for(int i=0; i<N; i++) {
auto& sh = ac.shapes[i];
int K = sh.size();
for(int k=0; k<K; k++) {
seq sq, sq1;
sq.push_back(gen_symbol(i, k));
auto con = sh.connections[k % sh.cycle_length];
auto& sh1 = ac.shapes[con.sid];
con.eid %= sh1.cycle_length;
sq.push_back(gen_symbol(con.sid, con.eid));
int k1 = k / sh.cycle_length * sh.cycle_length;
if(k1) sq1.push_back(gen_symbol(i, 120 + k1));
if(k1) sq1.push_back(gen_symbol(255, 255));
if(lexless(sq, sq1)) swap(sq, sq1);
enqueue_rule(sq, sq1);
}
}
/* rotate + move -> move */
if(include_rotations) for(int i=0; i<N; i++) {
auto& sh = ac.shapes[i];
int K = sh.size();
int C = sh.cycle_length;
for(int k=C; k<K; k+=C)
for(int l=0; l<K; l++) {
seq sq, sq1;
sq.push_back(gen_symbol(i, 120+k));
sq.push_back(gen_symbol(255, 255));
sq.push_back(gen_symbol(i, l));
sq1.push_back(gen_symbol(i, (k+l) % K));
if(lexless(sq, sq1)) swap(sq, sq1);
enqueue_rule(sq, sq1);
}
}
/* cycle a vertex */
for(int i=0; i<N; i++) {
auto& sh = ac.shapes[i];
for(int j=0; j<sh.cycle_length; j++) {
int val = sh.vertex_valence[j]; // (j+1) % sh.cycle_length];
int ai = i, aj = j;
seq sq;
// cellwalker cw(get_sample(i), j);
// auto cw0 = cw;
for(int v=0; v<val; v++) {
// println(hlog, "at: ", cw);
/*
cw++;
cw += wstep;
*/
auto& ash = ac.shapes[ai];
aj += 1;
aj %= ash.size();
sq.push_back(gen_symbol(ai, aj));
auto co = ash.connections[aj % ash.cycle_length];
ai = co.sid;
aj = co.eid;
aj %= ac.shapes[ai].cycle_length;
}
// println(hlog, "finish at: ", cw, " from: ", cw0);
aj %= sh.cycle_length;
if(i != ai || j != aj /* || cw != cw0 */) { println(hlog, "bad cycling!"); return; }
enqueue_rule(sq, {});
}
}
for(auto r: rules) println(hlog, r);
}
void final_rules() {
auto& ac = arb::current;
int N = isize(ac.shapes);
for(int i=0; i<N; i++) {
auto& sh = ac.shapes[i];
int K = sh.size();
int C = sh.cycle_length;
for(int k=0; k<K; k++)
for(int l=0; l<K; l+=C) {
seq sq, sq1;
sq.push_back(gen_symbol(i, k));
auto con = sh.connections[k % sh.cycle_length];
auto& sh1 = ac.shapes[con.sid];
con.eid %= sh1.cycle_length;
sq.push_back(gen_symbol(con.sid, con.eid));
sq.push_back(gen_symbol(i, 120 + l));
int k1 = k / sh.cycle_length * sh.cycle_length;
// if(k1)
sq1.push_back(gen_symbol(i, 120 + (k1 + l) % K));
enqueue_rule(sq, sq1);
}
}
}
/* the substring of s starting at pos equals needle */
bool sub_at(const seq& s, int pos, const seq& needle) {
for(int i=0; i<isize(needle); i++)
if(s[i+pos] != needle[i])
return false;
return true;
}
/* the suffix of s1 of length len, and the prefix of s2 of length len, agree */
bool suf_pref_agree(const seq& s1, const seq& s2, int len) {
for(int i=0; i<len; i++)
if(s1[isize(s1)-len+i] != s2[i])
return false;
return true;
}
bool is_reducible(const seq& s, const pair<seq, seq>& r, seq& result) {
int rl = isize(r.first);
int sl = isize(s);
for(int i=0; i<=sl-rl; i++)
if(sub_at(s, i, r.first)) {
result = sub(s, 0, i);
catto(result, r.second);
catto_sub(result, s, i+rl, sl-rl-i);
return true;
}
return false;
}
void find_critical(const pair<seq,seq>& p, const pair<seq,seq>& q) {
int pl = isize(p.first);
int ql = isize(q.first);
for(int i=1; i<pl && i < ql; i++)
if(suf_pref_agree(p.first, q.first, i)) {
seq nleft = sub(p.first, 0, pl-i);
catto(nleft, q.second);
seq nright = p.second;
catto_sub(nright, q.first, i, ql-i);
enqueue_rule(nleft, nright);
}
}
void handle_rule(const pair<seq, seq>& nr) {
auto& lh = nr.first;
auto& rh = nr.second;
for(auto& r: rules) {
seq res;
if(is_reducible(lh, r, res)) {
enqueue_rule(res, rh);
return;
}
if(is_reducible(rh, r, res)) {
enqueue_rule(lh, res);
return;
}
}
vector<seq> to_erase;
for(auto& r: rules) {
seq res;
if(is_reducible(r.first, nr, res)) {
to_erase.push_back(r.first);
enqueue_rule(res, r.second);
}
if(is_reducible(rh, r, res)) {
to_erase.push_back(r.first);
enqueue_rule(r.first, res);
}
}
for(auto s: to_erase) rules.erase(s);
for(auto& r: rules) {
find_critical(r, nr);
find_critical(nr, r);
}
rules[nr.first] = nr.second;
}
int kb_result;
double total_time;
int main_loop(int timelimit = 5, int lenlimit = 300) {
clock_t start = clock();
while(!rulequeue.empty()) {
if(clock() > start + timelimit * CLOCKS_PER_SEC) return 1;
auto p = rulequeue.front();
rulequeue.pop();
if(p.first == p.second) continue;
if(lexless(p.first, p.second)) swap(p.first, p.second);
if(rules_checked.count(p)) continue;
rules_checked.insert(p);
int len = p.first.size() + p.second.size();
if(len > max_rule_length) max_rule_length = len;
if(len > lenlimit) return 2;
handle_rule(p);
}
println(hlog, "finished in ", (clock() - start) * 1. / CLOCKS_PER_SEC);
return 0;
}
void test_knuth_bendix() {
rules.clear();
rulequeue = {};
rules_checked.clear();
max_rule_length = 0;
start_game();
prepare_rules();
clock_t total_start = clock();
println(hlog, "total_start = ", total_start * 1. / CLOCKS_PER_SEC);
kb_result = main_loop(10);
println(hlog, "after first phase = ", clock() * 1. / CLOCKS_PER_SEC);
if(kb_result == 0) {
println(hlog, "intermediate rules:");
for(auto s: rules) println(hlog, s.first, " => ", s.second);
final_rules();
kb_result = main_loop(50, 999999);
if(kb_result) kb_result += 2;
}
println(hlog, "after second phase = ", clock() * 1. / CLOCKS_PER_SEC);
total_time = (clock() - total_start) * 1. / CLOCKS_PER_SEC;
}
bool is_end_reducible(const seq& s1) {
for(auto r: rules)
if(isize(s1) >= isize(r.first) && sub(s1, isize(s1)-isize(r.first), isize(r.first)) == r.first)
return true;
return false;
}
int count_tree_states() {
set<seq> all_prefixes;
for(auto r: rules) {
auto left = r.first;
for(int i=0; i<isize(left); i++) {
seq prefix = sub(left, 0, i);
seq s = prefix;
s.emplace_back(gen_symbol(left[i].eid, 120));
if(!is_end_reducible(s))
all_prefixes.emplace(prefix);
}
}
all_prefixes.erase(seq{});
for(auto x: all_prefixes)
println(hlog, x);
return isize(all_prefixes) + isize(arb::current.shapes);
}
bool forked = false;
int max_children = 7;
void test_all(string setname) {
floorshapes_level = 1;
vector<string> filenames;
std::ifstream is("devmods/rulegen-tests/"+setname+".lst");
string s;
while(getline(is, s)) {
while(s != "" && s[0] == ' ') s = s.substr(1);
if(s != "" && s[0] != '#') filenames.push_back(s);
}
println(hlog, "CSV;kbres;Tp;kbrules;tree;maxlen;file");
int children = 0;
sem_t sem;
if(forked) sem_init(&sem, true, 1);
fflush(stdout);
for(const string& s: filenames) {
if(forked) {
int pid;
if(children >= max_children) {
wait(&pid); children--;
}
if((pid = fork())) children++;
else goto doit;
continue;
}
doit:
if(rulegen::set_general(s)) {
if(!arb::in()) try {
arb::convert::convert();
arb::convert::activate();
}
catch(hr_exception& e) {
println(hlog, "CSV; failed to convert ", s);
return;
}
// println(hlog, "will call test_knuth_bendix");
test_knuth_bendix();
// println(hlog, "after test_knuth_bendix");
if(forked) sem_wait(&sem);
println(hlog, "CSV;",kb_result,";", total_time, ";", isize(rules), ";", kb_result == 0 ? count_tree_states() : 0, ";", max_rule_length, ";", s);
if(forked) sem_post(&sem);
fflush(stdout);
}
if(forked) exit(0);
}
while(children) { int pid; wait(&pid); children--; }
if(forked) sem_destroy(&sem);
}
void kb_merge() {
vector<string> filenames;
std::ifstream is("devmods/rulegen-tests/all.lst");
string s;
while(getline(is, s)) {
while(s != "" && s[0] == ' ') s = s.substr(1);
if(s != "" && s[0] != '#') filenames.push_back(s);
}
map<string, string> cas;
if(std::ifstream is = std::ifstream("devmods/rulegen-tests/kbmerge.txt")) {;
while(getline(is, s)) {
int i = 0, j = 0;
for(char c: s) { i++; if(c == ';') j = i; }
string fname = s.substr(j);
if(cas.count(fname) && cas[fname].substr(0,5) != s.substr(0,5)) {
// println(hlog, "repetition:\n", cas[fname], "\n", s);
s = min(s, cas[fname]);
}
cas[fname] = s;
// println(hlog, "assigned to ", fname, ": ", s);
}
}
println(hlog, "CSV;kbres;Tp;kbrules;tree;maxlen;file");
for(auto ss: filenames)
if(!cas.count(ss))
println(hlog, "MISSING: ", ss);
else
println(hlog, cas[ss]);
}
map<cell*, seq> first_seq_last;
map<cell*, seq> first_seq;
map<cell*, int> qty;
void view_kb_tree(const shiftmatrix& V, cellwalker cw, seq s) {
qty[cw.at]++;
if(cw.at == lmouseover) println(hlog, s, " ; ", cw.spin);
for(int i=0; i<cw.at->type; i++) {
seq s1 = s;
s1.emplace_back(gen_symbol(shvid(cw.at), i));
if(is_end_reducible(s1)) continue;
auto cw1 = cw;
cw1 += i;
int is = cw1.spin;
cw1 += wstep;
if(!gmatrix.count(cw1.at)) continue;
transmatrix A = currentmap->adj(cw.at, is);
int eid = shvid(cw1.at);
int clen = arb::current.shapes[eid].cycle_length;
cw1 -= cw1.spin % clen;
// queueline(V * get_corner_position(cw.at, cw.spin, 6), V * A * get_corner_position(cw1.at, cw1.spin, 6), 0xC00000FF);
seq s2 = s1;
s2.emplace_back(gen_symbol(shvid(cw1.at), 120));
if(!is_end_reducible(s2)) { // first_seq_last[cw1.at] == s1) {
if(cw1.at == lmouseover) println(hlog, "is not end_reducible: ", s2);
vid.linewidth *= 3;
queueline(V * C0, V * mid(tC0(A), mid(C0, tC0(A))), 0xFFC0C0FF, 3);
vid.linewidth /= 3;
}
if(!first_seq.count(cw1.at))
first_seq[cw1.at] = s1;
else
first_seq[cw1.at] = min(s1, first_seq[cw1.at]);
view_kb_tree(V*A, cw1, s1);
}
}
void kb_marker() {
qty.clear();
first_seq_last = std::move(first_seq);
first_seq.clear();
view_kb_tree(gmatrix[cwt.at], cellwalker(cwt.at, 0), {});
for(auto p: qty)
queuestr(gmatrix[p.first], .5, its(p.second), 0xFFFFFFFF, 1);
}
void add_kb_view() {
addHook(hooks_markers, 100, kb_marker);
}
int u = arg::add3("-kb", [] {
test_knuth_bendix();
println(hlog, "result: ", kb_result, " rules: ", isize(rules));
println(hlog, "final rules:");
for(auto s: rules) println(hlog, s.first, " => ", s.second);
if(kb_result == 0) println(hlog, "tree states: ", count_tree_states());
})
+ arg::add3("-kb-test-all", [] { arg::shift(); test_all(arg::args()); })
+ arg::add3("-kb-forked", [] { arg::shift(); max_children = arg::argi(); forked = max_children; })
+ arg::add3("-kb-rev", [] { arg::shift(); reversed = arg::argi(); })
+ arg::add3("-kb-view", add_kb_view)
+ arg::add3("-kb-merge", kb_merge);
}