Skip to content

Commit c958378

Browse files
willtebbuttgdalle
andauthored
Add AutoMooncake type (#89)
* Add AutoMooncake type * Export AutoMooncake * Add basic tests for AutoMooncake * Docs * Bump minor version * Deprecate AutoTapir * Deprecate AutoTapir * Re-include AutoTapir in docs * Tweak deprecation docstring * Update docstring for AutoMooncake * Line at EoF * Apply suggestions from code review * Preserve old AutoTapir docstring --------- Co-authored-by: Guillaume Dalle <[email protected]>
1 parent f0a724f commit c958378

File tree

7 files changed

+58
-12
lines changed

7 files changed

+58
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ADTypes"
22
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = ["Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors"]
4-
version = "1.8.1"
4+
version = "1.9.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ AutoGTPSA
3939
### Reverse mode
4040

4141
```@docs
42+
AutoMooncake
4243
AutoReverseDiff
43-
AutoTapir
4444
AutoTracker
4545
AutoZygote
4646
```
@@ -106,3 +106,9 @@ ADTypes.SymbolicMode
106106
```@docs
107107
ADTypes.Auto
108108
```
109+
110+
## Deprecated
111+
112+
```@docs
113+
AutoTapir
114+
```

src/ADTypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ export AutoChainRules,
4242
AutoForwardDiff,
4343
AutoGTPSA,
4444
AutoModelingToolkit,
45+
AutoMooncake,
4546
AutoPolyesterForwardDiff,
4647
AutoReverseDiff,
4748
AutoSymbolics,

src/dense.jl

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,27 @@ function Base.show(io::IO, backend::AutoGTPSA{D}) where {D}
227227
print(io, ")")
228228
end
229229

230+
"""
231+
AutoMooncake
232+
233+
Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend for automatic differentiation.
234+
235+
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
236+
237+
# Constructors
238+
239+
AutoMooncake(; config)
240+
241+
# Fields
242+
243+
- `config`: either `nothing` or an instance of `Mooncake.Config` -- see the docstring of `Mooncake.Config` for more information. `AutoMooncake(; config=nothing)` is equivalent to `AutoMooncake(; config=Mooncake.Config())`, i.e. the default configuration.
244+
"""
245+
Base.@kwdef struct AutoMooncake{Tconfig} <: AbstractADType
246+
config::Tconfig
247+
end
248+
249+
mode(::AutoMooncake) = ReverseMode()
250+
230251
"""
231252
AutoPolyesterForwardDiff{chunksize,T}
232253
@@ -323,7 +344,11 @@ mode(::AutoSymbolics) = SymbolicMode()
323344
"""
324345
AutoTapir
325346
326-
Struct used to select the [Tapir.jl](https://github.com/withbayes/Tapir.jl) backend for automatic differentiation.
347+
!!! danger
348+
349+
`AutoTapir` is deprecated following a package renaming, please use [`AutoMooncake`](@ref) instead.
350+
351+
Struct used to select the [Tapir.jl](https://github.com/compintell/Tapir.jl) backend for automatic differentiation.
327352
328353
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
329354
@@ -333,16 +358,10 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
333358
334359
# Fields
335360
336-
- `safe_mode::Bool`: whether to run additional checks to catch errors early. While this is
337-
on by default to ensure that users are aware of this option, you should generally turn
338-
it off for actual use, as it has substantial performance implications.
339-
If you encounter a problem with using Tapir (it fails to differentiate a function, or
340-
something truly nasty like a segfault occurs), then you should try switching `safe_mode`
341-
on and look at what happens. Often errors are caught earlier and the error messages are
342-
more useful.
361+
- `safe_mode::Bool`: whether to run additional checks to catch errors early.
343362
"""
344-
Base.@kwdef struct AutoTapir <: AbstractADType
345-
safe_mode::Bool = true
363+
struct AutoTapir <: AbstractADType
364+
safe_mode::Bool
346365
end
347366

348367
mode(::AutoTapir) = ReverseMode()

src/legacy.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,10 @@ function AutoModelingToolkit(; obj_sparse::Bool = false, cons_sparse::Bool = fal
3636
:AutoModelingToolkit; force = false)
3737
return mtk_to_symbolics(obj_sparse, cons_sparse)
3838
end
39+
40+
function AutoTapir(; safe_mode=true)
41+
Base.depwarn(
42+
"`AutoTapir` is deprecated in favour of `AutoMooncake`.", :AutoTapir; force=false
43+
)
44+
return AutoTapir(safe_mode)
45+
end

test/dense.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ end
119119
@test ad.descriptor == Val(:descriptor)
120120
end
121121

122+
@testset "AutoMooncake" begin
123+
ad = AutoMooncake(; config=nothing)
124+
@test ad isa AbstractADType
125+
@test ad isa AutoMooncake
126+
@test mode(ad) isa ReverseMode
127+
@test ad.config === nothing
128+
end
129+
122130
@testset "AutoPolyesterForwardDiff" begin
123131
ad = AutoPolyesterForwardDiff()
124132
@test ad isa AbstractADType

test/legacy.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,8 @@ end
6868
ad = @test_deprecated AutoReverseDiff(true)
6969
@test ad.compile
7070
end
71+
72+
@testset "AutoTapir" begin
73+
@test_deprecated AutoTapir()
74+
@test_deprecated AutoTapir(; safe_mode=false)
75+
end

0 commit comments

Comments
 (0)