@@ -6,6 +6,9 @@ using DiffResults
66using ReverseDiff
77using Test
88
9+ struct MyStruct end
10+ f (:: MyStruct , x) = sum (4 x .+ 1 )
11+ f (x, y:: MyStruct ) = sum (4 x .+ 1 )
912f (x) = sum (4 x .+ 1 )
1013
1114function ChainRulesCore. rrule (:: typeof (f), x)
@@ -20,21 +23,37 @@ function ChainRulesCore.rrule(::typeof(f), x)
2023 rather than 4 when we compute the derivative of `f`, it means
2124 the importing mechanism works.
2225 =#
23- return ChainRulesCore. NoTangent (), fill (3 * d, size (x))
26+ return NoTangent (), fill (3 * d, size (x))
27+ end
28+ return r, back
29+ end
30+ function ChainRulesCore. rrule (:: typeof (f), :: MyStruct , x)
31+ r = f (MyStruct (), x)
32+ function back (d)
33+ return NoTangent (), NoTangent (), fill (3 * d, size (x))
34+ end
35+ return r, back
36+ end
37+ function ChainRulesCore. rrule (:: typeof (f), x, :: MyStruct )
38+ r = f (x, MyStruct ())
39+ function back (d)
40+ return NoTangent (), fill (3 * d, size (x)), NoTangent ()
2441 end
2542 return r, back
2643end
2744
2845ReverseDiff. @grad_from_chainrules f (x:: ReverseDiff.TrackedArray )
29-
46+ # test arg type hygiene
47+ ReverseDiff. @grad_from_chainrules f (:: MyStruct , x:: ReverseDiff.TrackedArray )
48+ ReverseDiff. @grad_from_chainrules f (x:: ReverseDiff.TrackedArray , y:: MyStruct )
3049
3150g (x, y) = sum (4 x .+ 4 y)
3251
3352function ChainRulesCore. rrule (:: typeof (g), x, y)
3453 r = g (x, y)
3554 function back (d)
3655 # same as above, use 3 and 5 as the derivatives
37- return ChainRulesCore . NoTangent (), fill (3 * d, size (x)), fill (5 * d, size (x))
56+ return NoTangent (), fill (3 * d, size (x)), fill (5 * d, size (x))
3857 end
3958 return r, back
4059end
@@ -93,6 +112,19 @@ ReverseDiff.@grad_from_chainrules g(x::ReverseDiff.TrackedArray, y::ReverseDiff.
93112
94113end
95114
115+ @testset " custom struct input" begin
116+ input = rand (3 , 3 )
117+ output, back = ChainRulesCore. rrule (f, MyStruct (), input);
118+ _, _, d = back (1 )
119+ @test output == f (MyStruct (), input)
120+ @test d == fill (3 , size (input))
121+
122+ output, back = ChainRulesCore. rrule (f, input, MyStruct ());
123+ _, d, _ = back (1 )
124+ @test output == f (input, MyStruct ())
125+ @test d == fill (3 , size (input))
126+ end
127+
96128# ## Tape test
97129@testset " Tape test: Ensure ordinary call is not tracked" begin
98130 tp = ReverseDiff. InstructionTape ()
@@ -112,7 +144,7 @@ f_vararg(x, args...) = sum(4x .+ sum(args))
112144function ChainRulesCore. rrule (:: typeof (f_vararg), x, args... )
113145 r = f_vararg (x, args... )
114146 function back (d)
115- return ChainRulesCore . NoTangent (), fill (3 * d, size (x))
147+ return NoTangent (), fill (3 * d, size (x))
116148 end
117149 return r, back
118150end
@@ -136,7 +168,7 @@ f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))
136168function ChainRulesCore. rrule (:: typeof (f_kw), x, args... ; k= 1 , kwargs... )
137169 r = f_kw (x, args... ; k= k, kwargs... )
138170 function back (d)
139- return ChainRulesCore . NoTangent (), fill (3 * d, size (x))
171+ return NoTangent (), fill (3 * d, size (x))
140172 end
141173 return r, back
142174end
@@ -175,20 +207,20 @@ end
175207# ## Isolated Scope
176208module IsolatedModuleForTestingScoping
177209using ChainRulesCore
178- using ReverseDiff: @grad_from_chainrules
210+ using ReverseDiff: ReverseDiff, @grad_from_chainrules
179211
180212f (x) = sum (4 x .+ 1 )
181213
182214function ChainRulesCore. rrule (:: typeof (f), x)
183215 r = f (x)
184216 function back (d)
185217 # return a distinguishable but improper grad
186- return ChainRulesCore . NoTangent (), fill (3 * d, size (x))
218+ return NoTangent (), fill (3 * d, size (x))
187219 end
188220 return r, back
189221end
190222
191- @grad_from_chainrules f (x:: TrackedArray )
223+ @grad_from_chainrules f (x:: ReverseDiff. TrackedArray )
192224
193225module SubModule
194226using Test
0 commit comments