Skip to content

Commit 1bda732

Browse files
committed
Fix NaN-safe mode
1 parent 7223b5d commit 1bda732

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

src/partials.jl

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti
8282
@inline Base.:-(partials::Partials) = Partials(minus_tuple(partials.values))
8383
@inline Base.:*(x::Real, partials::Partials) = partials*x
8484

85+
@inline function Base.:*(partials::Partials, x::Real)
86+
return Partials(scale_tuple(partials.values, x))
87+
end
88+
89+
@inline function Base.:/(partials::Partials, x::Real)
90+
return Partials(div_tuple_by_scalar(partials.values, x))
91+
end
92+
93+
@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
94+
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
95+
end
96+
8597
@inline function _div_partials(a::Partials, b::Partials, aval, bval)
8698
return _mul_partials(a, b, inv(bval), -(aval / (bval*bval)))
8799
end
@@ -90,33 +102,17 @@ end
90102
#----------------------#
91103

92104
if NANSAFE_MODE_ENABLED
93-
@inline function Base.:*(partials::Partials, x::Real)
94-
x = ifelse(!isfinite(x) && iszero(partials), one(x), x)
95-
return Partials(scale_tuple(partials.values, x))
96-
end
97-
98-
@inline function Base.:/(partials::Partials, x::Real)
99-
x = ifelse(x == zero(x) && iszero(partials), one(x), x)
100-
return Partials(div_tuple_by_scalar(partials.values, x))
105+
@inline function _mul_partial(partial::Real, x::Real)
106+
y = partial * x
107+
return !isfinite(x) && iszero(partial) ? zero(y) : y
101108
end
102-
103-
@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
104-
x_a = ifelse(!isfinite(x_a) && iszero(a), one(x_a), x_a)
105-
x_b = ifelse(!isfinite(x_b) && iszero(b), one(x_b), x_b)
106-
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
109+
@inline function _div_partial(partial::Real, x::Real)
110+
y = partial / x
111+
return iszero(x) && iszero(partial) ? zero(y) : y
107112
end
108113
else
109-
@inline function Base.:*(partials::Partials, x::Real)
110-
return Partials(scale_tuple(partials.values, x))
111-
end
112-
113-
@inline function Base.:/(partials::Partials, x::Real)
114-
return Partials(div_tuple_by_scalar(partials.values, x))
115-
end
116-
117-
@inline function _mul_partials(a::Partials{N}, b::Partials{N}, x_a, x_b) where N
118-
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
119-
end
114+
@inline _mul_partial(partial::Real, x::Real) = partial * x
115+
@inline _div_partial(partial::Real, x::Real) = partial / x
120116
end
121117

122118
# edge cases where N == 0 #
@@ -197,11 +193,11 @@ end
197193
end
198194

199195
@generated function scale_tuple(tup::NTuple{N}, x) where N
200-
return tupexpr(i -> :(tup[$i] * x), N)
196+
return tupexpr(i -> :(_mul_partial(tup[$i], x)), N)
201197
end
202198

203199
@generated function div_tuple_by_scalar(tup::NTuple{N}, x) where N
204-
return tupexpr(i -> :(tup[$i] / x), N)
200+
return tupexpr(i -> :(_div_partial(tup[$i], x)), N)
205201
end
206202

207203
@generated function add_tuples(a::NTuple{N}, b::NTuple{N}) where N
@@ -217,7 +213,7 @@ end
217213
end
218214

219215
@generated function mul_tuples(a::NTuple{N}, b::NTuple{N}, afactor, bfactor) where N
220-
return tupexpr(i -> :((afactor * a[$i]) + (bfactor * b[$i])), N)
216+
return tupexpr(i -> :(_mul_partial(a[$i], afactor) + _mul_partial(b[$i], bfactor)), N)
221217
end
222218

223219
###################

0 commit comments

Comments
 (0)