Skip to content

[WIP] Make FastDEQs fast again #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 76 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
d593c41
move towards efl
avik-pal Apr 9, 2022
a032635
Some model code
avik-pal Apr 9, 2022
8ec596d
Update model
avik-pal Apr 10, 2022
c6fd5e5
Update examples
avik-pal Apr 10, 2022
9b6c74c
add cifar10 training script
avik-pal Apr 10, 2022
cbc2adf
Add dates dep
avik-pal Apr 10, 2022
7f7312c
Fix training
avik-pal Apr 10, 2022
2c296a1
RIP ordering
avik-pal Apr 10, 2022
98a4003
Some printing
avik-pal Apr 10, 2022
b3a6a08
Some printing
avik-pal Apr 10, 2022
47f9ea3
Stop printing
avik-pal Apr 10, 2022
6ec4c7b
missing acc
avik-pal Apr 10, 2022
0844f88
Add distributed training
avik-pal Apr 11, 2022
7f3273a
Make float32
avik-pal Apr 11, 2022
37f5c87
HPC safe precompile
avik-pal Apr 11, 2022
30da3c5
Change solver from VCABM4 since it changes inputs to FP64
avik-pal Apr 11, 2022
b3f301e
ADAMW syncing not working for MPI
avik-pal Apr 12, 2022
ed4c40f
Reduce using MPI
avik-pal Apr 12, 2022
e89b401
Fixed depth for pretraining
avik-pal Apr 13, 2022
07634bc
Fixed depth testing
avik-pal Apr 13, 2022
3968f1c
Missing state
avik-pal Apr 13, 2022
a951f98
Imagenet model configuration
avik-pal Apr 13, 2022
005f35f
Imagenet model configuration
avik-pal Apr 13, 2022
3745309
Update config
avik-pal Apr 13, 2022
f67a776
Int divide
avik-pal Apr 13, 2022
7a401be
exptconfig
avik-pal Apr 13, 2022
e51a483
Oops
avik-pal Apr 13, 2022
e3f411f
Remove nothing
avik-pal Apr 13, 2022
43fd1ac
Warmup the pretraining mode
avik-pal Apr 13, 2022
1b57511
updates
avik-pal Apr 14, 2022
b57102d
Update training
avik-pal Apr 16, 2022
8e84ae2
Update defaults
avik-pal Apr 17, 2022
ab2afcd
Run for different configurations
avik-pal Apr 18, 2022
a016807
Fix model construction
avik-pal Apr 18, 2022
c3cb712
Better scheduling
avik-pal Apr 19, 2022
27703a5
Dont track statistics
avik-pal Apr 20, 2022
ce8cddf
Temp fix for keys error
avik-pal Apr 20, 2022
d12347f
Move to component arrays
avik-pal Apr 20, 2022
c019b20
Use named tuple
avik-pal Apr 21, 2022
1ba6736
Add termination.jl
avik-pal Apr 21, 2022
94f8fb8
extra in train
avik-pal Apr 21, 2022
5e45c4f
Update config
avik-pal Apr 21, 2022
7934860
Modify architecture to make it GPU friendly
avik-pal Apr 23, 2022
4953087
Update
avik-pal Apr 23, 2022
0aef672
Use pretraining
avik-pal Apr 24, 2022
3a12e0a
Fix backpass for variational autoencoder
avik-pal Apr 24, 2022
9c2626d
Fix dep
avik-pal Apr 24, 2022
0b6a927
Partial update to Lux
avik-pal May 3, 2022
336a19a
Update FastDEQExperiments
avik-pal May 3, 2022
cad992e
Compat entries
avik-pal May 3, 2022
db98546
Dep fixes
avik-pal May 6, 2022
19f2911
Dep fixes
avik-pal May 6, 2022
0eca98f
Update src
avik-pal May 12, 2022
25cea0f
Docs
avik-pal May 12, 2022
c0ed127
Docs
avik-pal May 12, 2022
7b71612
Docs
avik-pal May 12, 2022
ea24bfc
make it val
avik-pal May 13, 2022
afc8b79
Type Inference fixes
avik-pal May 13, 2022
34517ea
Cant typeassert for state
avik-pal May 14, 2022
fdc460b
Remove typeasserts
avik-pal May 14, 2022
be6f72d
Update script
avik-pal May 16, 2022
ff9462c
Pretraining
avik-pal May 17, 2022
57a7abb
Config
avik-pal May 17, 2022
513172a
Fix data augmentation
avik-pal May 18, 2022
d709846
Wandb logging
avik-pal May 23, 2022
50b879d
modify normalization layers
avik-pal May 23, 2022
d709d7f
modify normalization layers
avik-pal May 24, 2022
6897ec3
Modify model architecture
avik-pal May 26, 2022
4555e10
Modify model architecture
avik-pal May 26, 2022
9fa1c32
Use weight norm
avik-pal May 26, 2022
0ef3ffa
Update model
avik-pal May 26, 2022
3df322d
Relax types
avik-pal May 26, 2022
1cd9a6e
Minor fix
avik-pal May 26, 2022
edbb7a3
Make post_deq type-stable
avik-pal May 26, 2022
a9af2e3
No MPI for cifar10
avik-pal May 26, 2022
7e42c90
Decay the weight of skip
avik-pal May 27, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ jobs:
${{ runner.os }}-test-
${{ runner.os }}-
- name: Install dependencies
run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/FluxExperimental.jl"); Pkg.add(url="https://github.com/SciML/DiffEqSensitivity.jl", rev="ap/fastdeq"); Pkg.instantiate()'
# FIXME: Remove once Lux.jl is registered
run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.instantiate()'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v2
with:
coverage: false
files: lcov.info
7 changes: 4 additions & 3 deletions .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
push:
branches:
- main
tags: '*'
tags: "*"
pull_request:

