adapt genann_train to general activation function

Note: requires the user to specify a "differential expression" for the
activation function, by which I mean its derivative in terms of its
function value.
Thus limited to strictly increasing, differentiable functions.
This commit is contained in:
PJ v M 2023-09-14 13:36:40 +00:00
parent 4f72209510
commit a7893ba85c
2 changed files with 15 additions and 2 deletions

View File

@ -73,6 +73,10 @@ double genann_act_sigmoid(const genann *ann unused, double a) {
return 1.0 / (1 + exp(-a));
}
double genann_act_diffexpr_sigmoid(const genann * ann unused, double y) {
return y*(1.0-y);
}
void genann_init_sigmoid_lookup(const genann *ann) {
const double f = (sigmoid_dom_max - sigmoid_dom_min) / LOOKUP_SIZE;
int i;
@ -143,6 +147,9 @@ genann *genann_init(int inputs, int hidden_layers, int hidden, int outputs) {
genann_init_sigmoid_lookup(ret);
ret->diffexpr_activation_hidden = genann_act_diffexpr_sigmoid;
ret->diffexpr_activation_output = genann_act_diffexpr_sigmoid;
return ret;
}
@ -296,7 +303,7 @@ void genann_train(genann const *ann, double const *inputs, double const *desired
}
} else {
for (j = 0; j < ann->outputs; ++j) {
*d++ = (*t - *o) * *o * (1.0 - *o);
*d++ = (*t - *o) * ann->diffexpr_activation_output(ann, *o);
++o; ++t;
}
}
@ -328,7 +335,7 @@ void genann_train(genann const *ann, double const *inputs, double const *desired
delta += forward_delta * forward_weight;
}
*d = *o * (1.0-*o) * delta;
*d = ann->diffexpr_activation_hidden(ann, *o) * delta;
++d; ++o;
}
}

View File

@ -53,6 +53,11 @@ typedef struct genann {
/* Which activation function to use for output. Default: gennann_act_sigmoid_cached*/
genann_actfun activation_output;
/* Derivative of the activation function, expressed in terms of the function value; i.e. f'(f_inverse(y))
* Used for backpropagation. Default: y(1-y), corresponding to the sigmoid. */
genann_actfun diffexpr_activation_hidden;
genann_actfun diffexpr_activation_output;
/* Total number of weights, and size of weights buffer. */
int total_weights;
@ -97,6 +102,7 @@ void genann_write(genann const *ann, FILE *out);
void genann_init_sigmoid_lookup(const genann *ann);
double genann_act_sigmoid(const genann *ann, double a);
double genann_act_sigmoid_cached(const genann *ann, double a);
double genann_act_diffexpr_sigmoid(const genann *ann, double y);
double genann_act_threshold(const genann *ann, double a);
double genann_act_linear(const genann *ann, double a);