Skip to content

Commit f33d535

Browse files
committed
Fix NaN-safe mode
1 parent 7223b5d commit f33d535

File tree

5 files changed

+115
-50
lines changed

5 files changed

+115
-50
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ name: CI
22

33
on:
44
pull_request:
5-
branches:
6-
- master
75
push:
86
branches:
97
- master

src/partials.jl

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti
8282
@inline Base.:-(partials::Partials) = Partials(minus_tuple(partials.values))
8383
@inline Base.:*(x::Real, partials::Partials) = partials*x
8484

85+
@inline function Base.:*(partials::Partials, x::Real)
86+
return Partials(scale_tuple(partials.values, x))
87+
end
88+
89+
@inline function Base.:/(partials::Partials, x::Real)
90+
return Partials(div_tuple_by_scalar(partials.values, x))
91+
end
92+
93+
@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
94+
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
95+
end
96+
8597
@inline function _div_partials(a::Partials, b::Partials, aval, bval)
8698
return _mul_partials(a, b, inv(bval), -(aval / (bval*bval)))
8799
end
@@ -90,33 +102,22 @@ end
90102
#----------------------#
91103

92104
if NANSAFE_MODE_ENABLED
93-
@inline function Base.:*(partials::Partials, x::Real)
94-
x = ifelse(!isfinite(x) && iszero(partials), one(x), x)
95-
return Partials(scale_tuple(partials.values, x))
96-
end
97-
98-
@inline function Base.:/(partials::Partials, x::Real)
99-
x = ifelse(x == zero(x) && iszero(partials), one(x), x)
100-
return Partials(div_tuple_by_scalar(partials.values, x))
105+
# A dual number with a zero partial is just an unperturbed non-dual number
106+
# Hence when propagated the resulting dual number is unperturbed as well,
107+
# ie., its partial is zero as well, regardless of the primal value
108+
# However, standard floating point multiplication/division would return `NaN`
109+
# if the primal is not-finite/zero
110+
@inline function _mul_partial(partial::Real, x::Real)
111+
y = partial * x
112+
return iszero(partial) ? zero(y) : y
101113
end
102-
103-
@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
104-
x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a)
105-
x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b)
106-
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
114+
@inline function _div_partial(partial::Real, x::Real)
115+
y = partial / x
116+
return iszero(partial) ? zero(y) : y
107117
end
108118
else
109-
@inline function Base.:*(partials::Partials, x::Real)
110-
return Partials(scale_tuple(partials.values, x))
111-
end
112-
113-
@inline function Base.:/(partials::Partials, x::Real)
114-
return Partials(div_tuple_by_scalar(partials.values, x))
115-
end
116-
117-
@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
118-
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
119-
end
119+
@inline _mul_partial(partial::Real, x::Real) = partial * x
120+
@inline _div_partial(partial::Real, x::Real) = partial / x
120121
end
121122

122123
# edge cases where N == 0 #
@@ -197,11 +198,11 @@ end
197198
end
198199

199200
@generated function scale_tuple(tup::NTuple{N}, x) where N
200-
return tupexpr(i -> :(tup[$i] * x), N)
201+
return tupexpr(i -> :(_mul_partial(tup[$i], x)), N)
201202
end
202203

203204
@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N
204-
return tupexpr(i -> :(tup[$i] / x), N)
205+
return tupexpr(i -> :(_div_partial(tup[$i], x)), N)
205206
end
206207

207208
@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
@@ -217,7 +218,7 @@ end
217218
end
218219

219220
@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N
220-
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
221+
return tupexpr(i -> :(_mul_partial(a[$i], afactor) + _mul_partial(b[$i], bfactor)), N)
221222
end
222223

223224
###################

test/DerivativeTest.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,13 @@ end
113113
@test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im)
114114
end
115115

116+
@testset "NaN-safe mode" begin
117+
x = ForwardDiff.derivative(log zero, 1.0)
118+
if ForwardDiff.NANSAFE_MODE_ENABLED
119+
@test iszero(x)
120+
else
121+
@test isnan(x)
122+
end
123+
end
124+
116125
end # module

test/GradientTest.jl

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,15 @@ end
148148
end
149149

150150
@testset "exponential function at base zero" begin
151-
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [NaN, NaN])
152-
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, NaN])
153-
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, NaN])
151+
if ForwardDiff.NANSAFE_MODE_ENABLED
152+
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [-Inf, -Inf])
153+
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, -Inf])
154+
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, 0.0])
155+
else
156+
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, -0.5]), [NaN, NaN])
157+
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, NaN])
158+
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, NaN])
159+
end
154160
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0])
155161
end
156162

@@ -207,11 +213,19 @@ end
207213
end
208214

209215
@testset "gradient for exponential with NaNMath" begin
210-
@test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1])
216+
if ForwardDiff.NANSAFE_MODE_ENABLED
217+
@test isequal(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[2]), [NaN, 1.0]), [1.0, NaN])
218+
else
219+
@test isequal(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[2]), [NaN, 1.0]), [NaN, NaN])
220+
end
211221
@test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0]
212222
@test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1])
213223