jobs:
Expand All @@ -14,9 +14,10 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: '1'
version: "1"
- name: Install dependencies
run: julia --project=docs -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/FluxExperimental.jl"); Pkg.add(url="https://github.com/SciML/DiffEqSensitivity.jl", rev="ap/fastdeq"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
# FIXME: Remove once Lux.jl is registered
run: julia --project=docs -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
- name: Build and deploy
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token
Expand Down
8 changes: 6 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
wandb/
.vscode
data/
/Manifest.toml
build
Manifest.toml
build
statprof
profs
logs
benchmarking
23 changes: 15 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,53 @@ version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FluxExperimental = "c0d22e4d-7f3e-44a4-9c97-37045f84daf2"
FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "3"
ChainRulesCore = "1"
DiffEqBase = "6"
DiffEqCallbacks = "2.20.1"
DiffEqSensitivity = "6.64"
Flux = "0.12"
FluxMPI = "0.1.1"
Functors = "0.2"
LinearSolve = "1"
Lux = "0.4"
MLUtils = "0.2"
OrdinaryDiffEq = "6"
SciMLBase = "1.19"
Setfield = "0.8, 1"
SteadyStateDiffEq = "1.6"
UnPack = "1"
Zygote = "0.6.34"
julia = "1.7"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FluxExperimental = "c0d22e4d-7f3e-44a4-9c97-37045f84daf2"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CUDA", "Flux", "FluxExperimental", "LinearAlgebra", "Random", "Test"]
test = ["CUDA", "LinearAlgebra", "Lux", "Random", "Test"]
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# FastDEQ

![Dynamics Overview](assets/dynamics_overview.gif)

