genann/test.c

277 lines
6.4 KiB
C
Raw Normal View History

2016-02-10 02:53:54 +03:00
/*
* GENANN - Minimal C Artificial Neural Network
*
* Copyright (c) 2015, 2016 Lewis Van Winkle
*
* http://CodePlea.com
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgement in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*
*/
#include "genann.h"
#include "minctest.h"
#include <stdio.h>
#include <math.h>
#include <stdlib.h>
void basic() {
2016-02-11 23:38:42 +03:00
genann *ann = genann_init(1, 0, 0, 1);
2016-02-10 02:53:54 +03:00
lequal(ann->total_weights, 2);
double a;
a = 0;
ann->weight[0] = 0;
ann->weight[1] = 0;
lfequal(0.5, *genann_run(ann, &a));
a = 1;
lfequal(0.5, *genann_run(ann, &a));
a = 11;
lfequal(0.5, *genann_run(ann, &a));
a = 1;
ann->weight[0] = 1;
ann->weight[1] = 1;
lfequal(0.5, *genann_run(ann, &a));
a = 10;
ann->weight[0] = 1;
ann->weight[1] = 1;
lfequal(1.0, *genann_run(ann, &a));
a = -10;
lfequal(0.0, *genann_run(ann, &a));
genann_free(ann);
}
void xor() {
2016-02-11 23:38:42 +03:00
genann *ann = genann_init(2, 1, 2, 1);
2016-02-10 02:53:54 +03:00
ann->activation_hidden = genann_act_threshold;
ann->activation_output = genann_act_threshold;
lequal(ann->total_weights, 9);
/* First hidden. */
ann->weight[0] = .5;
ann->weight[1] = 1;
ann->weight[2] = 1;
/* Second hidden. */
ann->weight[3] = 1;
ann->weight[4] = 1;
ann->weight[5] = 1;
/* Output. */
ann->weight[6] = .5;
ann->weight[7] = 1;
ann->weight[8] = -1;
double input[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double output[4] = {0, 1, 1, 0};
lfequal(output[0], *genann_run(ann, input[0]));
lfequal(output[1], *genann_run(ann, input[1]));
lfequal(output[2], *genann_run(ann, input[2]));
lfequal(output[3], *genann_run(ann, input[3]));
genann_free(ann);
}
void backprop() {
2016-02-11 23:38:42 +03:00
genann *ann = genann_init(1, 0, 0, 1);
2016-02-10 02:53:54 +03:00
double input, output;
input = .5;
output = 1;
double first_try = *genann_run(ann, &input);
genann_train(ann, &input, &output, .5);
double second_try = *genann_run(ann, &input);
lok(fabs(first_try - output) > fabs(second_try - output));
genann_free(ann);
}
void train_and() {
double input[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double output[4] = {0, 0, 0, 1};
2016-02-11 23:38:42 +03:00
genann *ann = genann_init(2, 0, 0, 1);
2016-02-10 02:53:54 +03:00
int i, j;
for (i = 0; i < 50; ++i) {
for (j = 0; j < 4; ++j) {
genann_train(ann, input[j], output + j, .8);
}
}
ann->activation_output = genann_act_threshold;
lfequal(output[0], *genann_run(ann, input[0]));
lfequal(output[1], *genann_run(ann, input[1]));
lfequal(output[2], *genann_run(ann, input[2]));
lfequal(output[3], *genann_run(ann, input[3]));
genann_free(ann);
}
void train_or() {
double input[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double output[4] = {0, 1, 1, 1};
2016-02-11 23:38:42 +03:00
genann *ann = genann_init(2, 0, 0, 1);
2016-02-10 02:53:54 +03:00
genann_randomize(ann);
int i, j;
for (i = 0; i < 50; ++i) {
for (j = 0; j < 4; ++j) {
genann_train(ann, input[j], output + j, .8);
}
}
ann->activation_output = genann_act_threshold;
lfequal(output[0], *genann_run(ann, input[0]));
lfequal(output[1], *genann_run(ann, input[1]));
lfequal(output[2], *genann_run(ann, input[2]));
lfequal(output[3], *genann_run(ann, input[3]));
genann_free(ann);
}
void train_xor() {
double input[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double output[4] = {0, 1, 1, 0};
2016-02-11 23:38:42 +03:00
genann *ann = genann_init(2, 1, 2, 1);
2016-02-10 02:53:54 +03:00
int i, j;
for (i = 0; i < 300; ++i) {
for (j = 0; j < 4; ++j) {
genann_train(ann, input[j], output + j, 3);
}
/* printf("%1.2f ", xor_score(ann)); */
}
ann->activation_output = genann_act_threshold;
lfequal(output[0], *genann_run(ann, input[0]));
lfequal(output[1], *genann_run(ann, input[1]));
lfequal(output[2], *genann_run(ann, input[2]));
lfequal(output[3], *genann_run(ann, input[3]));
genann_free(ann);
}
void persist() {
2016-02-11 23:38:42 +03:00
genann *first = genann_init(1000, 5, 50, 10);
2016-02-10 02:53:54 +03:00
FILE *out = fopen("persist.txt", "w");
genann_write(first, out);
fclose(out);
FILE *in = fopen("persist.txt", "r");
2016-02-11 23:38:42 +03:00
genann *second = genann_read(in);
2016-02-10 02:53:54 +03:00
fclose(out);
lequal(first->inputs, second->inputs);
lequal(first->hidden_layers, second->hidden_layers);
lequal(first->hidden, second->hidden);
lequal(first->outputs, second->outputs);
lequal(first->total_weights, second->total_weights);
int i;
for (i = 0; i < first->total_weights; ++i) {
lok(first->weight[i] == second->weight[i]);
}
genann_free(first);
genann_free(second);
}
void copy() {
2016-02-11 23:38:42 +03:00
genann *first = genann_init(1000, 5, 50, 10);
2016-02-10 02:53:54 +03:00
2016-02-11 23:38:42 +03:00
genann *second = genann_copy(first);
2016-02-10 02:53:54 +03:00
lequal(first->inputs, second->inputs);
lequal(first->hidden_layers, second->hidden_layers);
lequal(first->hidden, second->hidden);
lequal(first->outputs, second->outputs);
lequal(first->total_weights, second->total_weights);
int i;
for (i = 0; i < first->total_weights; ++i) {
lfequal(first->weight[i], second->weight[i]);
}
genann_free(first);
genann_free(second);
}
void sigmoid() {
double i = -20;
const double max = 20;
const double d = .0001;
while (i < max) {
lfequal(genann_act_sigmoid(i), genann_act_sigmoid_cached(i));
i += d;
}
}
int main(int argc, char *argv[])
{
printf("GENANN TEST SUITE\n");
srand(100);
lrun("basic", basic);
lrun("xor", xor);
lrun("backprop", backprop);
lrun("train and", train_and);
lrun("train or", train_or);
lrun("train xor", train_xor);
lrun("persist", persist);
lrun("copy", copy);
lrun("sigmoid", sigmoid);
lresults();
return lfails != 0;
}