things are getting better

This commit is contained in:
Gustav Louw 2018-03-26 19:09:24 -07:00
parent 4dbef70f4b
commit 48c91d609d
2 changed files with 47 additions and 6 deletions

11
README.md Normal file
View File

@ -0,0 +1,11 @@
# Shaper
Shaper learns hand written digits.
Get the training data:
wget http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data
Build and run:
make; ./shaper

42
main.c
View File

@ -6,6 +6,7 @@
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <time.h>
#include "genann.h"
@ -94,12 +95,33 @@ static void dfree(const Data d)
free(d.od);
}
static void shuffle(const Data d)
{
srand(time(0));
for(int a = 0; a < d.rows; a++)
{
const int b = rand() % d.rows;
double* ot = d.od[a];
double* it = d.id[a];
// Swap output.
d.od[a] = d.od[b];
d.od[b] = ot;
// Swap input.
d.id[a] = d.id[b];
d.id[b] = it;
}
}
static genann* dtrain(const Data d, const int ntimes, const int layers, const int neurons, const int rate)
{
genann* const ann = genann_init(d.icols, layers, neurons, d.ocols);
for(int i = 0; i < ntimes; i++)
for(int j = 0; j < d.rows; j++)
genann_train(ann, d.id[j], d.od[j], rate);
{
shuffle(d);
for(int j = 0; j < d.rows; j++)
genann_train(ann, d.id[j], d.od[j], rate);
printf("%f\n", (double) i / ntimes);
}
return ann;
}
@ -107,9 +129,15 @@ static void dpredict(genann* ann, const Data d)
{
for(int i = 0; i < d.rows; i++)
{
printf("%d: ", i);
// Prediciton.
const double* const pred = genann_run(ann, d.id[i]);
for(int j = 0; j < d.ocols; j++)
printf("%s%d", j > 0 ? " " : "", pred[j] > 0.9);
printf("%s", " :: ");
// Actual.
for(int j = 0; j < d.ocols; j++)
printf("%s%d", j > 0 ? " " : "", (int) d.od[i][j]);
putchar('\n');
}
}
@ -133,10 +161,12 @@ int main(int argc, char* argv[])
{
(void) argc;
(void) argv;
const Data data = dbuild("semeion.data", 256, 10);
genann* ann = dtrain(data, 256, 1, 32, 1);
dpredict(ann, data);
const Data test = dbuild("semeion.data", 256, 10);
const Data vald = dbuild("written.data", 256, 10);
genann* ann = dtrain(test, 256, 1, 128, 1);
dpredict(ann, vald);
genann_free(ann);
dfree(data);
dfree(test);
dfree(vald);
return 0;
}