[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://fastdeq.sciml.ai/dev/)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://fastdeq.sciml.ai/stable/)
[![codecov](https://codecov.io/gh/SciML/FastDEQ.jl/branch/main/graph/badge.svg?token=plksEh6pUG)](https://codecov.io/gh/SciML/FastDEQ.jl)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)

Deep Equilibrium Networks using [Lux.jl](https://lux.csail.mit.edu/dev) and [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/)
12 changes: 6 additions & 6 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ makedocs(
canonical="https://fastdeq.sciml.ai/stable/"),
pages = [
"FastDEQ: Fast Deep Equilibrium Networks" => "index.md",
"API" => [
"Dynamical Systems" => "api/solvers.md",
"Non Linear Solvers" => "api/nlsolve.md",
"General Purpose Layers" => "api/layers.md",
"DEQ Layers" => "api/deqs.md",
"Miscellaneous" => "api/misc.md",
"Manual" => [
"Dynamical Systems" => "manual/solvers.md",
"Non Linear Solvers" => "manual/nlsolve.md",
"General Purpose Layers" => "manual/layers.md",
"DEQ Layers" => "manual/deqs.md",
"Miscellaneous" => "manual/misc.md",
],
"References" => "references.md",
]
Expand Down
10 changes: 0 additions & 10 deletions docs/src/api/misc.md

This file was deleted.

15 changes: 0 additions & 15 deletions docs/src/api/solvers.md

This file was deleted.

17 changes: 3 additions & 14 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# FastDEQ: (Fast) Deep Equlibrium Networks

FastDEQ.jl is a framework built on top of [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/) and [Flux.jl](https://fluxml.ai) enabling the efficient training and inference for Deep Equilibrium Networks (Infinitely Deep Neural Networks).
FastDEQ.jl is a framework built on top of [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/) and [Lux.jl](https://lux.csail.mit.edu/dev/) enabling the efficient training and inference for Deep Equilibrium Networks (Infinitely Deep Neural Networks).

## Installation

Currently the package is not registered and requires manually installing a few dependencies. We are working towards upstream fixes which will make installation easier

```julia
] add https://github.com/SciML/DiffEqSensitivity.jl.git#ap/fastdeq
] add https://github.com/avik-pal/FluxExperimental.jl.git#main
] add https://github.com/avik-pal/Lux.jl.git#main
] add https://github.com/SciML/FastDEQ.jl
```

Expand All @@ -27,14 +26,4 @@ If you are using this project for research or other academic purposes consider c
}
```

For specific algorithms, check the respective documentations and cite the corresponding papers.

## FAQs

#### How do I reproduce the experiments in the paper -- *Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural ODEs (Continuous DEQs)*?

Check out the `ap/paper` branch for the code corresponding to that paper.

#### Are there some tutorials?

We are working on adding some in the near future. In the meantime, please checkout the `experiments` directory in the `ap/paper` branch. You can also check `test/runtests.jl` for some simple examples.
For specific algorithms, check the respective documentations and cite the corresponding papers.
File renamed without changes.
1 change: 0 additions & 1 deletion docs/src/api/layers.md → docs/src/manual/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@

```@docs
DEQChain
MultiParallelNet
```
7 changes: 7 additions & 0 deletions docs/src/manual/misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Miscellaneous

```@docs
DeepEquilibriumAdjoint
DeepEquilibriumSolution
NormalInitializer
```
2 changes: 1 addition & 1 deletion docs/src/api/nlsolve.md → docs/src/manual/nlsolve.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ We provide the following NonLinear Solvers for DEQs. These are compatible with G
```@docs
BroydenSolver
LimitedMemoryBroydenSolver
```
```
38 changes: 38 additions & 0 deletions docs/src/manual/solvers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Dynamical System Variants

[baideep2019](@cite) introduced Discrete Deep Equilibrium Models which drives a Discrete Dynamical System to its steady-state. [pal2022mixing](@cite) extends this framework to Continuous Dynamical Systems which converge to the steady-stable in a more stable fashion. For a detailed discussion refer to [pal2022mixing](@cite).

## Continuous DEQs

```@docs
ContinuousDEQSolver
```

## Discrete DEQs

```@docs
DiscreteDEQSolver
```

## Termination Conditions

#### Termination on Absolute Tolerance

* `:abs`: Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)``
* `:abs_norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol``
* `:abs_deq_default`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges)
* `:abs_deq_best`: Same as `:abs_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged

#### Termination on Relative Tolerance

* `:rel`: Terminates if ``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)``
* `:rel_norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|``
* `:rel_deq_default`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges)
* `:rel_deq_best`: Same as `:rel_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged

#### Termination using both Absolute and Relative Tolerances

* `:norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` &
``\| \frac{\partial u}{\partial t} \| \leq abstol``
* `fallback`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems
but doesn't scale well for neural networks, and should be avoided unless absolutely necessary
57 changes: 57 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
name = "FastDEQExperiments"
uuid = "5aa64bb0-ce80-4310-96b1-36313c344f92"
authors = ["Avik Pal <[email protected]>"]
version = "0.1.0"

[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Augmentor = "02898b10-1f73-11ea-317c-6393d7073e15"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b"
Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
Wandb = "ad70616a-06c9-5745-b1f1-6a5f42545108"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "3"
DataLoaders = "0.1"
Flux = "0.13"
FluxMPI = "0.5.3"
Format = "1.3"
Lux = "0.4"
MLDatasets = "0.5"
MLUtils = "0.2"
MPI = "0.19"
NNlib = "0.8"
Optimisers = "0.2"
OrdinaryDiffEq = "6"
ParameterSchedulers = "0.3"
Setfield = "0.8, 1"
Wandb = "0.4.3"
Zygote = "0.6"
julia = "1.6"
Loading