New example: Mnist

This example offers a simple, yet practical application to genann
through number recognition.

Co-authored-by: Daniel Akbarinia <daniel.akbarinia@telecom-paris.fr>
Co-authored-by: Antoine Heitzmann <antoine.heitzmann@telecom-paris.fr>
Co-authored-by: Edouard Pompee <edouard.pompee@telecom-paris.fr>
This commit is contained in:
Roann CANTEL 2024-06-26 08:22:52 +02:00
parent 4f72209510
commit ae10110fb8
4 changed files with 321 additions and 3 deletions

View File

@ -1,7 +1,7 @@
CFLAGS = -Wall -Wshadow -O3 -g -march=native
CFLAGS = -Wall -Wshadow -O3 -I. -g -march=native
LDLIBS = -lm
all: check example1 example2 example3 example4
all: check example1 example2 example3 example4 mnist
sigmoid: CFLAGS += -Dgenann_act=genann_act_sigmoid_cached
sigmoid: all
@ -25,10 +25,11 @@ example3: example3.o genann.o
example4: example4.o genann.o
mnist: mnist.o mnist_db.o genann.o
clean:
$(RM) *.o
$(RM) test example1 example2 example3 example4 *.exe
$(RM) test example1 example2 example3 example4 mnist *.exe
$(RM) persist.txt
.PHONY: sigmoid threshold linear clean

93
mnist.c Normal file
View File

@ -0,0 +1,93 @@
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include "genann.h"
#include "mnist_db.h"
#define CLASS_COUNT 10
int main(int argc, char* argv[])
{
size_t i;
int j;
double output[CLASS_COUNT];
MnistDataset training, tests;
if(argc != 5) {
printf("./mnist [NUMBER OF HIDDEN LAYERS] [NEURON PER HIDDEN LAYERS] [TRAINING ITERATION COUNT] [OUTPUT FILE]");
return 1;
}
if(mnist_init(&training,
"mnist_data/train-images-idx3-ubyte",
"mnist_data/train-labels-idx1-ubyte",
0, 0
))
return 1;
if(mnist_load_batch(&training) != training.batch_size) {
mnist_free(&training);
return 1;
}
if(mnist_init(&tests,
"mnist_data/t10k-images-idx3-ubyte",
"mnist_data/t10k-labels-idx1-ubyte",
0, 0
)) {
mnist_free(&training);
return 1;
}
if(mnist_load_batch(&tests) != tests.batch_size) {
mnist_free(&tests);
mnist_free(&training);
return 1;
}
assert(training.width == tests.width);
assert(training.height == tests.height);
assert(training.width != 0);
assert(training.height != 0);
genann *ann = genann_init(training.width * training.height,
atoi(argv[1]),
atoi(argv[2]),
CLASS_COUNT
);
assert(ann != NULL);
memset(output, 0, CLASS_COUNT * sizeof(double));
for(j = 0; j < atoi(argv[3]); j ++) {
for (i = 0; i < training.batch_size; ++i) {
printf("[Training number %d]: %zd%%\r",
j+1,
(100 * (i+1)) / training.batch_size
);
output[training.batch_entries[i].class] = 1;
genann_train(ann, training.batch_entries[i].pixels, output, 0.25);
output[training.batch_entries[i].class] = 0;
}
printf("\n");
}
FILE *output_file = fopen(argv[4], "w");
if(output_file) {
genann_write(ann, output_file);
fclose(output_file);
} else
perror("fopen");
genann_free(ann);
mnist_free(&training);
mnist_free(&tests);
return 0;
}

176
mnist_db.c Normal file
View File