214-
@test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1])
224+
if ForwardDiff.NANSAFE_MODE_ENABLED
225+
@test isequal(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0]), [1.0, NaN])
226+
else
227+
@test isequal(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0]), [NaN, NaN])
228+
end
215229
@test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0]
216230
@test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5])
217231
end
@@ -286,4 +300,34 @@ end
286300
@test grad == SVector{3}(der, der, der)
287301
end
288302

303+
@testset "NaN-safe mode" begin
304+
# issue #774
305+
f = x -> log(zero(x[1]) + x[2])
306+
x = [1.0, 0.0]
307+
y1 = ForwardDiff.gradient(f, x)
308+
y2 = ForwardDiff.gradient(f, x, ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{1}()))
309+
y3 = ForwardDiff.gradient(f, x, ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{2}()))
310+
for y in (y1, y2, y3)
311+
if ForwardDiff.NANSAFE_MODE_ENABLED
312+
@test y == [0.0, Inf]
313+
else
314+
@test isequal(y, [NaN, Inf])
315+
end
316+
end
317+
318+
# issue #745
319+
g = a -> a[1] * exp(-a[2])
320+
a = [1.0, -1e3]
321+
b1 = ForwardDiff.gradient(g, a)
322+
b2 = ForwardDiff.gradient(g, a, ForwardDiff.GradientConfig(g, a, ForwardDiff.Chunk{1}()))
323+
b3 = ForwardDiff.gradient(g, a, ForwardDiff.GradientConfig(g, a, ForwardDiff.Chunk{2}()))
324+
for b in (b1, b2, b3)
325+
if ForwardDiff.NANSAFE_MODE_ENABLED
326+
@test b == [Inf, -Inf]
327+
else
328+
@test isequal(b, [NaN, NaN])
329+
end
330+
end
331+
end
332+
289333
end # module

test/PartialsTest.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,37 @@ samerng() = MersenneTwister(1)
111111
@test (PARTIALS / X).values == map(v -> v / X, VALUES)
112112

113113
if N > 0
114-
@test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2))
115-
@test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)
116-
@test ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, X, Y) == Y * PARTIALS
117-
@test ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, X, Y) == X * PARTIALS
114+
# Only zero partials
115+
ALLZERO = Partials(ntuple(_ -> zero(T), N))
116+
# Mix of zero and non-zero partials
117+
FIRSTZERO = Partials(ntuple(i -> i == 1 ? zero(T) : rand(T), N))
118+
119+
# The following properties should always be satisfied, regardless of whether NaN-safe mode is enabled or disabled
120+
# We use `isequal` for comparisons in the presence of `NaN`s
121+
for p1 in (PARTIALS, ALLZERO, FIRSTZERO), p2 in (PARTIALS2, ALLZERO, FIRSTZERO), v1 in (X, NaN, Inf), v2 in (Y, NaN, Inf)
122+
@test isequal(ForwardDiff._div_partials(p1, p2, v1, v2), ForwardDiff._mul_partials(p1, p2, inv(v2), -v1/(v2^2)))
123+
@test isequal(ForwardDiff._mul_partials(p1, p2, v1, v2), v1 * p1 + v2 * p2)
124+
end
125+
for v1 in (X, NaN, Inf), v2 in (Y, NaN, Inf)
126+
@test isequal(ForwardDiff._mul_partials(ZERO_PARTIALS, PARTIALS, v1, v2), v2 * PARTIALS)
127+
@test isequal(ForwardDiff._mul_partials(PARTIALS, ZERO_PARTIALS, v1, v2), v1 * PARTIALS)
128+
end
118129

119130
if ForwardDiff.NANSAFE_MODE_ENABLED
120-
ZEROS = Partials((fill(zero(T), N)...,))
121-
122-
@test (NaN * ZEROS).values == ZEROS.values
123-
@test (Inf * ZEROS).values == ZEROS.values
124-
@test (ZEROS / 0).values == ZEROS.values
125-
126-
@test ForwardDiff._mul_partials(ZEROS, ZEROS, X, NaN).values == ZEROS.values
127-
@test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, X).values == ZEROS.values
128-
@test ForwardDiff._mul_partials(ZEROS, ZEROS, X, Inf).values == ZEROS.values
129-
@test ForwardDiff._mul_partials(ZEROS, ZEROS, Inf, X).values == ZEROS.values
130-
@test ForwardDiff._mul_partials(ZEROS, ZEROS, Inf, NaN).values == ZEROS.values
131-
@test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, Inf).values == ZEROS.values
131+
for f in ((p -> NaN * p), (p -> Inf * p), (p -> -Inf * p), (p -> p / 0), (p -> p / NaN), (p -> p / Inf), (p -> p / -Inf))
132+
# Only zero partials
133+
@test iszero(@inferred(f(ALLZERO)))
134+
135+
# Mix of zero and non-zero partials
136+
z = @inferred(f(FIRSTZERO))
137+
for i in 1:N
138+
if iszero(FIRSTZERO[i])
139+
@test iszero(z[i])
140+
else
141+
@test isequal(z[i], f(FIRSTZERO[i]))
142+
end
143+
end
144+
end
132145
end
133146
end
134147
end

0 commit comments

Comments
 (0)