Skip to content

The official implement of "Quantum Generative Models for Image Generation: Insights from MNIST and MedMNIST. A novel approach to image generation using quantum-enhanced diffusion models".

Notifications You must be signed in to change notification settings

ChiShengChen/Quantum_MNIST_Diffusion

Repository files navigation

Quantum Diffusion Models

arXiv
Python PyTorch PennyLane

The official implement of "Quantum Generative Models for Image Generation: Insights from MNIST and MedMNIST". A novel approach to image generation using quantum-enhanced diffusion models. This project implements diffusion models enhanced with quantum circuits for medical and standard image generation.

๐Ÿ“‹ Overview

image image

This repository explores the integration of quantum computing into diffusion models for image generation. The implementation provides both classical and quantum-enhanced versions of diffusion models for MNIST and PathMNIST datasets.

Key features:

  • Quantum-enhanced attention mechanism for diffusion models
  • Classical vs quantum model comparison framework
  • Evaluation metrics (FID, SSIM) for generated images
  • Support for MNIST and PathMNIST medical datasets
  • We trained the quantum diffusion model with fewer than 100 images, demonstrating the advantage of quantum layers in low-data regimes.

๐Ÿš€ Models

Diffusion Model Architecture

  • U-Net backbone with residual blocks and skip connections
  • Flexible channels for both MNIST (grayscale) and PathMNIST (RGB)
  • Timestep embedding using sinusoidal positional encoding
  • Cosine beta scheduling for improved sampling
  • Exponential Moving Average (EMA) for stable training

Quantum Enhancement

  • Hybrid quantum-classical model with quantum attention layers
  • Parameterized quantum circuits implemented using PennyLane
  • RY and RZ rotations with CNOT entanglement structure
  • Quantum feature re-weighting mechanism

๐Ÿ’ฟ Datasets

MNIST

  • Standard handwritten digit recognition dataset
  • Trained on individual digit classes (0-9)
  • Grayscale images (1-channel, 28ร—28)

PathMNIST

  • Medical imaging dataset from MedMNIST collection
  • Colorectal cancer histology patches
  • RGB images (3-channel, 28ร—28)
  • Class-conditional training

๐Ÿ“Š Results

Training Progression Comparison

The following GIFs demonstrate the training progression of both classical and quantum diffusion models for each MNIST digit. Notice how the models learn to generate increasingly refined digit representations over 30 epochs:

Digit 0

Classical Model Quantum Model

Digit 1

Classical Model Quantum Model

Digit 2

Classical Model Quantum Model

Digit 3

Classical Model Quantum Model

Digit 4

Classical Model Quantum Model

Digit 5

Classical Model Quantum Model

Digit 6

Classical Model Quantum Model

Digit 7

Classical Model Quantum Model

Digit 8

Classical Model Quantum Model

Digit 9

Classical Model Quantum Model

Quantitative Evaluation

The project evaluates generated images using:

  • Frรฉchet Inception Distance (FID): measures the similarity between generated and real image distributions
  • Structural Similarity Index (SSIM): measures the perceptual difference between images

Sample results comparing classical and quantum models:

Model Dataset FIDโ†“ SSIMโ†‘
Classical MNIST 271.05 0.1085
Quantum MNIST 259.25 0.1263
------- --------- ------ -------
Classical PathMNIST 95.72 0.4107
Quantum PathMNIST 84.40 0.0931

I tried using the full skip-connection U-Net (v8) for this generation as well, but it didnโ€™t outperform the lightweight one.

Model Dataset FIDโ†“ SSIMโ†‘
Classical MNIST 275.68 0.0267
Quantum MNIST 288.40 0.0323

๐Ÿ”ง Implementation

Training

# Train classical diffusion model on MNIST
python quantum_difussion_mnist_v7.py  # --use_quantum=False

# Train quantum diffusion model on MNIST
python quantum_difussion_mnist_v7.py  # --use_quantum=True

# Train on PathMNIST
python quantum_difussion_pathmnist_v7.py  # --use_quantum=True/False

Evaluation

# Evaluate generated PathMNIST samples
python cal_fid_ssim_medmnist.py

# Evaluate generated MNIST samples
python cal_fid_ssim.py

# Debug image splitting for FID calculation
python debug_img.py

๐Ÿ“ฆ Installation

# Clone the repository
git clone https://github.com/ChiShengChen/Quantum_MNIST_Diffusion.git
cd Quantum_MNIST_Diffusion

# Create a conda environment
conda create -n quantum-diffusion python=3.8
conda activate quantum-diffusion

# Install dependencies
pip install torch torchvision tqdm matplotlib pennylane medmnist scikit-image scipy

๐Ÿ“ Citation

If you use this code for your research, please cite:

@article{chen2025quantum,
  title={Quantum Generative Models for Image Generation: Insights from MNIST and MedMNIST},
  author={Chen, Chi-Sheng and Hou, Wei An and Hu, Siang-Wei and Cai, Zhen-Sheng},
  journal={arXiv preprint arXiv:2504.00034},
  year={2025}
}

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

About

The official implement of "Quantum Generative Models for Image Generation: Insights from MNIST and MedMNIST. A novel approach to image generation using quantum-enhanced diffusion models".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages