interactive

This commit is contained in:
Gustav Louw 2018-04-02 12:42:19 -07:00
parent ea130335da
commit be0f1f3247
2 changed files with 104 additions and 9 deletions

View File

@ -25,6 +25,7 @@ CFLAGS += -flto
LDFLAGS =
LDFLAGS += -lm
LDFLAGS += -lSDL2
ifdef ComSpec
RM = del /F /Q

112
test.c
View File

@ -1,8 +1,19 @@
#include "Tinn.h"
#include <stdio.h>
#include <stdbool.h>
#include <time.h>
#include <float.h>
#include <string.h>
#include <stdlib.h>
#include <SDL2/SDL.h>
typedef struct
{
bool down;
int x;
int y;
}
Input;
typedef struct
{
@ -125,6 +136,93 @@ static Data build(const char* path, const int nips, const int nops)
return data;
}
void dprint(const float* const p, const int size)
{
for(int i = 0; i < size; i++)
printf("%f ", (double) p[i]);
printf("\n");
}
typedef struct
{
int i;
float val;
}
Index;
Index ixmax(const float* const p, const int size)
{
Index ix;
ix.val = -FLT_MAX;
for(int i = 0; i < size; i++)
if(p[i] > ix.val)
ix.val = p[ix.i = i];
return ix;
}
void dploop(const Tinn tinn, const Data data)
{
SDL_Renderer* renderer;
SDL_Window* window;
#define W 16
#define H 16
#define S 20
const int xres = W * S;
const int yres = H * S;
SDL_CreateWindowAndRenderer(xres, yres, 0, &window, &renderer);
static float digit[W * H];
Input input = { false, 0, 0 };
for(SDL_Event e; true; SDL_PollEvent(&e))
{
if(e.type == SDL_QUIT)
exit(1);
const int button = SDL_GetMouseState(&input.x, &input.y);
// Draw digit.
if(button)
{
const int xx = input.x / S;
const int yy = input.y / S;
const int w = 2;
for(int i = 0; i < w; i++)
for(int j = 0; j < w; j++)
digit[(xx + i) + W * (yy + j)] = 1.0f;
input.down = true;
}
// Predict.
else
{
if(input.down)
{
const float* const pred = xpredict(tinn, digit);
dprint(pred, data.nops);
const Index ix = ixmax(pred, data.nops);
if(ix.val > 0.9f)
printf("%d\n", ix.i);
else
printf("I do not recognize that digit\n");
memset((void*) digit, 0, sizeof(digit));
}
input.down = false;
}
// Draw digit to screen.
for(int x = 0; x < xres; x++)
for(int y = 0; y < yres; y++)
{
const int xx = x / S;
const int yy = y / S;
digit[xx + W * yy] == 1.0f ?
SDL_SetRenderDrawColor(renderer, 0xFF, 0xFF, 0xFF, 0xFF):
SDL_SetRenderDrawColor(renderer, 0x00, 0x00, 0x00, 0xFF);
SDL_RenderDrawPoint(renderer, x, y);
}
SDL_RenderPresent(renderer);
SDL_Delay(15);
}
#undef W
#undef H
#undef S
}
int main()
{
// Tinn does not seed the random number generator.
@ -144,7 +242,7 @@ int main()
const Data data = build("semeion.data", nips, nops);
// Train, baby, train.
const Tinn tinn = xtbuild(nips, nhid, nops);
for(int i = 0; i < 100; i++)
for(int i = 0; i < 200; i++)
{
shuffle(data);
float error = 0.0f;
@ -163,14 +261,10 @@ int main()
// This is how you load the neural network from disk.
const Tinn loaded = xtload("saved.tinn");
// Now we do a prediction with the neural network we loaded from disk.
// Ideally, we would also load a testing set to make the prediction with,
// but for the sake of brevity here we just reuse the training set from earlier.
const float* const in = data.in[0];
const float* const tg = data.tg[0];
const float* const pd = xpredict(loaded, in);
for(int i = 0; i < data.nops; i++) { printf("%f ", (double) tg[i]); } printf("\n");
for(int i = 0; i < data.nops; i++) { printf("%f ", (double) pd[i]); } printf("\n");
// All done. Let's clean up.
// SDL will create a window so that you can draw digits.
// Enter the draw and predict loop:
dploop(loaded, data);
// All done. Let's clean up
xtfree(loaded);
dfree(data);
return 0;