mirror of
https://github.com/codeplea/genann
synced 2024-11-21 22:11:34 +03:00
adapt genann_train to general activation function
Note: requires the user to specify a "differential expression" for the activation function, by which I mean its derivative in terms of its function value. Thus limited to strictly increasing, differentiable functions.
This commit is contained in:
parent
4f72209510
commit
a7893ba85c
11
genann.c
11
genann.c
@ -73,6 +73,10 @@ double genann_act_sigmoid(const genann *ann unused, double a) {
|
||||
return 1.0 / (1 + exp(-a));
|
||||
}
|
||||
|
||||
double genann_act_diffexpr_sigmoid(const genann * ann unused, double y) {
|
||||
return y*(1.0-y);
|
||||
}
|
||||
|
||||
void genann_init_sigmoid_lookup(const genann *ann) {
|
||||
const double f = (sigmoid_dom_max - sigmoid_dom_min) / LOOKUP_SIZE;
|
||||
int i;
|
||||
@ -143,6 +147,9 @@ genann *genann_init(int inputs, int hidden_layers, int hidden, int outputs) {
|
||||
|
||||
genann_init_sigmoid_lookup(ret);
|
||||
|
||||
ret->diffexpr_activation_hidden = genann_act_diffexpr_sigmoid;
|
||||
ret->diffexpr_activation_output = genann_act_diffexpr_sigmoid;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -296,7 +303,7 @@ void genann_train(genann const *ann, double const *inputs, double const *desired
|
||||
}
|
||||
} else {
|
||||
for (j = 0; j < ann->outputs; ++j) {
|
||||
*d++ = (*t - *o) * *o * (1.0 - *o);
|
||||
*d++ = (*t - *o) * ann->diffexpr_activation_output(ann, *o);
|
||||
++o; ++t;
|
||||
}
|
||||
}
|
||||
@ -328,7 +335,7 @@ void genann_train(genann const *ann, double const *inputs, double const *desired
|
||||
delta += forward_delta * forward_weight;
|
||||
}
|
||||
|
||||
*d = *o * (1.0-*o) * delta;
|
||||
*d = ann->diffexpr_activation_hidden(ann, *o) * delta;
|
||||
++d; ++o;
|
||||
}
|
||||
}
|
||||
|
6
genann.h
6
genann.h
@ -53,6 +53,11 @@ typedef struct genann {
|
||||
/* Which activation function to use for output. Default: gennann_act_sigmoid_cached*/
|
||||
genann_actfun activation_output;
|
||||
|
||||
/* Derivative of the activation function, expressed in terms of the function value; i.e. f'(f_inverse(y))
|
||||
* Used for backpropagation. Default: y(1-y), corresponding to the sigmoid. */
|
||||
genann_actfun diffexpr_activation_hidden;
|
||||
genann_actfun diffexpr_activation_output;
|
||||
|
||||
/* Total number of weights, and size of weights buffer. */
|
||||
int total_weights;
|
||||
|
||||
@ -97,6 +102,7 @@ void genann_write(genann const *ann, FILE *out);
|
||||
void genann_init_sigmoid_lookup(const genann *ann);
|
||||
double genann_act_sigmoid(const genann *ann, double a);
|
||||
double genann_act_sigmoid_cached(const genann *ann, double a);
|
||||
double genann_act_diffexpr_sigmoid(const genann *ann, double y);
|
||||
double genann_act_threshold(const genann *ann, double a);
|
||||
double genann_act_linear(const genann *ann, double a);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user