mirror of https://github.com/glouw/tinn
interactive
This commit is contained in:
parent
ea130335da
commit
be0f1f3247
1
Makefile
1
Makefile
|
@ -25,6 +25,7 @@ CFLAGS += -flto
|
||||||
|
|
||||||
LDFLAGS =
|
LDFLAGS =
|
||||||
LDFLAGS += -lm
|
LDFLAGS += -lm
|
||||||
|
LDFLAGS += -lSDL2
|
||||||
|
|
||||||
ifdef ComSpec
|
ifdef ComSpec
|
||||||
RM = del /F /Q
|
RM = del /F /Q
|
||||||
|
|
112
test.c
112
test.c
|
@ -1,8 +1,19 @@
|
||||||
#include "Tinn.h"
|
#include "Tinn.h"
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <stdbool.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
|
#include <float.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
#include <SDL2/SDL.h>
|
||||||
|
|
||||||
|
typedef struct
|
||||||
|
{
|
||||||
|
bool down;
|
||||||
|
int x;
|
||||||
|
int y;
|
||||||
|
}
|
||||||
|
Input;
|
||||||
|
|
||||||
typedef struct
|
typedef struct
|
||||||
{
|
{
|
||||||
|
@ -125,6 +136,93 @@ static Data build(const char* path, const int nips, const int nops)
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dprint(const float* const p, const int size)
|
||||||
|
{
|
||||||
|
for(int i = 0; i < size; i++)
|
||||||
|
printf("%f ", (double) p[i]);
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
float val;
|
||||||
|
}
|
||||||
|
Index;
|
||||||
|
|
||||||
|
Index ixmax(const float* const p, const int size)
|
||||||
|
{
|
||||||
|
Index ix;
|
||||||
|
ix.val = -FLT_MAX;
|
||||||
|
for(int i = 0; i < size; i++)
|
||||||
|
if(p[i] > ix.val)
|
||||||
|
ix.val = p[ix.i = i];
|
||||||
|
return ix;
|
||||||
|
}
|
||||||
|
|
||||||
|
void dploop(const Tinn tinn, const Data data)
|
||||||
|
{
|
||||||
|
SDL_Renderer* renderer;
|
||||||
|
SDL_Window* window;
|
||||||
|
#define W 16
|
||||||
|
#define H 16
|
||||||
|
#define S 20
|
||||||
|
const int xres = W * S;
|
||||||
|
const int yres = H * S;
|
||||||
|
SDL_CreateWindowAndRenderer(xres, yres, 0, &window, &renderer);
|
||||||
|
static float digit[W * H];
|
||||||
|
Input input = { false, 0, 0 };
|
||||||
|
for(SDL_Event e; true; SDL_PollEvent(&e))
|
||||||
|
{
|
||||||
|
if(e.type == SDL_QUIT)
|
||||||
|
exit(1);
|
||||||
|
const int button = SDL_GetMouseState(&input.x, &input.y);
|
||||||
|
// Draw digit.
|
||||||
|
if(button)
|
||||||
|
{
|
||||||
|
const int xx = input.x / S;
|
||||||
|
const int yy = input.y / S;
|
||||||
|
const int w = 2;
|
||||||
|
for(int i = 0; i < w; i++)
|
||||||
|
for(int j = 0; j < w; j++)
|
||||||
|
digit[(xx + i) + W * (yy + j)] = 1.0f;
|
||||||
|
input.down = true;
|
||||||
|
}
|
||||||
|
// Predict.
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if(input.down)
|
||||||
|
{
|
||||||
|
const float* const pred = xpredict(tinn, digit);
|
||||||
|
dprint(pred, data.nops);
|
||||||
|
const Index ix = ixmax(pred, data.nops);
|
||||||
|
if(ix.val > 0.9f)
|
||||||
|
printf("%d\n", ix.i);
|
||||||
|
else
|
||||||
|
printf("I do not recognize that digit\n");
|
||||||
|
memset((void*) digit, 0, sizeof(digit));
|
||||||
|
}
|
||||||
|
input.down = false;
|
||||||
|
}
|
||||||
|
// Draw digit to screen.
|
||||||
|
for(int x = 0; x < xres; x++)
|
||||||
|
for(int y = 0; y < yres; y++)
|
||||||
|
{
|
||||||
|
const int xx = x / S;
|
||||||
|
const int yy = y / S;
|
||||||
|
digit[xx + W * yy] == 1.0f ?
|
||||||
|
SDL_SetRenderDrawColor(renderer, 0xFF, 0xFF, 0xFF, 0xFF):
|
||||||
|
SDL_SetRenderDrawColor(renderer, 0x00, 0x00, 0x00, 0xFF);
|
||||||
|
SDL_RenderDrawPoint(renderer, x, y);
|
||||||
|
}
|
||||||
|
SDL_RenderPresent(renderer);
|
||||||
|
SDL_Delay(15);
|
||||||
|
}
|
||||||
|
#undef W
|
||||||
|
#undef H
|
||||||
|
#undef S
|
||||||
|
}
|
||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
// Tinn does not seed the random number generator.
|
// Tinn does not seed the random number generator.
|
||||||
|
@ -144,7 +242,7 @@ int main()
|
||||||
const Data data = build("semeion.data", nips, nops);
|
const Data data = build("semeion.data", nips, nops);
|
||||||
// Train, baby, train.
|
// Train, baby, train.
|
||||||
const Tinn tinn = xtbuild(nips, nhid, nops);
|
const Tinn tinn = xtbuild(nips, nhid, nops);
|
||||||
for(int i = 0; i < 100; i++)
|
for(int i = 0; i < 200; i++)
|
||||||
{
|
{
|
||||||
shuffle(data);
|
shuffle(data);
|
||||||
float error = 0.0f;
|
float error = 0.0f;
|
||||||
|
@ -163,14 +261,10 @@ int main()
|
||||||
// This is how you load the neural network from disk.
|
// This is how you load the neural network from disk.
|
||||||
const Tinn loaded = xtload("saved.tinn");
|
const Tinn loaded = xtload("saved.tinn");
|
||||||
// Now we do a prediction with the neural network we loaded from disk.
|
// Now we do a prediction with the neural network we loaded from disk.
|
||||||
// Ideally, we would also load a testing set to make the prediction with,
|
// SDL will create a window so that you can draw digits.
|
||||||
// but for the sake of brevity here we just reuse the training set from earlier.
|
// Enter the draw and predict loop:
|
||||||
const float* const in = data.in[0];
|
dploop(loaded, data);
|
||||||
const float* const tg = data.tg[0];
|
// All done. Let's clean up
|
||||||
const float* const pd = xpredict(loaded, in);
|
|
||||||
for(int i = 0; i < data.nops; i++) { printf("%f ", (double) tg[i]); } printf("\n");
|
|
||||||
for(int i = 0; i < data.nops; i++) { printf("%f ", (double) pd[i]); } printf("\n");
|
|
||||||
// All done. Let's clean up.
|
|
||||||
xtfree(loaded);
|
xtfree(loaded);
|
||||||
dfree(data);
|
dfree(data);
|
||||||
return 0;
|
return 0;
|
||||||
|
|
Loading…
Reference in New Issue