Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReverseDiff"
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
version = "1.14.3"
version = "1.14.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
11 changes: 11 additions & 0 deletions src/ReverseDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ const REAL_TYPES = (:Bool, :Integer, :(Irrational{:ℯ}), :(Irrational{:π}), :R
const SKIPPED_UNARY_SCALAR_FUNCS = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
const SKIPPED_BINARY_SCALAR_FUNCS = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]

# Some functions with derivatives in DiffRules are not supported
# For instance, ReverseDiff does not support functions with complex results and derivatives
const SKIPPED_DIFFRULES = Tuple{Symbol,Symbol}[
(:SpecialFunctions, :hankelh1),
(:SpecialFunctions, :hankelh1x),
(:SpecialFunctions, :hankelh2),
(:SpecialFunctions, :hankelh2x),
(:SpecialFunctions, :besselh),
(:SpecialFunctions, :besselhx),
]

include("tape.jl")
include("tracked.jl")
include("macros.jl")
Expand Down
128 changes: 106 additions & 22 deletions src/derivatives/elementwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ for g! in (:map!, :broadcast!), (M, f, arity) in DiffRules.diffrules(; filter_mo
@warn "$M.$f is not available and hence rule for it can not be defined"
continue # Skip rules for methods not defined in the current scope
end
(M, f) in SKIPPED_DIFFRULES && continue
if arity == 1
@eval @inline Base.$(g!)(f::typeof($M.$f), out::TrackedArray, t::TrackedArray) = $(g!)(ForwardOptimize(f), out, t)
elseif arity == 2
Expand Down Expand Up @@ -122,23 +123,53 @@ for (g!, g) in ((:map!, :map), (:broadcast!, :broadcast))
return out
end
end
for A in ARRAY_TYPES, T in (:TrackedArray, :TrackedReal)
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray{S}, x::$(T){X}, y::$A) where {F,S,X}
result = DiffResults.GradientResult(SVector(zero(S), zero(S)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
for A in ARRAY_TYPES
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedReal{X,D}, y::$A) where {F,X,D}
result = DiffResults.DiffResult(zero(X), zero(D))
df = let result=result
(vx, vy) -> let vy=vy
ForwardDiff.derivative!(result, s -> f.f(s, vy), vx)
end
end
results = $(g)(df, value(x), value(y))
map!(DiffResult.value, value(out), results)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tape(x, y), SpecialInstruction, $(g), (x, y), out, cache)
record!(tape(x), SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::$(T){Y}) where {F,Y}
result = DiffResults.GradientResult(SVector(zero(S), zero(S)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedReal{Y,D}) where {F,Y,D}
result = DiffResults.DiffResult(zero(Y), zero(D))
df = let result=result
(vx, vy) -> let vx=vx
ForwardDiff.derivative!(result, s -> f.f(vx, s), vy)
end
end
results = $(g)(df, value(x), value(y))
map!(DiffResult.value, value(out), results)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tape(x, y), SpecialInstruction, $(g), (x, y), out, cache)
record!(tape(y), SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedArray{X}, y::$A) where {F,X}
result = DiffResults.GradientResult(SVector(zero(X)))
df = (vx, vy) -> let vy=vy
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why s[1]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since s is a SVector with a single element vx which we want to use here. That's just the one-argument version of the current implementation on the master branch:

result = DiffResults.GradientResult(SVector(zero(S), zero(S)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))

