saving and loading functions

This commit is contained in:
Gustav Louw 2018-03-30 15:42:20 -07:00
parent a166e79f8e
commit cca0b71032
3 changed files with 63 additions and 11 deletions

36
Tinn.c
View File

@ -1,5 +1,6 @@
#include "Tinn.h"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
@ -90,8 +91,7 @@ static void forewards(const Tinn t, const double* in)
// Randomizes weights and biases.
static void twrand(const Tinn t)
{
int wgts = t.nhid * (t.nips + t.nops);
for(int i = 0; i < wgts; i++) t.w[i] = frand() - 0.5;
for(int i = 0; i < t.nw; i++) t.w[i] = frand() - 0.5;
for(int i = 0; i < t.nb; i++) t.b[i] = frand() - 0.5;
}
@ -113,7 +113,8 @@ Tinn xtbuild(int nips, int nhid, int nops)
Tinn t;
// Tinn only supports one hidden layer so there are two biases.
t.nb = 2;
t.w = (double*) calloc(nhid * (nips + nops), sizeof(*t.w));
t.nw = nhid * (nips + nops);
t.w = (double*) calloc(t.nw, sizeof(*t.w));
t.b = (double*) calloc(t.nb, sizeof(*t.b));
t.h = (double*) calloc(nhid, sizeof(*t.h));
t.o = (double*) calloc(nops, sizeof(*t.o));
@ -125,9 +126,38 @@ Tinn xtbuild(int nips, int nhid, int nops)
return t;
}
void xtsave(const Tinn t, const char* path)
{
FILE* file = fopen(path, "w");
// Header.
fprintf(file, "%d %d %d\n", t.nips, t.nhid, t.nops);
// Biases and weights.
for(int i = 0; i < t.nb; i++) fprintf(file, "%lf\n", t.b[i]);
for(int i = 0; i < t.nw; i++) fprintf(file, "%lf\n", t.w[i]);
fclose(file);
}
Tinn xtload(const char* path)
{
FILE* file = fopen(path, "r");
int nips = 0;
int nhid = 0;
int nops = 0;
// Header.
fscanf(file, "%d %d %d\n", &nips, &nhid, &nops);
// A new tinn is returned.
Tinn t = xtbuild(nips, nhid, nips);
// Biases and weights.
for(int i = 0; i < t.nb; i++) fscanf(file, "%lf\n", &t.b[i]);
for(int i = 0; i < t.nw; i++) fscanf(file, "%lf\n", &t.w[i]);
fclose(file);
return t;
}
void xtfree(const Tinn t)
{
free(t.w);
free(t.b);
free(t.h);
free(t.o);
}

20
Tinn.h
View File

@ -10,16 +10,32 @@ typedef struct
// Number of biases - always two - Tinn only supports a single hidden layer.
int nb;
// Number of weights.
int nw;
int nips; // Number of inputs.
int nhid; // Number of hidden neurons.
int nops; // Number of outputs.
}
Tinn;
// Trains a tinn with an input and target output with a learning rate.
// Returns error rate of the neural network.
double xttrain(const Tinn, const double* in, const double* tg, double rate);
// Builds a new tinn object given number of inputs (nips),
// number of hidden neurons for the hidden layer (nhid),
// and number of outputs (nops).
Tinn xtbuild(int nips, int nhid, int nops);
void xtfree(Tinn);
// Returns an output prediction given an input.
double* xpredict(const Tinn, const double* in);
// Saves the tinn to disk.
void xtsave(const Tinn, const char* path);
// Loads a new tinn from disk.
Tinn xtload(const char* path);
// Frees a tinn from the heap.
void xtfree(const Tinn);

18
test.c
View File

@ -133,12 +133,12 @@ int main()
// Hyper Parameters.
// Learning rate is annealed and thus not constant.
const int nhid = 32;
double rate = 0.5;
double rate = 1.0;
// Load the training set.
const Data data = build("semeion.data", nips, nops);
// Rock and roll.
// Train, baby, train.
const Tinn tinn = xtbuild(nips, nhid, nops);
for(int i = 0; i < 100; i++)
for(int i = 0; i < 30; i++)
{
shuffle(data);
double error = 0.0;
@ -149,16 +149,22 @@ int main()
error += xttrain(tinn, in, tg, rate);
}
printf("error %.12f :: rate %f\n", error / data.rows, rate);
rate *= 0.99;
rate *= 0.9;
}
// This is how you save the neural network to disk.
xtsave(tinn, "saved.tinn");
xtfree(tinn);
// This is how you load the neural network from disk.
const Tinn loaded = xtload("saved.tinn");
// Ideally, you would load a testing set for predictions,
// but for the sake of brevity the training set is reused.
const double* const in = data.in[0];
const double* const tg = data.tg[0];
const double* const pd = xpredict(tinn, in);
const double* const pd = xpredict(loaded, in);
for(int i = 0; i < data.nops; i++) { printf("%f ", tg[i]); } printf("\n");
for(int i = 0; i < data.nops; i++) { printf("%f ", pd[i]); } printf("\n");
xtfree(tinn);
// Cleanup.
xtfree(loaded);
dfree(data);
return 0;
}