mirror of https://github.com/glouw/tinn
interactive
This commit is contained in:
parent
ea130335da
commit
be0f1f3247
1
Makefile
1
Makefile
|
@ -25,6 +25,7 @@ CFLAGS += -flto
|
|||
|
||||
LDFLAGS =
|
||||
LDFLAGS += -lm
|
||||
LDFLAGS += -lSDL2
|
||||
|
||||
ifdef ComSpec
|
||||
RM = del /F /Q
|
||||
|
|
112
test.c
112
test.c
|
@ -1,8 +1,19 @@
|
|||
#include "Tinn.h"
|
||||
#include <stdio.h>
|
||||
#include <stdbool.h>
|
||||
#include <time.h>
|
||||
#include <float.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
#include <SDL2/SDL.h>
|
||||
|
||||
typedef struct
|
||||
{
|
||||
bool down;
|
||||
int x;
|
||||
int y;
|
||||
}
|
||||
Input;
|
||||
|
||||
typedef struct
|
||||
{
|
||||
|
@ -125,6 +136,93 @@ static Data build(const char* path, const int nips, const int nops)
|
|||
return data;
|
||||
}
|
||||
|
||||
void dprint(const float* const p, const int size)
|
||||
{
|
||||
for(int i = 0; i < size; i++)
|
||||
printf("%f ", (double) p[i]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
typedef struct
|
||||
{
|
||||
int i;
|
||||
float val;
|
||||
}
|
||||
Index;
|
||||
|
||||
Index ixmax(const float* const p, const int size)
|
||||
{
|
||||
Index ix;
|
||||
ix.val = -FLT_MAX;
|
||||
for(int i = 0; i < size; i++)
|
||||
if(p[i] > ix.val)
|
||||
ix.val = p[ix.i = i];
|
||||
return ix;
|
||||
}
|
||||
|
||||
void dploop(const Tinn tinn, const Data data)
|
||||
{
|
||||
SDL_Renderer* renderer;
|
||||
SDL_Window* window;
|
||||
#define W 16
|
||||
#define H 16
|
||||
#define S 20
|
||||
const int xres = W * S;
|
||||
const int yres = H * S;
|
||||
SDL_CreateWindowAndRenderer(xres, yres, 0, &window, &renderer);
|
||||
static float digit[W * H];
|
||||
Input input = { false, 0, 0 };
|
||||
for(SDL_Event e; true; SDL_PollEvent(&e))
|
||||
{
|
||||
if(e.type == SDL_QUIT)
|
||||
exit(1);
|
||||
const int button = SDL_GetMouseState(&input.x, &input.y);
|
||||
// Draw digit.
|
||||
if(button)
|
||||
{
|
||||
const int xx = input.x / S;
|
||||
const int yy = input.y / S;
|
||||
const int w = 2;
|
||||
for(int i = 0; i < w; i++)
|
||||
for(int j = 0; j < w; j++)
|
||||
digit[(xx + i) + W * (yy + j)] = 1.0f;
|
||||
input.down = true;
|
||||
}
|
||||
// Predict.
|
||||
else
|
||||
{
|
||||
if(input.down)
|
||||
{
|
||||
const float* const pred = xpredict(tinn, digit);
|
||||
dprint(pred, data.nops);
|
||||
const Index ix = ixmax(pred, data.nops);
|
||||
if(ix.val > 0.9f)
|
||||
printf("%d\n", ix.i);
|
||||
else
|
||||
printf("I do not recognize that digit\n");
|
||||
memset((void*) digit, 0, sizeof(digit));
|
||||
}
|
||||
input.down = false;
|
||||
}
|
||||
// Draw digit to screen.
|
||||
for(int x = 0; x < xres; x++)
|
||||
for(int y = 0; y < yres; y++)
|
||||
{
|
||||
const int xx = x / S;
|
||||
const int yy = y / S;
|
||||
digit[xx + W * yy] == 1.0f ?
|
||||
SDL_SetRenderDrawColor(renderer, 0xFF, 0xFF, 0xFF, 0xFF):
|
||||
SDL_SetRenderDrawColor(renderer, 0x00, 0x00, 0x00, 0xFF);
|
||||
SDL_RenderDrawPoint(renderer, x, y);
|
||||
}
|
||||
SDL_RenderPresent(renderer);
|
||||
SDL_Delay(15);
|
||||
}
|
||||
#undef W
|
||||
#undef H
|
||||
#undef S
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// Tinn does not seed the random number generator.
|
||||
|
@ -144,7 +242,7 @@ int main()
|
|||
const Data data = build("semeion.data", nips, nops);
|
||||
// Train, baby, train.
|
||||
const Tinn tinn = xtbuild(nips, nhid, nops);
|
||||
for(int i = 0; i < 100; i++)
|
||||
for(int i = 0; i < 200; i++)
|
||||
{
|
||||
shuffle(data);
|
||||
float error = 0.0f;
|
||||
|
@ -163,14 +261,10 @@ int main()
|
|||
// This is how you load the neural network from disk.
|
||||
const Tinn loaded = xtload("saved.tinn");
|
||||
// Now we do a prediction with the neural network we loaded from disk.
|
||||
// Ideally, we would also load a testing set to make the prediction with,
|
||||
// but for the sake of brevity here we just reuse the training set from earlier.
|
||||
const float* const in = data.in[0];
|
||||
const float* const tg = data.tg[0];
|
||||
const float* const pd = xpredict(loaded, in);
|
||||
for(int i = 0; i < data.nops; i++) { printf("%f ", (double) tg[i]); } printf("\n");
|
||||
for(int i = 0; i < data.nops; i++) { printf("%f ", (double) pd[i]); } printf("\n");
|
||||
// All done. Let's clean up.
|
||||
// SDL will create a window so that you can draw digits.
|
||||
// Enter the draw and predict loop:
|
||||
dploop(loaded, data);
|
||||
// All done. Let's clean up
|
||||
xtfree(loaded);
|
||||
dfree(data);
|
||||
return 0;
|
||||
|
|
Loading…
Reference in New Issue