end
results = $(g)(df, value(x), value(y))
map!(DiffResult.value, value(out), results)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tape(x), SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedArray{Y}) where {F,Y}
result = DiffResults.GradientResult(SVector(zero(Y)))
df = let vx=vx
(vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s[1]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

end
results = $(g)(df, value(x), value(y))
map!(DiffResult.value, value(out), results)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tape(y), SpecialInstruction, $(g), (x, y), out, cache)
return out
end
end
Expand Down Expand Up @@ -166,6 +197,7 @@ for g in (:map, :broadcast), (M, f, arity) in DiffRules.diffrules(; filter_modul
if arity == 1
@eval @inline Base.$(g)(f::typeof($M.$f), t::TrackedArray) = $(g)(ForwardOptimize(f), t)
elseif arity == 2
(M, f) in SKIPPED_DIFFRULES && continue
# skip these definitions if `f` is one of the functions
# that will get a manually defined broadcast definition
# later (see "built-in infix operations" below)
Expand Down Expand Up @@ -207,20 +239,52 @@ for g in (:map, :broadcast)
record!(tp, SpecialInstruction, $(g), x, out, cache)
return out
end
for A in ARRAY_TYPES, T in (:TrackedArray, :TrackedReal)
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$(T){X,D}, y::$A) where {F,X,D}
result = DiffResults.GradientResult(SVector(zero(X), zero(D)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
for A in ARRAY_TYPES
@eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedReal{X,D}, y::$A) where {F,X,D}
result = DiffResults.DiffResult(zero(X), zero(D))
df = let result=result
(vx, vy) -> let vy=vy
ForwardDiff.derivative!(result, s -> f.f(s, vy), vx)
end
end
results = $(g)(df, value(x), value(y))
tp = tape(x)
out = track(DiffResults.value.(results), D, tp)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::$(T){Y,D}) where {F,Y,D}
result = DiffResults.GradientResult(SVector(zero(Y), zero(D)))
df = (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(s[1], s[2]), SVector(vx, vy))
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedReal{Y,D}) where {F,Y,D}
result = DiffResults.DiffResult(zero(Y), zero(D))
df = let result=result
(vx, vy) -> let vx=vx
ForwardDiff.derivative!(result, s -> f.f(vx, s), vy)
end
end
results = $(g)(df, value(x), value(y))
tp = tape(y)
out = track(DiffResults.value.(results), D, tp)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedArray{X,D}, y::$A) where {F,X,D}
result = DiffResults.GradientResult(SVector(zero(X)))
df = (vx, vy) -> let vy=vy
ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s[1]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

end
results = $(g)(df, value(x), value(y))
tp = tape(x)
out = track(DiffResults.value.(results), D, tp)
cache = (results, df, index_bound(x, out), index_bound(y, out))
record!(tp, SpecialInstruction, $(g), (x, y), out, cache)
return out
end
@eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedArray{Y,D}) where {F,Y,D}
result = DiffResults.GradientResult(SVector(zero(Y)))
df = (vx, vy) -> let vx=vx
ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same s[1] comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

end
results = $(g)(df, value(x), value(y))
tp = tape(y)
out = track(DiffResults.value.(results), D, tp)
Expand Down Expand Up @@ -291,8 +355,15 @@ end
diffresult_increment_deriv!(input, output_deriv, results, 1)
else
a, b = input
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1)
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2)
p = 0
if istracked(a)
p += 1
diffresult_increment_deriv!(a, output_deriv, results, p)
end
if istracked(b)
p += 1
diffresult_increment_deriv!(b, output_deriv, results, p)
end
end
unseed!(output)
return nothing
Expand All @@ -311,12 +382,25 @@ end
end
else
a, b = input
p = 0
if size(a) == size(b)
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1)
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2)
if istracked(a)
p += 1
diffresult_increment_deriv!(a, output_deriv, results, p)
end
if istracked(b)
p += 1
diffresult_increment_deriv!(b, output_deriv, results, p)
end
else
istracked(a) && diffresult_increment_deriv!(a, output_deriv, results, 1, a_bound)
istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2, b_bound)
if istracked(a)
p += 1
diffresult_increment_deriv!(a, output_deriv, results, p, a_bound)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change the value of p here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To extract the correct partial. If a is tracked, its corresponding partial has index p = 1 but if only b is tracked, the first partial (p = 1) corresponds to b. And if both a and b are tracked, p = 1 corresponds to a and p = 2 to b. So incrementing p in the branches allows us to avoid checking and handling all three scenarios separately.

Note that on the master branch p = 1 for a and p = 2 for b are hardcoded. That only works because on the master branch always the partials wrt to both arguments are computed and stored, even if only one argument is tracked.

end
if istracked(b)
p += 1
diffresult_increment_deriv!(b, output_deriv, results, p, b_bound)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same p comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

