fixed test case 1

This commit is contained in:
Krishna Vedala 2020-05-31 13:04:56 -04:00
parent 47b653e7ca
commit 65b2d92977

View File

@ -205,33 +205,42 @@ void fit(struct adaline *ada, const double **X, const int *y, const int N)
*/
void test1(double eta)
{
const int num_features = 2;
struct adaline ada = new_adaline(num_features, eta); // 2 features
struct adaline ada = new_adaline(2, eta); // 2 features
const int N = 10; // number of sample points
const double saved_X[10][2] = {{0, 1}, {1, -2}, {2, 3}, {3, -1},
{4, 1}, {6, -5}, {-7, -3}, {-8, 5},
{-9, 2}, {-10, -15}};
const double X[10][2] = {{0, 1}, {1, -2}, {2, 3}, {3, -1}, {4, 1},
{6, -5}, {-7, -3}, {-8, 5}, {-9, 2}, {-10, -15}};
const int y[10] = {1, -1, 1, -1, -1,
double **X = (double **)malloc(N * sizeof(double *));
const int Y[10] = {1, -1, 1, -1, -1,
-1, 1, 1, 1, -1}; // corresponding y-values
for (int i = 0; i < N; i++)
{
X[i] = (double *)saved_X[i];
}
printf("------- Test 1 -------\n");
printf("Model before fit: %s", get_weights_str(&ada));
fit(&ada, X, y, N);
fit(&ada, X, Y, N);
printf("Model after fit: %s\n", get_weights_str(&ada));
double test_x[] = {5, -3};
int pred = predict(&ada, test_x, NULL);
printf("Predict for x=(5,-3): %d", pred);
printf("Predict for x=(5,-3): % d", pred);
assert(pred == -1);
printf(" ...passed\n");
double test_x2[] = {5, 8};
pred = predict(&ada, test_x2, NULL);
printf("Predict for x=(5,8): % d\n", pred);
printf("Predict for x=(5, 8): % d", pred);
assert(pred == 1);
printf(" ...passed\n");
// for (int i = 0; i < N; i++)
// free(X[i]);
free(X);
}
/**
@ -376,7 +385,7 @@ int main(int argc, char **argv)
if (argc == 2) // read eta value from commandline argument if present
eta = strtof(argv[1], NULL);
// test1(eta);
test1(eta);
printf("Press ENTER to continue...\n");
getchar();