mirror of
https://github.com/glouw/tinn
synced 2024-11-21 22:11:21 +03:00
comments
This commit is contained in:
parent
31031c8da3
commit
6bd4e7a47f
16
test.c
16
test.c
@ -126,14 +126,17 @@ static Data build(const char* path, const int nips, const int nops)
|
||||
|
||||
int main()
|
||||
{
|
||||
// Input and output size is harded coded here,
|
||||
// so make sure the training data sizes match.
|
||||
// Input and output size is harded coded here as machine learning
|
||||
// repositories usually don't include the input and output size in the data itself.
|
||||
const int nips = 256;
|
||||
const int nops = 10;
|
||||
// Hyper Parameters.
|
||||
// Learning rate is annealed and thus not constant.
|
||||
// It can be fine tuned along with the number of hidden layers.
|
||||
// Feel free to modify the anneal rate as well.
|
||||
const int nhid = 32;
|
||||
double rate = 1.0;
|
||||
const double anneal = 0.9;
|
||||
// Load the training set.
|
||||
const Data data = build("semeion.data", nips, nops);
|
||||
// Train, baby, train.
|
||||
@ -149,21 +152,22 @@ int main()
|
||||
error += xttrain(tinn, in, tg, rate);
|
||||
}
|
||||
printf("error %.12f :: rate %f\n", error / data.rows, rate);
|
||||
rate *= 0.9;
|
||||
rate *= anneal;
|
||||
}
|
||||
// This is how you save the neural network to disk.
|
||||
xtsave(tinn, "saved.tinn");
|
||||
xtfree(tinn);
|
||||
// This is how you load the neural network from disk.
|
||||
const Tinn loaded = xtload("saved.tinn");
|
||||
// Ideally, you would load a testing set for predictions,
|
||||
// but for the sake of brevity the training set is reused.
|
||||
// Now we do a prediction with the neural network we loaded from disk.
|
||||
// Ideally, we would also load a testing set to make the prediction with,
|
||||
// but for the sake of brevity here we just reuse the training set from earlier.
|
||||
const double* const in = data.in[0];
|
||||
const double* const tg = data.tg[0];
|
||||
const double* const pd = xpredict(loaded, in);
|
||||
for(int i = 0; i < data.nops; i++) { printf("%f ", tg[i]); } printf("\n");
|
||||
for(int i = 0; i < data.nops; i++) { printf("%f ", pd[i]); } printf("\n");
|
||||
// Cleanup.
|
||||
// All done. Let's clean up.
|
||||
xtfree(loaded);
|
||||
dfree(data);
|
||||
return 0;
|
||||
|
Loading…
Reference in New Issue
Block a user