end
end
end
unseed!(output)
Expand Down
1 change: 1 addition & 0 deletions src/derivatives/scalars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
@warn "$M.$f is not available and hence rule for it can not be defined"
continue # Skip rules for methods not defined in the current scope
end
(M, f) in SKIPPED_DIFFRULES && continue
if arity == 1
@eval @inline $M.$(f)(t::TrackedReal) = ForwardOptimize($M.$(f))(t)
elseif arity == 2
Expand Down
63 changes: 39 additions & 24 deletions test/derivatives/ElementwiseTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function test_elementwise(f, fopt, x, tp)
# reverse
out = similar(y, (length(x), length(x)))
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
test_approx(out, ForwardDiff.jacobian(z -> map(f, z), x))
test_approx(out, ForwardDiff.jacobian(z -> map(f, z), x); nans=true)

# forward
x2 = x .- offset
Expand All @@ -57,7 +57,7 @@ function test_elementwise(f, fopt, x, tp)
# reverse
out = similar(y, (length(x), length(x)))
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
test_approx(out, ForwardDiff.jacobian(z -> broadcast(f, z), x))
test_approx(out, ForwardDiff.jacobian(z -> broadcast(f, z), x); nans=true)

# forward
x2 = x .- offset
Expand All @@ -81,9 +81,9 @@ function test_map(f, fopt, a, b, tp)
@test length(tp) == 1

# reverse
out = similar(c, (length(a), length(a)))
out = similar(c, (length(c), length(a)))
ReverseDiff.seeded_reverse_pass!(out, ct, at, tp)
test_approx(out, ForwardDiff.jacobian(x -> map(f, x, b), a))
test_approx(out, ForwardDiff.jacobian(x -> map(f, x, b), a); nans=true)

# forward
a2 = a .- offset
Expand All @@ -102,9 +102,9 @@ function test_map(f, fopt, a, b, tp)
@test length(tp) == 1

# reverse
out = similar(c, (length(a), length(a)))
out = similar(c, (length(c), length(b)))
ReverseDiff.seeded_reverse_pass!(out, ct, bt, tp)
test_approx(out, ForwardDiff.jacobian(x -> map(f, a, x), b))
test_approx(out, ForwardDiff.jacobian(x -> map(f, a, x), b); nans=true)

# forward
b2 = b .- offset
Expand All @@ -123,13 +123,17 @@ function test_map(f, fopt, a, b, tp)
@test length(tp) == 1

# reverse
out_a = similar(c, (length(a), length(a)))
out_b = similar(c, (length(a), length(a)))
out_a = similar(c, (length(c), length(a)))
out_b = similar(c, (length(c), length(b)))
ReverseDiff.seeded_reverse_pass!(out_a, ct, at, tp)
ReverseDiff.seeded_reverse_pass!(out_b, ct, bt, tp)
test_approx(out_a, ForwardDiff.jacobian(x -> map(f, x, b), a))
test_approx(out_b, ForwardDiff.jacobian(x -> map(f, a, x), b))

jac = let a=a, b=b, f=f
ForwardDiff.jacobian(vcat(vec(a), vec(b))) do x
map(f, reshape(x[1:length(a)], size(a)), reshape(x[(length(a) + 1):end], size(b)))
end
end
test_approx(out_a, jac[:, 1:length(a)]; nans=true)
test_approx(out_b, jac[:, (length(a) + 1):end]; nans=true)
# forward
a2, b2 = a .- offset, b .- offset
ReverseDiff.value!(at, a2)
Expand Down Expand Up @@ -163,7 +167,7 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
# reverse
out = similar(c, (length(c), length(a)))
ReverseDiff.seeded_reverse_pass!(out, ct, at, tp)
test_approx(out, ForwardDiff.jacobian(x -> g(x, b), a))
test_approx(out, ForwardDiff.jacobian(x -> g(x, b), a); nans=true)

# forward
a2 = a .- offset
Expand All @@ -184,7 +188,7 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
# reverse
out = similar(c, (length(c), length(b)))
ReverseDiff.seeded_reverse_pass!(out, ct, bt, tp)
test_approx(out, ForwardDiff.jacobian(x -> g(a, x), b))
test_approx(out, ForwardDiff.jacobian(x -> g(a, x), b); nans=true)

# forward
b2 = b .- offset
Expand All @@ -207,8 +211,13 @@ function test_broadcast(f, fopt, a::AbstractArray, b::AbstractArray, tp, builtin
out_b = similar(c, (length(c), length(b)))
ReverseDiff.seeded_reverse_pass!(out_a, ct, at, tp)
ReverseDiff.seeded_reverse_pass!(out_b, ct, bt, tp)
test_approx(out_a, ForwardDiff.jacobian(x -> g(x, b), a))
test_approx(out_b, ForwardDiff.jacobian(x -> g(a, x), b))
jac = let a=a, b=b, g=g
ForwardDiff.jacobian(vcat(vec(a), vec(b))) do x
g(reshape(x[1:length(a)], size(a)), reshape(x[(length(a) + 1):end], size(b)))
end
end
test_approx(out_a, jac[:, 1:length(a)]; nans=true)
test_approx(out_b, jac[:, (length(a) + 1):end]; nans=true)

# forward
a2, b2 = a .- offset, b .- offset
Expand Down Expand Up @@ -243,7 +252,7 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
# reverse
out = similar(y)
ReverseDiff.seeded_reverse_pass!(out, yt, nt, tp)
test_approx(out, ForwardDiff.derivative(z -> g(z, x), n))
test_approx(out, ForwardDiff.derivative(z -> g(z, x), n); nans=true)

# forward
n2 = n + offset
Expand All @@ -264,7 +273,7 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
# reverse
out = similar(y, (length(y), length(x)))
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
test_approx(out, ForwardDiff.jacobian(z -> g(n, z), x))
test_approx(out, ForwardDiff.jacobian(z -> g(n, z), x); nans=true)

# forward
x2 = x .- offset
Expand All @@ -287,8 +296,11 @@ function test_broadcast(f, fopt, n::Number, x::AbstractArray, tp, builtin::Bool
out_x = similar(y, (length(y), length(x)))
ReverseDiff.seeded_reverse_pass!(out_n, yt, nt, tp)
ReverseDiff.seeded_reverse_pass!(out_x, yt, xt, tp)
test_approx(out_n, ForwardDiff.derivative(z -> g(z, x), n))
test_approx(out_x, ForwardDiff.jacobian(z -> g(n, z), x))
jac = let x=x, g=g
ForwardDiff.jacobian(z -> g(z[1], reshape(z[2:end], size(x))), vcat(n, vec(x)))
end
test_approx(out_n, reshape(jac[:, 1], size(y)); nans=true)
test_approx(out_x, jac[:, 2:end]; nans=true)

# forward
n2, x2 = n + offset , x .- offset
Expand Down Expand Up @@ -323,7 +335,7 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
# reverse
out = similar(y)
ReverseDiff.seeded_reverse_pass!(out, yt, nt, tp)
test_approx(out, ForwardDiff.derivative(z -> g(x, z), n))
test_approx(out, ForwardDiff.derivative(z -> g(x, z), n); nans=true)

# forward
n2 = n + offset
Expand All @@ -344,7 +356,7 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
# reverse
out = similar(y, (length(y), length(x)))
ReverseDiff.seeded_reverse_pass!(out, yt, xt, tp)
test_approx(out, ForwardDiff.jacobian(z -> g(z, n), x))
test_approx(out, ForwardDiff.jacobian(z -> g(z, n), x); nans=true)

# forward
x2 = x .- offset
Expand All @@ -367,8 +379,11 @@ function test_broadcast(f, fopt, x::AbstractArray, n::Number, tp, builtin::Bool
out_x = similar(y, (length(y), length(x)))
ReverseDiff.seeded_reverse_pass!(out_n, yt, nt, tp)
ReverseDiff.seeded_reverse_pass!(out_x, yt, xt, tp)
test_approx(out_n, ForwardDiff.derivative(z -> g(x, z), n))
test_approx(out_x, ForwardDiff.jacobian(z -> g(z, n), x))
jac = let x=x, g=g
ForwardDiff.jacobian(z -> g(reshape(z[1:(end - 1)], size(x)), z[end]), vcat(vec(x), n))
end
test_approx(out_x, jac[:, 1:(end - 1)]; nans=true)
test_approx(out_n, reshape(jac[:, end], size(y)); nans=true)

# forward
x2, n2 = x .- offset, n + offset
Expand All @@ -393,7 +408,7 @@ for (M, fsym, arity) in DiffRules.diffrules(; filter_modules=nothing)
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), fsym))
error("$M.$fsym is not available")
end
fsym === :rem2pi && continue
(M, fsym) in ReverseDiff.SKIPPED_DIFFRULES && continue
if arity == 1
f = eval(:($M.$fsym))
test_println("forward-mode unary scalar functions", f)
Expand Down
Loading