@ -0,0 +1,176 @@
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "mnist_db.h"
#include "utils.h"
int mnist_init(MnistDataset *output,
const char *images_file,
const char *labels_file,
int transpose,
size_t batch_size
)
{
size_t i;
double *buf;
if(!output)
return -1;
memset(output, 0, sizeof(MnistDataset));
output->transpose = transpose;
#ifndef _MSC_VER
output->fimage = fopen(images_file, "r");
if(!output->fimage) {
perror("fopen");
return -1;
}
#else
if(fopen_s(&output->fimage, images_file, "rb"))
return 1;
#endif
#ifndef _MSC_VER
output->flabel = fopen(labels_file, "r");
if(!output->flabel) {
perror("fopen");
return -1;
}
#else
if(fopen_s(&output->flabel, labels_file, "rb"))
return 1;
#endif
fseek(output->fimage, 4, SEEK_SET);
if(!fread(&output->entries_count, 4, 1, output->fimage)) {
perror("fread1");
return -1;
}
if(!fread(&output->width, 4, 1, output->fimage)) {
perror("fread2");
return -1;
}
if(!fread(&output->height, 4, 1, output->fimage)) {
perror("fread3");
return -1;
}
#ifdef LITTLE_ENDIAN
output->entries_count = CHANGE_ENDIANNESS(output->entries_count);
output->width = CHANGE_ENDIANNESS(output->width);
output->height = CHANGE_ENDIANNESS(output->height);
#endif /* LITTLE_ENDIAN */
if(batch_size != 0)
output->batch_size = batch_size;
else
output->batch_size = output->entries_count;
printf("Batch size: %zd; Width: %d; Height: %d\n",
output->batch_size, output->width, output->height);
output->batch_entries = malloc(sizeof(MnistEntry) * output->batch_size);
if(!output->batch_entries) {
perror("malloc");
return -1;
}
buf = malloc(sizeof(double) * output->width * output->height * output->batch_size);
if(!buf) {
perror("malloc");
return -1;
}
for(i = 0; i < output->batch_size; i ++) {
output->batch_entries[i].class = 0;
output->batch_entries[i].pixels = buf + i * (output->width * output->height);
}
return 0;
}
size_t mnist_load_batch(MnistDataset *dt)
{
size_t i, j;
size_t x, y;
double tmp;
MnistEntry *entry;
const size_t MNIST_ENTRY_SIZE = dt->width * dt->height;
unsigned char buf[MNIST_ENTRY_SIZE + 1];
if(dt->entries_read >= dt->entries_count)
dt->entries_read = 0;
for(i = 0; i < dt->batch_size; i ++, dt->entries_read ++) {
entry = &dt->batch_entries[i];
if(dt->entries_read >= dt->entries_count)
break;
fseek(dt->fimage, MNIST_ENTRY_SIZE * dt->entries_read + 16, SEEK_SET);
fseek(dt->flabel, dt->entries_read + 8, SEEK_SET);
/* Read the label */
if(!fread(buf, 1, 1, dt->flabel)) {
perror("fread");
break;
}
entry->class = (int) buf[0];
/* Read the image */
if(!fread(buf, MNIST_ENTRY_SIZE, 1, dt->fimage)) {
perror("fread");
break;
}
for(j = 0; j < MNIST_ENTRY_SIZE; j ++)
entry->pixels[j] = ((double) buf[j]) / 255.;
}
if(dt->transpose) {
for(i = 0; i < dt->batch_size; i ++) {
entry = &dt->batch_entries[i];
for(x = 0; x < dt->width; x ++) {
for(y = x+1; y < dt->height; y ++) {
/* Swap entry->pixels[x + dt->width * y] and entry->pixels[y + dt->width * x] */
tmp = entry->pixels[x + dt->width * y];
entry->pixels[x + dt->width * y] = entry->pixels[y + dt->width * x];
entry->pixels[y + dt->width * x] = tmp;
}
}
}
}
return i;
}
void mnist_free(MnistDataset *dt)
{
if(!dt)
return;
/*
Cette ligne de code fonctionne, car
elle repose sur le fait que les pixels
des différentes images soient sur un
même buffer contigue, et que l'addresse
du début de ce dit buffer correspond
à l'addresse du début de la première
MnistEntry, d' ce free en particulier.
*/
fclose(dt->fimage);
fclose(dt->flabel);
free(dt->batch_entries[0].pixels);
free(dt->batch_entries);
memset(dt, 0, sizeof(MnistDataset));
}

48
mnist_db.h Normal file
View File

@ -0,0 +1,48 @@
#ifndef _MNIST_DB_H_
#define _MNIST_DB_H_
#include <stddef.h>
#include <stdio.h>
#define CHANGE_ENDIANNESS(a) \
( \
(((a) >> 0) & 0xFF) << 24 | \
(((a) >> 8) & 0xFF) << 16 | \
(((a) >> 16) & 0xFF) << 8 | \
(((a) >> 24) & 0xFF) << 0 \
)
typedef struct MnistEntry MnistEntry;
struct MnistEntry {
int class;
double *pixels;
};
typedef struct MnistDataset MnistDataset;
struct MnistDataset {
unsigned int width;
unsigned int height;
int transpose;
size_t entries_count;
size_t entries_read;
FILE *fimage;
FILE *flabel;
size_t batch_size;
MnistEntry *batch_entries;
};
#define CLASS_COUNT 10
/* Read a dataset from .
Returns -1 in case of a failure, and 0 otherwise. */
int mnist_init(MnistDataset *output, const char *images_file, const char *labels_file, int transpose, size_t batch_size);
size_t mnist_load_batch(MnistDataset *dt);
/* Libère la la base de donnée de la mémoire. */
void mnist_free(MnistDataset *dt);
#endif /* _MNIST_DB_H_ */