diff --git a/Makefile b/Makefile index 40382d5..a16be38 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ CFLAGS += -flto LDFLAGS = LDFLAGS += -lm +LDFLAGS += -lSDL2 ifdef ComSpec RM = del /F /Q diff --git a/test.c b/test.c index 7362aa8..938f3e5 100644 --- a/test.c +++ b/test.c @@ -1,8 +1,19 @@ #include "Tinn.h" #include +#include #include +#include #include #include +#include + +typedef struct +{ + bool down; + int x; + int y; +} +Input; typedef struct { @@ -125,6 +136,93 @@ static Data build(const char* path, const int nips, const int nops) return data; } +void dprint(const float* const p, const int size) +{ + for(int i = 0; i < size; i++) + printf("%f ", (double) p[i]); + printf("\n"); +} + +typedef struct +{ + int i; + float val; +} +Index; + +Index ixmax(const float* const p, const int size) +{ + Index ix; + ix.val = -FLT_MAX; + for(int i = 0; i < size; i++) + if(p[i] > ix.val) + ix.val = p[ix.i = i]; + return ix; +} + +void dploop(const Tinn tinn, const Data data) +{ + SDL_Renderer* renderer; + SDL_Window* window; + #define W 16 + #define H 16 + #define S 20 + const int xres = W * S; + const int yres = H * S; + SDL_CreateWindowAndRenderer(xres, yres, 0, &window, &renderer); + static float digit[W * H]; + Input input = { false, 0, 0 }; + for(SDL_Event e; true; SDL_PollEvent(&e)) + { + if(e.type == SDL_QUIT) + exit(1); + const int button = SDL_GetMouseState(&input.x, &input.y); + // Draw digit. + if(button) + { + const int xx = input.x / S; + const int yy = input.y / S; + const int w = 2; + for(int i = 0; i < w; i++) + for(int j = 0; j < w; j++) + digit[(xx + i) + W * (yy + j)] = 1.0f; + input.down = true; + } + // Predict. + else + { + if(input.down) + { + const float* const pred = xpredict(tinn, digit); + dprint(pred, data.nops); + const Index ix = ixmax(pred, data.nops); + if(ix.val > 0.9f) + printf("%d\n", ix.i); + else + printf("I do not recognize that digit\n"); + memset((void*) digit, 0, sizeof(digit)); + } + input.down = false; + } + // Draw digit to screen. + for(int x = 0; x < xres; x++) + for(int y = 0; y < yres; y++) + { + const int xx = x / S; + const int yy = y / S; + digit[xx + W * yy] == 1.0f ? + SDL_SetRenderDrawColor(renderer, 0xFF, 0xFF, 0xFF, 0xFF): + SDL_SetRenderDrawColor(renderer, 0x00, 0x00, 0x00, 0xFF); + SDL_RenderDrawPoint(renderer, x, y); + } + SDL_RenderPresent(renderer); + SDL_Delay(15); + } + #undef W + #undef H + #undef S +} + int main() { // Tinn does not seed the random number generator. @@ -144,7 +242,7 @@ int main() const Data data = build("semeion.data", nips, nops); // Train, baby, train. const Tinn tinn = xtbuild(nips, nhid, nops); - for(int i = 0; i < 100; i++) + for(int i = 0; i < 200; i++) { shuffle(data); float error = 0.0f; @@ -163,14 +261,10 @@ int main() // This is how you load the neural network from disk. const Tinn loaded = xtload("saved.tinn"); // Now we do a prediction with the neural network we loaded from disk. - // Ideally, we would also load a testing set to make the prediction with, - // but for the sake of brevity here we just reuse the training set from earlier. - const float* const in = data.in[0]; - const float* const tg = data.tg[0]; - const float* const pd = xpredict(loaded, in); - for(int i = 0; i < data.nops; i++) { printf("%f ", (double) tg[i]); } printf("\n"); - for(int i = 0; i < data.nops; i++) { printf("%f ", (double) pd[i]); } printf("\n"); - // All done. Let's clean up. + // SDL will create a window so that you can draw digits. + // Enter the draw and predict loop: + dploop(loaded, data); + // All done. Let's clean up xtfree(loaded); dfree(data); return 0;