Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.19.0"
version = "1.19.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 2 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
function (::Core.kwftype(typeof(ChainRulesCore.frule)))(
@nospecialize($kwargs::Any),
frule::typeof(ChainRulesCore.frule),
@nospecialize(::Any),
::$RuleConfig,
$(map(esc, primal_sig_parts)...),
)
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent())
end
function ChainRulesCore.frule(
@nospecialize(::Any), $(map(esc, primal_sig_parts)...)
::$RuleConfig, $(map(esc, primal_sig_parts)...)
)
$(__source__)
# Julia functions always only have 1 output, so return a single NoTangent()
Expand Down
18 changes: 18 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,24 @@ end
@test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent())
end

@testset "interactions with configs" begin
struct AllConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} end

foo_ndc1(x) = string(x)
@non_differentiable foo_ndc1(x)
@test frule(AllConfig(), foo_ndc1, 2.0) == (string(2.0), NoTangent())
r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0)
@test r1 == string(2.0)
@test pb1(NoTangent()) == (NoTangent(), NoTangent())

foo_ndc2(x; y=0) = string(x + y)
@non_differentiable foo_ndc2(x)
@test frule(AllConfig(), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent())
r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0)
@test r2 == string(6.0)
@test pb2(NoTangent()) == (NoTangent(), NoTangent())
end

@testset "Not supported (Yet)" begin
# Where clauses are not supported.
@test_macro_throws(
Expand Down