This repository demonstrates the application of Kolmogorov-Arnold Networks (KAN) for the classification of handwritten digits from the MNIST dataset.The original implementation of KAN is available here.
The MNIST dataset is a benchmark dataset in the machine learning community, containing 70,000 images of handwritten digits (0-9). This project explores how Kolmogorov-Arnold Networks (KANs) can be applied to build an efficient and effective classification model.
- Custom KAN Implementation: We built the KAN architecture from scratch.
- Optimized Training: PyTorch is used for fast training with a balance of speed and accuracy.
- Interactive Visualizations: Explore the learned representations and performance metrics through provided visual tools.
-
Dataset Preparation
The MNIST dataset is automatically downloaded via PyTorch'storchvision.datasets
. It is preprocessed and normalized to meet the input requirements for KAN. -
KAN Implementation
The KAN architecture is implemented following the Kolmogorov-Arnold representation theorem, which decomposes complex functions into simpler components for efficient computation. -
Training the Model
The model is trained using PyTorch, minimizing cross-entropy loss while optimizing KAN layer weights. Hyperparameters such as learning rate, batch size, and epochs can be configured. -
Evaluation
After training, the model is evaluated on the test set for accuracy in classifying digits. Visualizations are provided for better understanding. -
Saving the Model
The trained model is saved tokan_mnist_model.pth
for future inference or fine-tuning.
The MNIST dataset is automatically downloaded using PyTorch's torchvision.datasets
package. No manual download is required; the dataset is stored in the data/
folder upon running the code.
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
KAN-MNIST/
├── handwrittenExamples/ # Example images for prediction
│ └── m8.png # Sample image for testing
├── efficient_kan.py # Optimized KAN model
├── KANguess.py # Prediction script for new images
├── train.py # Training script for KAN
├── requirements.txt # Required packages
└── README.md # Documentation
Ensure the following are installed:
- Python 3.8+
- PyTorch
- Git
-
Clone the Repository:
git clone https://github.com/CheetahCodes21/KAN-MNIST.git cd KAN-MNIST
-
Create a Virtual Environment:
python3 -m venv kan_env source kan_env/bin/activate # On Mac/Linux kan_env\Scripts\activate # On Windows
-
Install Dependencies:
pip install -r requirements.txt
-
Run the Training Script:
python kan_mnist.py
-
Evaluate the Model:
python kan_mnist.py --test
-
Use the Pretrained Model:
python kan_mnist.py --load_model
The KAN model is defined in efficient_kan.py
. The architecture used in this project is a custom neural network with the following layers:
- Custom linear layers
KANLinear
that incorporate B-splines for interpolation. - The
KAN
model is initialized with a list of hidden layer sizes, and it constructs a series ofKANLinear
layers based on this list.
model = KAN([28 * 28, 64, 10]) # Example with input size 28*28, one hidden layer of size 64, and output size 10
The KANguess.py
script is used to make predictions on new images. It loads the trained model, preprocesses the input image, and outputs the predicted digit.
- Place your image in the
handwrittenExamples
directory. - Update the
image_path
variable inKANguess.py
to point to your image. - Run the script:
The KANguess.py
script is used to make predictions on new images. It loads the trained model, preprocesses the input image, and outputs the predicted digit. To use this script for predictions:
python KANguess.py
Make sure the input image is a grayscale image of a handwritten digit (similar to the MNIST dataset format) for accurate results.