mirror of
https://github.com/glouw/tinn
synced 2024-11-24 23:39:38 +03:00
things are getting better
This commit is contained in:
parent
4dbef70f4b
commit
48c91d609d
11
README.md
Normal file
11
README.md
Normal file
@ -0,0 +1,11 @@
|
||||
# Shaper
|
||||
|
||||
Shaper learns hand written digits.
|
||||
|
||||
Get the training data:
|
||||
|
||||
wget http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data
|
||||
|
||||
Build and run:
|
||||
|
||||
make; ./shaper
|
42
main.c
42
main.c
@ -6,6 +6,7 @@
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
|
||||
#include "genann.h"
|
||||
|
||||
@ -94,12 +95,33 @@ static void dfree(const Data d)
|
||||
free(d.od);
|
||||
}
|
||||
|
||||
static void shuffle(const Data d)
|
||||
{
|
||||
srand(time(0));
|
||||
for(int a = 0; a < d.rows; a++)
|
||||
{
|
||||
const int b = rand() % d.rows;
|
||||
double* ot = d.od[a];
|
||||
double* it = d.id[a];
|
||||
// Swap output.
|
||||
d.od[a] = d.od[b];
|
||||
d.od[b] = ot;
|
||||
// Swap input.
|
||||
d.id[a] = d.id[b];
|
||||
d.id[b] = it;
|
||||
}
|
||||
}
|
||||
|
||||
static genann* dtrain(const Data d, const int ntimes, const int layers, const int neurons, const int rate)
|
||||
{
|
||||
genann* const ann = genann_init(d.icols, layers, neurons, d.ocols);
|
||||
for(int i = 0; i < ntimes; i++)
|
||||
for(int j = 0; j < d.rows; j++)
|
||||
genann_train(ann, d.id[j], d.od[j], rate);
|
||||
{
|
||||
shuffle(d);
|
||||
for(int j = 0; j < d.rows; j++)
|
||||
genann_train(ann, d.id[j], d.od[j], rate);
|
||||
printf("%f\n", (double) i / ntimes);
|
||||
}
|
||||
return ann;
|
||||
}
|
||||
|
||||
@ -107,9 +129,15 @@ static void dpredict(genann* ann, const Data d)
|
||||
{
|
||||
for(int i = 0; i < d.rows; i++)
|
||||
{
|
||||
printf("%d: ", i);
|
||||
// Prediciton.
|
||||
const double* const pred = genann_run(ann, d.id[i]);
|
||||
for(int j = 0; j < d.ocols; j++)
|
||||
printf("%s%d", j > 0 ? " " : "", pred[j] > 0.9);
|
||||
printf("%s", " :: ");
|
||||
// Actual.
|
||||
for(int j = 0; j < d.ocols; j++)
|
||||
printf("%s%d", j > 0 ? " " : "", (int) d.od[i][j]);
|
||||
putchar('\n');
|
||||
}
|
||||
}
|
||||
@ -133,10 +161,12 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
(void) argc;
|
||||
(void) argv;
|
||||
const Data data = dbuild("semeion.data", 256, 10);
|
||||
genann* ann = dtrain(data, 256, 1, 32, 1);
|
||||
dpredict(ann, data);
|
||||
const Data test = dbuild("semeion.data", 256, 10);
|
||||
const Data vald = dbuild("written.data", 256, 10);
|
||||
genann* ann = dtrain(test, 256, 1, 128, 1);
|
||||
dpredict(ann, vald);
|
||||
genann_free(ann);
|
||||
dfree(data);
|
||||
dfree(test);
|
||||
dfree(vald);
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user