Skip to content

KAN to classify handwritten digits from the MNIST dataset, providing efficient predictions and automated data handling.

Notifications You must be signed in to change notification settings

CheetahCodes21/KAN-MNIST

Repository files navigation

KAN-MNIST: Kolmogorov-Arnold Networks for MNIST Classification

This repository demonstrates the application of Kolmogorov-Arnold Networks (KAN) for the classification of handwritten digits from the MNIST dataset.The original implementation of KAN is available here.

Overview

The MNIST dataset is a benchmark dataset in the machine learning community, containing 70,000 images of handwritten digits (0-9). This project explores how Kolmogorov-Arnold Networks (KANs) can be applied to build an efficient and effective classification model.

Key Features

  • Custom KAN Implementation: We built the KAN architecture from scratch.
  • Optimized Training: PyTorch is used for fast training with a balance of speed and accuracy.
  • Interactive Visualizations: Explore the learned representations and performance metrics through provided visual tools.

Approach

  1. Dataset Preparation
    The MNIST dataset is automatically downloaded via PyTorch's torchvision.datasets. It is preprocessed and normalized to meet the input requirements for KAN.

  2. KAN Implementation
    The KAN architecture is implemented following the Kolmogorov-Arnold representation theorem, which decomposes complex functions into simpler components for efficient computation.

  3. Training the Model
    The model is trained using PyTorch, minimizing cross-entropy loss while optimizing KAN layer weights. Hyperparameters such as learning rate, batch size, and epochs can be configured.

  4. Evaluation
    After training, the model is evaluated on the test set for accuracy in classifying digits. Visualizations are provided for better understanding.

  5. Saving the Model
    The trained model is saved to kan_mnist_model.pth for future inference or fine-tuning.

Dataset

The MNIST dataset is automatically downloaded using PyTorch's torchvision.datasets package. No manual download is required; the dataset is stored in the data/ folder upon running the code.

from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

Project Structure

KAN-MNIST/
├── handwrittenExamples/   # Example images for prediction
│   └── m8.png             # Sample image for testing
├── efficient_kan.py       # Optimized KAN model
├── KANguess.py            # Prediction script for new images
├── train.py               # Training script for KAN
├── requirements.txt       # Required packages
└── README.md              # Documentation

Running the Project Locally

Prerequisites

Ensure the following are installed:

  • Python 3.8+
  • PyTorch
  • Git

Setup Instructions

  1. Clone the Repository:

    git clone https://github.com/CheetahCodes21/KAN-MNIST.git
    cd KAN-MNIST
  2. Create a Virtual Environment:

    python3 -m venv kan_env
    source kan_env/bin/activate  # On Mac/Linux
    kan_env\Scripts\activate     # On Windows
  3. Install Dependencies:

    pip install -r requirements.txt
  4. Run the Training Script:

    python kan_mnist.py
  5. Evaluate the Model:

    python kan_mnist.py --test
  6. Use the Pretrained Model:

    python kan_mnist.py --load_model

Model Architecture

The KAN model is defined in efficient_kan.py. The architecture used in this project is a custom neural network with the following layers:

  • Custom linear layers KANLinear that incorporate B-splines for interpolation.
  • The KAN model is initialized with a list of hidden layer sizes, and it constructs a series of KANLinear layers based on this list.

Example Initialization

model = KAN([28 * 28, 64, 10])  # Example with input size 28*28, one hidden layer of size 64, and output size 10

Making Predictions

The KANguess.py script is used to make predictions on new images. It loads the trained model, preprocesses the input image, and outputs the predicted digit.

  1. Place your image in the handwrittenExamples directory.
  2. Update the image_path variable in KANguess.py to point to your image.
  3. Run the script:

The KANguess.py script is used to make predictions on new images. It loads the trained model, preprocesses the input image, and outputs the predicted digit. To use this script for predictions:

python KANguess.py

Make sure the input image is a grayscale image of a handwritten digit (similar to the MNIST dataset format) for accurate results.

References

About

KAN to classify handwritten digits from the MNIST dataset, providing efficient predictions and automated data handling.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages