tinn/Tinn.c

134 lines
3.1 KiB
C
Raw Normal View History

2018-03-29 06:41:08 +03:00
#include "Tinn.h"
#include <stdlib.h>
#include <math.h>
2018-03-30 00:32:11 +03:00
#include <time.h>
2018-03-29 06:41:08 +03:00
2018-03-30 23:04:37 +03:00
// Error function.
static double err(double a, double b)
2018-03-29 06:41:08 +03:00
{
2018-03-30 23:04:37 +03:00
return 0.5 * pow(a - b, 2.0);
2018-03-29 06:41:08 +03:00
}
2018-03-30 23:04:37 +03:00
// Partial derivative of error function.
static double pderr(double a, double b)
2018-03-29 06:41:08 +03:00
{
2018-03-30 23:04:37 +03:00
return a - b;
2018-03-29 06:41:08 +03:00
}
2018-03-30 23:04:37 +03:00
// Total error.
static double terr(const double* tg, const double* o, int size)
2018-03-29 06:41:08 +03:00
{
2018-03-30 23:04:37 +03:00
double sum = 0.0;
for(int i = 0; i < size; i++)
sum += err(tg[i], o[i]);
return sum;
2018-03-29 06:41:08 +03:00
}
2018-03-30 23:04:37 +03:00
// Activation function.
static double act(double a)
{
return 1.0 / (1.0 + exp(-a));
}
// Partial derivative of activation function.
static double pdact(double a)
{
return a * (1.0 - a);
}
// Floating point random from 0.0 - 1.0.
static double frand()
2018-03-30 00:32:11 +03:00
{
return rand() / (double) RAND_MAX;
}
2018-03-30 23:04:37 +03:00
// Back propagation.
static void backwards(const Tinn t, const double* in, const double* tg, double rate)
2018-03-29 06:41:08 +03:00
{
2018-03-29 21:40:14 +03:00
double* x = t.w + t.nhid * t.nips;
2018-03-30 23:04:37 +03:00
for(int i = 0; i < t.nhid; i++)
2018-03-29 06:41:08 +03:00
{
double sum = 0.0;
2018-03-30 23:04:37 +03:00
// Calculate total error change with respect to output.
for(int j = 0; j < t.nops; j++)
2018-03-29 06:41:08 +03:00
{
2018-03-30 23:04:37 +03:00
double a = pderr(t.o[j], tg[j]);
double b = pdact(t.o[j]);
sum += a * b * x[j * t.nhid + i];
// Correct weights in hidden to output layer.
x[j * t.nhid + i] -= rate * a * b * t.h[i];
2018-03-29 06:41:08 +03:00
}
2018-03-30 23:04:37 +03:00
// Correct weights in input to hidden layer.
for(int j = 0; j < t.nips; j++)
t.w[i * t.nips + j] -= rate * sum * pdact(t.h[i]) * in[j];
}
}
// Forward propagation.
static void forewards(const Tinn t, const double* in)
{
double* x = t.w + t.nhid * t.nips;
// Calculate hidden layer neuron values.
for(int i = 0; i < t.nhid; i++)
{
double sum = 0.0;
for(int j = 0; j < t.nips; j++)
sum += in[j] * t.w[i * t.nips + j];
2018-03-30 00:32:11 +03:00
t.h[i] = act(sum + t.b[0]);
2018-03-29 06:41:08 +03:00
}
2018-03-30 23:04:37 +03:00
// Calculate output layer neuron values.
for(int i = 0; i < t.nops; i++)
2018-03-29 06:41:08 +03:00
{
double sum = 0.0;
2018-03-30 23:04:37 +03:00
for(int j = 0; j < t.nhid; j++)
sum += t.h[j] * x[i * t.nhid + j];
2018-03-30 00:32:11 +03:00
t.o[i] = act(sum + t.b[1]);
2018-03-29 06:41:08 +03:00
}
}
2018-03-30 23:04:37 +03:00
// Randomizes weights and biases.
static void twrand(const Tinn t)
2018-03-29 06:41:08 +03:00
{
2018-03-30 00:35:39 +03:00
int wgts = t.nhid * (t.nips + t.nops);
2018-03-30 23:04:37 +03:00
for(int i = 0; i < wgts; i++) t.w[i] = frand() - 0.5;
for(int i = 0; i < t.nb; i++) t.b[i] = frand() - 0.5;
}
double* xpredict(const Tinn t, const double* in)
{
forewards(t, in);
return t.o;
2018-03-29 06:41:08 +03:00
}
2018-03-30 23:04:37 +03:00
double xttrain(const Tinn t, const double* in, const double* tg, double rate)
2018-03-29 08:04:47 +03:00
{
forewards(t, in);
backwards(t, in, tg, rate);
2018-03-30 23:04:37 +03:00
return terr(tg, t.o, t.nops);
2018-03-29 08:04:47 +03:00
}
2018-03-30 00:32:11 +03:00
Tinn xtbuild(int nips, int nhid, int nops)
2018-03-29 06:41:08 +03:00
{
Tinn t;
2018-03-30 23:04:37 +03:00
// Tinn only supports one hidden layer so there are two biases.
2018-03-30 00:32:11 +03:00
t.nb = 2;
2018-03-29 08:04:47 +03:00
t.w = (double*) calloc(nhid * (nips + nops), sizeof(*t.w));
2018-03-30 00:32:11 +03:00
t.b = (double*) calloc(t.nb, sizeof(*t.b));
t.h = (double*) calloc(nhid, sizeof(*t.h));
t.o = (double*) calloc(nops, sizeof(*t.o));
2018-03-29 08:04:47 +03:00
t.nips = nips;
2018-03-30 00:32:11 +03:00
t.nhid = nhid;
t.nops = nops;
srand(time(0));
2018-03-29 08:04:47 +03:00
twrand(t);
2018-03-29 06:41:08 +03:00
return t;
}
2018-03-30 23:04:37 +03:00
void xtfree(const Tinn t)
2018-03-29 06:41:08 +03:00
{
2018-03-29 08:04:47 +03:00
free(t.w);
free(t.h);
free(t.o);
2018-03-29 06:41:08 +03:00
}