Skip to content
/ DeepGSB Public

Deep Generalized Schrödinger Bridge, NeurIPS 2022 Oral

License

Notifications You must be signed in to change notification settings

ghliu/DeepGSB

Repository files navigation

Deep Generalized Schrödinger Bridge
[NeurIPS 2022 Oral]

Official PyTorch implementation of the paper "Deep Generalized Schrödinger Bridge (DeepGSB)" which introduces a new class of diffusion models as a scalable numerical solver for Mean-Field Games (MFGs), e.g., population modeling & opinion depolarization, with hard distributional constraints.

Population modeling (crowd navigation) Opinion depolarization
drawing drawing drawing drawing drawing

This repo is co-maintained by Guan-Horng Liu, Tianrong Chen, and Oswin So. Contact us if you have any questions! If you find this library useful, please cite ⬇️

@inproceedings{liu2022deep,
  title={Deep Generalized Schr{\"o}dinger Bridge},
  author={Liu, Guan-Horng and Chen, Tianrong and So, Oswin and Theodorou, Evangelos A},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}

Install

Install the dependencies with Anaconda and activate the environment deepgsb with

conda env create --file requirements.yaml
conda activate deepgsb

Run & Evaluate

The repo contains 2 classes of Mean-Field Games, namely

  • population modeling: GMM, Vneck, Stunnel
  • opinion depolarization: opinion, opinion-1k (dim=1000).

The commands to generate similar results shown in our paper can be found in run.sh. Results, checkpoints, and tensorboard log files will be saved respectively to the folders results/, checkpoint/, and runs/.

bash run.sh <problem> # <problem> can be {GMM, Vneck, Stunnel, opinion, opinion-1k}

You can visualize the trained DeepGSB policies by making gif animation

python make_animation.py --load <path to checkpoint npz> --name <gif name>

Structure

We briefly document the file structure to ease the effort if you wish to integrate DeepGSB with your work flow.

deepgsb/
├── deepgsb.py       # the DeepGSB MFG solver
├── sb_policy.py     # the parametrized Schrödinger Bridge policy
├── loss_lib.py      # all loss functions (IPF/KL, TD, FK/grad)
├── eval_metrics.py  # all logging metrics (Wasserstein, etc)
├── replay_buffer.py
└── util.py
mfg/
├── mfg.py           # the Mean-Field Game environment
├── constraint.py    # the distributional boundary constraint (p0, pT)
├── state_cost.py    # all mean-field interaction state costs (F)
├── sde.py           # the associated stochastic processes (f, sigma)
├── opinion_lib.py   # all utilities for opinion depolarization MFG
├── plotting.py
└── util.py
models/              # the deep networks for parametrizing SB policy
configs/             # the configurations for each MFG

About

Deep Generalized Schrödinger Bridge, NeurIPS 2022 Oral

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published