Skip to content

This repository implements a simple Artificial Neural Network (ANN) from scratch using only NumPy.

Notifications You must be signed in to change notification settings

ProfessorNova/ANN-Scratch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ANN-Scratch

This repository implements a simple Artificial Neural Network (ANN) from scratch using only numpy. It was tested on the MNIST dataset and achieved an accuracy of around 95% after 50 epochs (The hyperparameters were not tuned so there is room for improvement).


Getting Started

Installation

You basically just have to have numpy installed (as well as matplotlib if you want to plot the data). You can install them using pip:

pip install numpy matplotlib

Then clone the repository:

git clone https://github.com/ProfessorNova/ANN-Scratch.git
cd ANN-Scratch

The code was tested on Python 3.10.11.

Usage

The functionality is shown with visualisation in train.ipynb. You can also run the code in train.py with the following command (there you will only see the loss and accuracy printed in the console):

python train.py

Components

The repository is divided into the following components:

  • lib/activations_functions.py: Contains the activation functions and their derivatives. The following activation functions are implemented:

    • Sigmoid
    • Linear
    • Softmax
    • ReLU
  • lib/neural_layer.py: Contains the NeuralLayer class which represents a layer in the neural network. It contains the forward and backward methods as well as a method to update the weights and biases.

  • lib/neural_network.py: Contains the NeuralNetwork class which represents the neural network. It implements the backpropagation algorithm and stochastic gradient descent. It also has methods to save and load the model.

  • lib/data_loader.py: Contains a function to load the given mnist_test.csv and mnist_train.csv files. Furthermore, it automatically preprocesses the data by normalizing it and converting the labels to one-hot encoding.

About

This repository implements a simple Artificial Neural Network (ANN) from scratch using only NumPy.

Topics

Resources

Stars

Watchers

Forks