Skip to content
/ LDR Public

The official PyTorch implementation of the paper: Xili Dai, Shengbang Tong, et al. "Closed-Loop Data Transcription to an LDR via Minimaxing Rate Reduction.".

License

Notifications You must be signed in to change notification settings

Delay-Xili/LDR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

53 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CTRL — Closed-Loop Data Transcription to an LDR via Minimaxing Rate Reduction

This repository contains the official PyTorch implementation of the paper: Xili Dai, Shengbang Tong, Mingyang Li, Ziyang Wu, Michael Psenka, Kwan Ho Ryan Chan, Pengyuan Zhai, Yaodong Yu, Xiaojun Yuan, Heung Yeung Shum, Yi Ma. "Closed-Loop Data Transcription to an LDR via Minimaxing Rate Reduction". Special Issue "Information Theory and Machine Learning" of Entropy.

Introduction

This work proposes a new computational framework for learning a structured generative model for real-world datasets. In particular, we propose a framework for closed-loop data transcription between a multi-class, high-dimensional data distribution and a linear discriminative representation (CTRL) in the feature space that consists of multiple independent multi-dimensional linear subspaces. This new framework unifies the concepts and benefits of auto-encoding (AE) and generative adversarial networks (GAN). It naturally extends AE and GAN concepts to the setting of learning a discriminative and generative representation for multi-class, high-dimensional, real-world data. Our extensive experiments on many benchmark image datasets demonstrate tremendous potential of this new closed-loop formulation: under fair comparison, visual quality of the learned decoder and classification performance of the encoder are competitive and often better than existing methods based on GAN, VAE, or a combination of both. We hope that this repository serves as a reproducible baseline for future research in this area.

The encoder f has dual roles: it learns an LDR z for the data x via maximizing the rate reduction of z and it is also a “feedback sensor” for any discrepancy between the data x and the decoded \hat{x}. The decoder g also has dual roles: it is a “controller” that corrects the discrepancy between x and \hat{x} and it also aims to minimize the overall coding rate for the learned LDR.

Reproducing Results

Installation for Reproducibility

For ease of reproducibility, we suggest you install Miniconda (or Anaconda if you prefer) before executing the following commands.

git clone https://github.com/Delay-Xili/LDR
cd LDR
conda create -y -n clt
source activate clt
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
pip install git+https://github.com/kwotsin/mimicry.git
mkdir data logs

Note: we highly encourage you to use a version of torch later then 1.10.0, since it gives a large speedup when computing torch.logdet.

More installation details can be found here.

Training

To retrain the neural network from scratch on your own machine, execute the following commands

CUDA_VISIBLE_DEVICES=0 python main.py --cfg experiments/mnist.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0 python main.py --cfg experiments/tmnist.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0 python main.py --cfg experiments/cifar10.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0,1 python main.py --cfg experiments/stl10.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0,1,2 python main.py --cfg experiments/CelebA.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0,1,2 python main.py --cfg experiments/LSUN.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0,1,2 python main.py --cfg experiments/ImageNet.yaml DATA.ROOT pth/to/the/dataset

Some hyper-parameters can be changed directly in the corresponding xxx.yaml file. We run the experiments on an NVIDIA RTX 3090 with 24GB memory. Adjust the CUDA_VISIBLE_DEVICES parameter based on available GPUs.

Pre-trained Models

You can download our trained models from the following links:

Datasets Models Results
MNIST mini dcgan link
TMNIST mini dcgan link
CIFAR-10 mini dcgan link
CIFAR-10 sngan32 TBD
STL-10 sngan48 TBD
CelebA sngan128 link
LSUN sngan128 link
ImageNet sngan128 link

Each link includes the corresponding results, which consists of three items: checkpoints, images, and data.
checkpoints: including all saved checkpoint files of the generator and discriminator during the training.
images: including all saved input and reconstructed images during the training.
data: including the Tensorboard file which records the losses and learning rates of discriminator and generator during the training process.

Evaluating the FID and IS score

To evaluate the FID and IS score of your checkpoints under checkpoints/, execute

CUDA_VISIBLE_DEVICES=0 python evaluation.py --cfg experiments/mnist.yaml EVAL.NETD_CKPT path/to/netD/ckpt EVAL.NETG_CKPT path/to/netG/ckpt
CUDA_VISIBLE_DEVICES=0 python evaluation.py --cfg experiments/cifar10.yaml EVAL.NETD_CKPT path/to/netD/ckpt EVAL.NETG_CKPT path/to/netG/ckpt

Testing the classification accuracy

To test the accuracy of your learned discriminator, execute

CUDA_VISIBLE_DEVICES=0 python test_acc.py --cfg pth/to/mnist/result/config.yaml --ckpt_epochs 4500 EVAL.DATA_SAMPLE 1000
CUDA_VISIBLE_DEVICES=0 python test_acc.py --cfg pth/to/cifar/result/config.yaml --ckpt_epochs 45000 EVAL.DATA_SAMPLE 1000

MNIST classification accuracy: 97.69%, CIFAR-10 classification accuracy: 73.05%.

Citation

If you find CLT useful in your research, please consider citing:

@article{dai2021closed,
  title={CTRL: Closed-Loop Transcription to an LDR via Minimaxing Rate Reduction},
  author={Dai, Xili and Tong, Shengbang and Li, Mingyang and Wu, Ziyang and Chan, Kwan Ho Ryan and Zhai, Pengyuan and Yu, Yaodong and Psenka, Michael and Yuan, Xiaojun and Shum, Heung Yeung and others},
  journal = {Entropy},
  volume = {24},
  year = {2022},
  number = {4},
  article-number = {456},
  url = {https://www.mdpi.com/1099-4300/24/4/456},
  issn = {1099-4300},
  doi = {10.3390/e24040456}
}

License

See LICENSE for details.

About

The official PyTorch implementation of the paper: Xili Dai, Shengbang Tong, et al. "Closed-Loop Data Transcription to an LDR via Minimaxing Rate Reduction.".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages