mirror of
https://github.com/codeplea/genann
synced 2024-11-21 22:11:34 +03:00
New example: Mnist
This example offers a simple, yet practical application to genann through number recognition. Co-authored-by: Daniel Akbarinia <daniel.akbarinia@telecom-paris.fr> Co-authored-by: Antoine Heitzmann <antoine.heitzmann@telecom-paris.fr> Co-authored-by: Edouard Pompee <edouard.pompee@telecom-paris.fr>
This commit is contained in:
parent
4f72209510
commit
ae10110fb8
7
Makefile
7
Makefile
@ -1,7 +1,7 @@
|
||||
CFLAGS = -Wall -Wshadow -O3 -g -march=native
|
||||
CFLAGS = -Wall -Wshadow -O3 -I. -g -march=native
|
||||
LDLIBS = -lm
|
||||
|
||||
all: check example1 example2 example3 example4
|
||||
all: check example1 example2 example3 example4 mnist
|
||||
|
||||
sigmoid: CFLAGS += -Dgenann_act=genann_act_sigmoid_cached
|
||||
sigmoid: all
|
||||
@ -25,10 +25,11 @@ example3: example3.o genann.o
|
||||
|
||||
example4: example4.o genann.o
|
||||
|
||||
mnist: mnist.o mnist_db.o genann.o
|
||||
|
||||
clean:
|
||||
$(RM) *.o
|
||||
$(RM) test example1 example2 example3 example4 *.exe
|
||||
$(RM) test example1 example2 example3 example4 mnist *.exe
|
||||
$(RM) persist.txt
|
||||
|
||||
.PHONY: sigmoid threshold linear clean
|
||||
|
93
mnist.c
Normal file
93
mnist.c
Normal file
@ -0,0 +1,93 @@
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <strings.h>
|
||||
|
||||
#include "genann.h"
|
||||
#include "mnist_db.h"
|
||||
|
||||
#define CLASS_COUNT 10
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
size_t i;
|
||||
int j;
|
||||
double output[CLASS_COUNT];
|
||||
MnistDataset training, tests;
|
||||
|
||||
if(argc != 5) {
|
||||
printf("./mnist [NUMBER OF HIDDEN LAYERS] [NEURON PER HIDDEN LAYERS] [TRAINING ITERATION COUNT] [OUTPUT FILE]");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if(mnist_init(&training,
|
||||
"mnist_data/train-images-idx3-ubyte",
|
||||
"mnist_data/train-labels-idx1-ubyte",
|
||||
0, 0
|
||||
))
|
||||
return 1;
|
||||
|
||||
if(mnist_load_batch(&training) != training.batch_size) {
|
||||
mnist_free(&training);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if(mnist_init(&tests,
|
||||
"mnist_data/t10k-images-idx3-ubyte",
|
||||
"mnist_data/t10k-labels-idx1-ubyte",
|
||||
0, 0
|
||||
)) {
|
||||
mnist_free(&training);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if(mnist_load_batch(&tests) != tests.batch_size) {
|
||||
mnist_free(&tests);
|
||||
mnist_free(&training);
|
||||
return 1;
|
||||
}
|
||||
|
||||
assert(training.width == tests.width);
|
||||
assert(training.height == tests.height);
|
||||
assert(training.width != 0);
|
||||
assert(training.height != 0);
|
||||
|
||||
genann *ann = genann_init(training.width * training.height,
|
||||
atoi(argv[1]),
|
||||
atoi(argv[2]),
|
||||
CLASS_COUNT
|
||||
);
|
||||
assert(ann != NULL);
|
||||
|
||||
memset(output, 0, CLASS_COUNT * sizeof(double));
|
||||
|
||||
for(j = 0; j < atoi(argv[3]); j ++) {
|
||||
for (i = 0; i < training.batch_size; ++i) {
|
||||
printf("[Training number %d]: %zd%%\r",
|
||||
j+1,
|
||||
(100 * (i+1)) / training.batch_size
|
||||
);
|
||||
|
||||
output[training.batch_entries[i].class] = 1;
|
||||
genann_train(ann, training.batch_entries[i].pixels, output, 0.25);
|
||||
output[training.batch_entries[i].class] = 0;
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
FILE *output_file = fopen(argv[4], "w");
|
||||
if(output_file) {
|
||||
genann_write(ann, output_file);
|
||||
fclose(output_file);
|
||||
} else
|
||||
perror("fopen");
|
||||
|
||||
genann_free(ann);
|
||||
|
||||
mnist_free(&training);
|
||||
mnist_free(&tests);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
176
mnist_db.c
Normal file
176
mnist_db.c
Normal file
@ -0,0 +1,176 @@
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "mnist_db.h"
|
||||
#include "utils.h"
|
||||
|
||||
int mnist_init(MnistDataset *output,
|
||||
const char *images_file,
|
||||
const char *labels_file,
|
||||
int transpose,
|
||||
size_t batch_size
|
||||
)
|
||||
{
|
||||
size_t i;
|
||||
double *buf;
|
||||
|
||||
if(!output)
|
||||
return -1;
|
||||
|
||||
memset(output, 0, sizeof(MnistDataset));
|
||||
output->transpose = transpose;
|
||||
|
||||
#ifndef _MSC_VER
|
||||
output->fimage = fopen(images_file, "r");
|
||||
|
||||
if(!output->fimage) {
|
||||
perror("fopen");
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
if(fopen_s(&output->fimage, images_file, "rb"))
|
||||
return 1;
|
||||
#endif
|
||||
|
||||
#ifndef _MSC_VER
|
||||
output->flabel = fopen(labels_file, "r");
|
||||
|
||||
if(!output->flabel) {
|
||||
perror("fopen");
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
if(fopen_s(&output->flabel, labels_file, "rb"))
|
||||
return 1;
|
||||
#endif
|
||||
|
||||
fseek(output->fimage, 4, SEEK_SET);
|
||||
|
||||
if(!fread(&output->entries_count, 4, 1, output->fimage)) {
|
||||
perror("fread1");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if(!fread(&output->width, 4, 1, output->fimage)) {
|
||||
perror("fread2");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if(!fread(&output->height, 4, 1, output->fimage)) {
|
||||
perror("fread3");
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef LITTLE_ENDIAN
|
||||
output->entries_count = CHANGE_ENDIANNESS(output->entries_count);
|
||||
output->width = CHANGE_ENDIANNESS(output->width);
|
||||
output->height = CHANGE_ENDIANNESS(output->height);
|
||||
#endif /* LITTLE_ENDIAN */
|
||||
|
||||
if(batch_size != 0)
|
||||
output->batch_size = batch_size;
|
||||
else
|
||||
output->batch_size = output->entries_count;
|
||||
|
||||
printf("Batch size: %zd; Width: %d; Height: %d\n",
|
||||
output->batch_size, output->width, output->height);
|
||||
|
||||
output->batch_entries = malloc(sizeof(MnistEntry) * output->batch_size);
|
||||
if(!output->batch_entries) {
|
||||
perror("malloc");
|
||||
return -1;
|
||||
}
|
||||
|
||||
buf = malloc(sizeof(double) * output->width * output->height * output->batch_size);
|
||||
if(!buf) {
|
||||
perror("malloc");
|
||||
return -1;
|
||||
}
|
||||
|
||||
for(i = 0; i < output->batch_size; i ++) {
|
||||
output->batch_entries[i].class = 0;
|
||||
output->batch_entries[i].pixels = buf + i * (output->width * output->height);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t mnist_load_batch(MnistDataset *dt)
|
||||
{
|
||||
size_t i, j;
|
||||
size_t x, y;
|
||||
double tmp;
|
||||
MnistEntry *entry;
|
||||
const size_t MNIST_ENTRY_SIZE = dt->width * dt->height;
|
||||
unsigned char buf[MNIST_ENTRY_SIZE + 1];
|
||||
|
||||
if(dt->entries_read >= dt->entries_count)
|
||||
dt->entries_read = 0;
|
||||
|
||||
for(i = 0; i < dt->batch_size; i ++, dt->entries_read ++) {
|
||||
entry = &dt->batch_entries[i];
|
||||
|
||||
if(dt->entries_read >= dt->entries_count)
|
||||
break;
|
||||
|
||||
fseek(dt->fimage, MNIST_ENTRY_SIZE * dt->entries_read + 16, SEEK_SET);
|
||||
fseek(dt->flabel, dt->entries_read + 8, SEEK_SET);
|
||||
|
||||
/* Read the label */
|
||||
if(!fread(buf, 1, 1, dt->flabel)) {
|
||||
perror("fread");
|
||||
break;
|
||||
}
|
||||
entry->class = (int) buf[0];
|
||||
|
||||
/* Read the image */
|
||||
if(!fread(buf, MNIST_ENTRY_SIZE, 1, dt->fimage)) {
|
||||
perror("fread");
|
||||
break;
|
||||
}
|
||||
|
||||
for(j = 0; j < MNIST_ENTRY_SIZE; j ++)
|
||||
entry->pixels[j] = ((double) buf[j]) / 255.;
|
||||
}
|
||||
|
||||
if(dt->transpose) {
|
||||
for(i = 0; i < dt->batch_size; i ++) {
|
||||
entry = &dt->batch_entries[i];
|
||||
for(x = 0; x < dt->width; x ++) {
|
||||
for(y = x+1; y < dt->height; y ++) {
|
||||
/* Swap entry->pixels[x + dt->width * y] and entry->pixels[y + dt->width * x] */
|
||||
tmp = entry->pixels[x + dt->width * y];
|
||||
entry->pixels[x + dt->width * y] = entry->pixels[y + dt->width * x];
|
||||
entry->pixels[y + dt->width * x] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
void mnist_free(MnistDataset *dt)
|
||||
{
|
||||
if(!dt)
|
||||
return;
|
||||
|
||||
/*
|
||||
Cette ligne de code fonctionne, car
|
||||
elle repose sur le fait que les pixels
|
||||
des différentes images soient sur un
|
||||
même buffer contigue, et que l'addresse
|
||||
du début de ce dit buffer correspond
|
||||
à l'addresse du début de la première
|
||||
MnistEntry, d'où ce free en particulier.
|
||||
*/
|
||||
fclose(dt->fimage);
|
||||
fclose(dt->flabel);
|
||||
|
||||
free(dt->batch_entries[0].pixels);
|
||||
free(dt->batch_entries);
|
||||
|
||||
memset(dt, 0, sizeof(MnistDataset));
|
||||
}
|
48
mnist_db.h
Normal file
48
mnist_db.h
Normal file
@ -0,0 +1,48 @@
|
||||
#ifndef _MNIST_DB_H_
|
||||
#define _MNIST_DB_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#define CHANGE_ENDIANNESS(a) \
|
||||
( \
|
||||
(((a) >> 0) & 0xFF) << 24 | \
|
||||
(((a) >> 8) & 0xFF) << 16 | \
|
||||
(((a) >> 16) & 0xFF) << 8 | \
|
||||
(((a) >> 24) & 0xFF) << 0 \
|
||||
)
|
||||
|
||||
typedef struct MnistEntry MnistEntry;
|
||||
struct MnistEntry {
|
||||
int class;
|
||||
double *pixels;
|
||||
};
|
||||
|
||||
typedef struct MnistDataset MnistDataset;
|
||||
struct MnistDataset {
|
||||
unsigned int width;
|
||||
unsigned int height;
|
||||
int transpose;
|
||||
|
||||
size_t entries_count;
|
||||
size_t entries_read;
|
||||
|
||||
FILE *fimage;
|
||||
FILE *flabel;
|
||||
|
||||
size_t batch_size;
|
||||
MnistEntry *batch_entries;
|
||||
};
|
||||
|
||||
#define CLASS_COUNT 10
|
||||
|
||||
/* Read a dataset from .
|
||||
Returns -1 in case of a failure, and 0 otherwise. */
|
||||
int mnist_init(MnistDataset *output, const char *images_file, const char *labels_file, int transpose, size_t batch_size);
|
||||
|
||||
size_t mnist_load_batch(MnistDataset *dt);
|
||||
|
||||
/* Libère la la base de donnée de la mémoire. */
|
||||
void mnist_free(MnistDataset *dt);
|
||||
|
||||
#endif /* _MNIST_DB_H_ */
|
Loading…
Reference in New Issue
Block a user