This repository includes the following:
diffusion
package that provides a clean, modular, and minimalistic implementation of different components and algorithms used in diffusion-based generative modeling (with references to key papers), andplayground
folder that contains a collection of examples that demonstrate diffusion-based generative modeling on different kinds of data (2D points, MNIST, CIFAR10, 3D point clouds, etc.)
The package consists of three core modules:
denoisers
module provides:KarrasDenoiser
: A thin wrapper around arbitrary neural nets that enable preconditioning of inputs and outputs, as proposed by Karras et al., (2022). The wrapper is agnostic to model architectures and only expects the shape of input and output tensors to match.KarrasOptimalDenoiser
: The optimal denoiser that corresponds to the analytical minimum of the denoising loss for a given training dataset.
training
module provides functionality for training diffusion models:- Loss functions (code): provides denoising MSE loss functions, including the original simple denoising loss of Ho et al. (2020) and preconditioned MSE loss of Karras et al., (2022).
- Loss weighting schemes (code): a collection of weighting schemes that assign different weights to losses computed for different noise levels, including the SNR-based weighting proposed by Hang et al. (2022).
- Noise level samplers (code): determine how noise levels are sampled during training at each step; the denoising loss is computed for the sampled noise levels, averaged, and optimized w.r.t. model parameters.
- Lightning model (code): a
LightningModule
class that puts all pieces together and enables training denoising models using Pytorch Lightning.
inference
modules provides functionality for sampling from trained diffusion models:
This is a very toy example, where each data instance is a 2D point that lies on a swiss-roll 1D manifold. Given that the data is so simple, it's a perfect playground for experimenting with different approaches to training and inference, visualizing diffusion trajectories, and building intuition. Both training and inference can comfortably run on a laptop (it takes a minute or so to train the model to convergence).
Colab notebook: (TODO: add link to the notebook)
Another toy example, where diffusion model is trained on MNIST. Model architectures are scaled down versions of the U-nets used on CIFAR10 and ImageNet benchmarks (all the architecture code is copied from https://github.com/NVlabs/edm/blob/main/training/networks.py verbatim). It takes about 1 hour to train an MNIST denoiser in Google Colab using a T4 GPU for 20 epochs or so. And running inference takes just a few seconds.
Colab notebook: (TODO: add link to the notebook)
In this example, we train U-net diffusion model on CIFAR10 benchmark.
The model can be trained using playground/cifar10/train.py
script (takes a few days to train on multiple GPUs), using the architecture and the best hyperparameters given by Karras et al. (2022).
Running inference takes just a few seconds, and can be done using different ODE solvers.
Colab notebook: (TODO: add link to the notebook)