1
0
mirror of https://github.com/zenorogue/hyperrogue.git synced 2025-01-12 02:10:34 +00:00

further improvements to Kohonen

This commit is contained in:
Zeno Rogue 2017-09-01 22:13:41 +02:00
parent b0489b7312
commit 951566399e
3 changed files with 332 additions and 47 deletions

View File

@ -1524,6 +1524,7 @@ void addauraspecial(const hyperpoint& h, int col, int dir);
void drawBug(const cellwalker& cw, int col); void drawBug(const cellwalker& cw, int col);
void mainloop(); void mainloop();
void mainloopiter();
extern bool showstartmenu; extern bool showstartmenu;
void selectLanguageScreen(); void selectLanguageScreen();

View File

@ -30,6 +30,8 @@ struct neuron {
int samples, csample, bestsample; int samples, csample, bestsample;
}; };
vector<string> colnames;
kohvec weights; kohvec weights;
vector<neuron> net; vector<neuron> net;
@ -95,6 +97,8 @@ void loadsamples(const char *fname) {
fclose(f); fclose(f);
samples = size(data); samples = size(data);
normalize(); normalize();
colnames.resize(cols);
for(int i=0; i<cols; i++) colnames[i] = "Column " + its(i);
uninit(0); sominit(1); uninit(0); sominit(1);
} }
@ -178,7 +182,7 @@ void coloring() {
vector<double> listing; vector<double> listing;
for(neuron& n: net) switch(c) { for(neuron& n: net) switch(c) {
case -4: case -4:
listing.push_back(n.samples); listing.push_back(log(5+n.samples));
break; break;
case -3: case -3:
@ -477,6 +481,42 @@ void uninit(int initto) {
if(inited > initto) inited = initto; if(inited > initto) inited = initto;
} }
void showsample(int id) {
for(int ii: samples_shown)
if(ii == id)
return;
int i = vdata.size();
samples_shown.push_back(id);
vdata.emplace_back();
auto& v = vdata.back();
v.name = data[id].name;
v.cp = dftcolor;
createViz(i, cwt.c, Id);
v.m->store();
}
void showsample(string s) {
if(s == "") return;
for(int i=0; i<samples; i++) {
if(s[0] != '*' && data[i].name == s)
showsample(i);
if(s[0] == '*' && data[i].name.find(s.substr(1)) != string::npos)
showsample(i);
}
}
void showbestsamples() {
vector<int> samplesbak;
for(auto& n: net)
if(n.samples)
showsample(n.bestsample);
analyze();
for(auto& n: net) n.samples = 0;
for(int i=0; i<samples; i++)
if(whowon[i])
whowon[i]->samples++;
}
void sominit(int initto) { void sominit(int initto) {
if(inited < 1 && initto >= 1) { if(inited < 1 && initto >= 1) {
@ -563,19 +603,151 @@ void describe(cell *c) {
if(cmode & sm::HELP) return; if(cmode & sm::HELP) return;
neuron *n = getNeuronSlow(c); neuron *n = getNeuronSlow(c);
if(!n) return; if(!n) return;
help += "cell number: " + its(neuronId(*n)) + "\n"; help += "cell number: " + its(neuronId(*n)) + " (" + its(n->samples) + ")\n";
help += "parameters:"; for(int k=0; k<cols; k++) help += " " + fts(n->net[k]); help += "parameters:"; for(int k=0; k<cols; k++) help += " " + fts(n->net[k]);
help += ", u-matrix = " + fts(n->udist); help += ", u-matrix = " + fts(n->udist);
help += "\n"; help += "\n";
int qty = 0; vector<pair<double, int>> v;
for(int s=0; s<samples; s++) if(whowon[s] == n) { for(int s=0; s<samples; s++) if(whowon[s] == n) v.emplace_back(vnorm(n->net, data[s].val), s);
random_shuffle(v.begin(), v.end());
sort(v.begin(), v.end(), [] (auto a, auto b) { return a.first < b.first; });
for(int i=0; i<size(v) && i<20; i++) {
int s = v[i].second;
help += "sample "+its(s)+":"; help += "sample "+its(s)+":";
for(int k=0; k<cols; k++) help += " " + fts(data[s].val[k]); for(int k=0; k<cols; k++) help += " " + fts(data[s].val[k]);
help += " "; help += data[s].name; help += "\n"; help += " "; help += data[s].name; help += "\n";
qty++; if(qty >= 20) break;
} }
} }
namespace levelline {
struct levelline {
int column, qty;
unsigned int color;
vector<double> values;
bool modified;
};
vector<levelline> levellines;
bool on;
void create() {
int xlalpha = int(pow(ld(.5), ggamma) * 255);
for(int i=0; i<cols; i++) {
levellines.emplace_back();
levelline& lv = levellines.back();
lv.column = i;
lv.color = ((hrandpos() & 0xFFFFFF) << 8) | xlalpha;
lv.qty = 0;
}
}
void build() {
if(levellines.size() == 0) create();
on = false;
for(auto& lv: levellines) {
if(!lv.qty) { lv.values.clear(); continue; }
on = true;
if(!lv.modified) continue;
lv.modified = false;
vector<double> sample;
for(int j=0; j<=1024; j++) sample.push_back(data[hrand(samples)].val[lv.column]);
sort(sample.begin(), sample.end());
lv.values.clear();
lv.values.push_back(-1e10);
for(int j=0; j<=1024; j+=1024 >> (lv.qty)) lv.values.push_back(sample[j]);
lv.values.push_back(1e10);
}
}
void draw() {
if(!on) return;
for(auto& g: gmatrix) {
cell *c1 = g.first;
transmatrix T = g.second;
neuron *n1 = getNeuron(c1);
if(!n1) continue;
for(int i=0; i<c1->type; i++) {
cell *c2 = c1->mov[i];
if(!c2) continue;
cell *c3 = c1->mov[i ? i-1 : c1->type-1];
if(!c3) continue;
if(!gmatrix.count(c2)) continue;
if(!gmatrix.count(c3)) continue;
double d2 = hdist(tC0(T), tC0(gmatrix[c2]));
double d3 = hdist(tC0(T), tC0(gmatrix[c3]));
neuron *n2 = getNeuron(c2);
if(!n2) continue;
neuron *n3 = getNeuron(c3);
if(!n3) continue;
for(auto& l: levellines) {
auto val1 = n1->net[l.column];
auto val2 = n2->net[l.column];
auto val3 = n3->net[l.column];
auto v1 = lower_bound(l.values.begin(), l.values.end(), val1);
auto v2 = lower_bound(l.values.begin(), l.values.end(), val2);
auto v3 = lower_bound(l.values.begin(), l.values.end(), val3);
auto draw = [&] () {
auto vmid = *v1;
queueline(
tC0(T * ddspin(c1,i) * xpush(d2 * (vmid-val1) / (val2-val1))),
tC0(T * ddspin(c1,i-1) * xpush(d3 * (vmid-val1) / (val3-val1))),
l.color);
};
while(v1 < v2 && v1 < v3) {
draw();
v1++;
}
while(v1 > v2 && v1 > v3) {
v1--;
draw();
}
}
}
}
setindex(false);
}
void show() {
if(levellines.size() == 0) create();
gamescreen(0);
cmode = vid.xres > vid.yres * 1.4 ? sm::SIDE : 0;
dialog::init("level lines");
char nx = 'a';
for(auto &l : levellines) {
dialog::addSelItem(colnames[l.column], its(l.qty), nx++);
dialog::lastItem().colorv = l.color >> 8;
}
dialog::addItem("exit menu", '0');
dialog::addItem("shift+letter to change color", 0);
dialog::display();
keyhandler = [] (int sym, int uni) {
dialog::handleNavigation(sym, uni);
if(uni >= 'a' && uni - 'a' + size(levellines)) {
auto& l = levellines[uni - 'a'];
dialog::editNumber(l.qty, 0, 10, 1, 0, colnames[l.column],
XLAT("Controls the number of level lines."));
dialog::reaction = [&l] () {
l.modified = true;
build();
};
}
else if(uni >= 'A' && uni - 'A' + size(levellines)) {
auto& l = levellines[uni - 'A'];
dialog::openColorDialog(l.color, NULL);
}
else if(doexiton(sym, uni)) popScreen();
};
}
}
void ksave(const char *fname) { void ksave(const char *fname) {
sominit(1); sominit(1);
FILE *f = fopen(fname, "wt"); FILE *f = fopen(fname, "wt");
@ -586,7 +758,7 @@ void ksave(const char *fname) {
fprintf(f, "%d %d\n", cells, t); fprintf(f, "%d %d\n", cells, t);
for(neuron& n: net) { for(neuron& n: net) {
for(int k=0; k<cols; k++) for(int k=0; k<cols; k++)
fprintf(f, "%.4lf ", n.net[k]); fprintf(f, "%.9lf ", n.net[k]);
fprintf(f, "\n"); fprintf(f, "\n");
} }
fclose(f); fclose(f);
@ -612,26 +784,64 @@ void kload(const char *fname) {
analyze(); analyze();
} }
void kclassify(const char *fname) { void ksavew(const char *fname) {
sominit(1); sominit(1);
for(neuron& n: net) n.samples = 0;
FILE *f = fopen(fname, "wt"); FILE *f = fopen(fname, "wt");
if(!f) { if(!f) {
fprintf(stderr, "Could not save classification\n"); fprintf(stderr, "Could not save the weights\n");
return; return;
} }
for(int id=0; id<samples; id++) { for(int i=0; i<cols; i++)
auto& w = winner(id); fprintf(f, "%s=%.9lf\n", colnames[i].c_str(), weights[i]);
w.samples++;
if(id % 100000 == 0) printf("%d/%d\n", id, size(data));
fprintf(f, "%s;%d\n", data[id].name.c_str(), neuronId(w));
}
fclose(f); fclose(f);
coloring();
} }
void kclassify2(const char *fname_classify, const char *fname_samples) { void kloadw(const char *fname) {
sominit(1);
FILE *f = fopen(fname, "rt");
if(!f) {
fprintf(stderr, "Could not load the weights\n");
return;
}
for(int i=0; i<cols; i++) {
string s1, s2;
char kind = 0;
while(true) {
int c = fgetc(f);
if(c == 10 || c == 13 || c == -1) {
if(s1 == "" && !kind && c != -1) continue;
if(s1 != "") colnames[i] = s1;
if(kind == '=') weights[i] = atof(s2.c_str());
if(kind == '*') weights[i] *= atof(s2.c_str());
if(kind == '/') weights[i] /= atof(s2.c_str());
if(c == -1) break;
goto nexti;
}
else if(c == '=' || c == '/' || c == '*') kind = c;
else (kind?s2:s1) += c;
}
nexti: ;
}
fclose(f);
analyze();
}
unsigned lastprogress;
void progress(string s) {
if(SDL_GetTicks() >= lastprogress + (noGUI ? 500 : 100)) {
if(noGUI)
printf("%s\n", s.c_str());
else {
clearMessages();
addMessage(s);
mainloopiter();
}
lastprogress = SDL_GetTicks();
}
}
void kclassify(const char *fname_classify) {
sominit(1); sominit(1);
vector<double> bdiffs(samples, 1e20); vector<double> bdiffs(samples, 1e20);
@ -646,7 +856,8 @@ void kclassify2(const char *fname_classify, const char *fname_samples) {
double diff = vnorm(net[n].net, data[s].val); double diff = vnorm(net[n].net, data[s].val);
if(diff < bdiffs[s]) bdiffs[s] = diff, bids[s] = n, whowon[s] = &net[n]; if(diff < bdiffs[s]) bdiffs[s] = diff, bids[s] = n, whowon[s] = &net[n];
} }
if(s % 1000000 == 999999) printf("%d/%d\n", s, samples); if(!(s % 128))
progress("Classifying: " + its(s) + "/" + its(samples));
} }
vector<double> bdiffn(cells, 1e20); vector<double> bdiffn(cells, 1e20);
@ -673,30 +884,55 @@ void kclassify2(const char *fname_classify, const char *fname_samples) {
fclose(f); fclose(f);
} }
} }
coloring();
}
void klistsamples(const char *fname_samples, bool best, bool colorformat) {
if(fname_samples != NULL) { if(fname_samples != NULL) {
printf("Listing best samples...\n"); printf("Listing samples...\n");
FILE *f = fopen(fname_samples, "wt"); FILE *f = fopen(fname_samples, "wt");
if(!f) { if(!f) {
printf("Failed to open file\n"); printf("Failed to open file\n");
} }
else { else {
fprintf(f, "%d\n", cols); auto klistsample = [f, colorformat] (int id, int neu) {
for(int n=0; n<cells; n++) { if(colorformat) {
fflush(f); fprintf(f, "%s;+#%d\n", data[id].name.c_str(), neu);
if(!net[n].samples) { fprintf(f, "\n"); continue; } }
int s = net[n].bestsample; else {
for(int k=0; k<cols; k++) for(int k=0; k<cols; k++)
fprintf(f, "%.4lf ", data[s].val[k]); fprintf(f, "%.4lf ", data[id].val[k]);
fflush(f); fprintf(f, "!%s\n", data[id].name.c_str());
fprintf(f, "!%s\n", data[s].name.c_str()); }
fflush(f); };
} if(!colorformat) fprintf(f, "%d\n", cols);
if(best)
for(int n=0; n<cells; n++) {
if(!net[n].samples) { if(!colorformat) fprintf(f, "\n"); continue; }
klistsample(net[n].bestsample, n);
}
else
for(int i=0; i<size(samples_shown); i++)
klistsample(samples_shown[i], neuronId(*(whowon[i])));
fclose(f); fclose(f);
} }
} }
}
coloring(); void neurondisttable(const char *fname) {
FILE *f = fopen(fname, "wt");
if(!f) {
printf("Could not open file: %s\n", fname);
return;
}
int neurons = size(net);
fprintf(f, "%d\n", neurons);
for(int i=0; i<neurons; i++) {
for(int j=0; j<neurons; j++) fprintf(f, "%3d", celldistance(net[i].where, net[j].where));
// todo: build the table correctly for gaussian=2
fprintf(f, "\n");
}
fclose(f);
} }
void steps() { void steps() {
@ -717,9 +953,11 @@ void showMenu() {
else if(whattodraw[i] == -4) c = "number of samples"; else if(whattodraw[i] == -4) c = "number of samples";
else if(whattodraw[i] == -5) c = "best sample's color"; else if(whattodraw[i] == -5) c = "best sample's color";
else if(whattodraw[i] == -6) c = "sample names to colors"; else if(whattodraw[i] == -6) c = "sample names to colors";
else c = "column " + its(whattodraw[i]); else c = colnames[whattodraw[i]];
dialog::addSelItem(XLAT("coloring (%1)", parts[i]), c, '1'+i); dialog::addSelItem(XLAT("coloring (%1)", parts[i]), c, '1'+i);
} }
dialog::addItem("coloring (all)", '0');
dialog::addItem("level lines", '4');
} }
bool handleMenu(int sym, int uni) { bool handleMenu(int sym, int uni) {
@ -734,6 +972,10 @@ bool handleMenu(int sym, int uni) {
for(char x: {'1','2','3'}) handleMenu(x, x); for(char x: {'1','2','3'}) handleMenu(x, x);
return true; return true;
} }
if(uni == '4') {
pushScreen(levelline::show);
return true;
}
return false; return false;
} }
@ -800,7 +1042,10 @@ int readArgs() {
// #4: run, stop etc. // #4: run, stop etc.
else if(argis("-somrunto")) { else if(argis("-somrunto")) {
int i = argi(); int i = argi();
shift(); while(t > i) kohonen::step(); shift(); while(t > i) {
if(t % 128 == 0) progress("Steps left: " + its(t));
kohonen::step();
}
} }
else if(argis("-somstop")) { else if(argis("-somstop")) {
t = 0; t = 0;
@ -809,7 +1054,10 @@ int readArgs() {
noshow = true; noshow = true;
} }
else if(argis("-somfinish")) { else if(argis("-somfinish")) {
while(!finished()) kohonen::step(); while(!finished()) {
kohonen::step();
if(t % 128 == 0) progress("Steps left: " + its(t));
}
} }
// #5 save data, classify etc. // #5 save data, classify etc.
@ -817,22 +1065,43 @@ int readArgs() {
PHASE(3); PHASE(3);
shift(); kohonen::ksave(args()); shift(); kohonen::ksave(args());
} }
else if(argis("-somsavew")) {
PHASE(3);
shift(); kohonen::ksavew(args());
}
else if(argis("-somloadw")) {
PHASE(3);
shift(); kohonen::kloadw(args());
}
else if(argis("-somclassify")) { else if(argis("-somclassify")) {
PHASE(3); PHASE(3);
shift(); kohonen::kclassify(args()); shift(); kohonen::kclassify(args());
} }
else if(argis("-somclassify2")) { else if(argis("-somlistshown")) {
PHASE(3); PHASE(3);
shift(); const char *f1 = args(); shift(); kohonen::klistsamples(args(), false, false);
shift(); const char *f2 = args(); }
kohonen::kclassify2(f1, f2); else if(argis("-somlistbest")) {
PHASE(3);
shift(); kohonen::klistsamples(args(), true, false);
}
else if(argis("-somlistbestc")) {
PHASE(3);
shift(); kohonen::klistsamples(args(), true, true);
}
else if(argis("-somndist")) {
PHASE(3);
shift(); kohonen::neurondisttable(args());
}
else if(argis("-somshowbest")) {
showbestsamples();
} }
else return 1; else return 1;
return 0; return 0;
} }
auto hooks = addHook(hooks_args, 100, readArgs); auto hooks = addHook(hooks_args, 100, readArgs) + addHook(hooks_frame, 50, levelline::draw);
} }
void mark(cell *c) { void mark(cell *c) {

View File

@ -29,6 +29,8 @@ bool specialmark = false;
bool rog3 = false; bool rog3 = false;
ld ggamma = .5;
string fname; string fname;
// const char *fname; // const char *fname;
@ -1027,8 +1029,6 @@ void queuedisk(const transmatrix& V, const colorpair& cp, bool legend) {
if(cp.shade == 'm') queuepoly(V, shDiskM, cp.color2); if(cp.shade == 'm') queuepoly(V, shDiskM, cp.color2);
} }
ld ggamma = .5;
void drawVertex(const transmatrix &V, cell *c, shmup::monster *m) { void drawVertex(const transmatrix &V, cell *c, shmup::monster *m) {
if(m->dead) return; if(m->dead) return;
int i = m->pid; int i = m->pid;
@ -1292,7 +1292,16 @@ void readcolor(const char *cfname) {
colorpair x; colorpair x;
int c2 = fgetc(f); int c2 = fgetc(f);
if(c2 == '=') { if(kohonen::samples && c2 == '+') {
kohonen::showsample(lab);
c2 = fgetc(f);
if(c2 == 10 || c2 == 13) continue;
}
if(c2 == '#') {
while(c2 != 10 && c2 != 13 && c2 != -1) c2 = fgetc(f);
continue;
}
else if(c2 == '=') {
string lab2 = ""; string lab2 = "";
while(true) { while(true) {
int c = fgetc(f); int c = fgetc(f);
@ -1326,6 +1335,12 @@ void readcolor(const char *cfname) {
vdata[i].cp = x; vdata[i].cp = x;
} }
} }
else if(kohonen::samples) {
for(int i=0; i<size(vdata); i++)
if(vdata[i].name == lab) {
vdata[i].cp = x;
}
}
else { else {
int i = getid(lab); int i = getid(lab);
again: vdata[i].cp = x; again: vdata[i].cp = x;