@@ -82,6 +82,18 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti
82
82
@inline Base.:- (partials:: Partials ) = Partials (minus_tuple (partials. values))
83
83
@inline Base.:* (x:: Real , partials:: Partials ) = partials* x
84
84
85
+ @inline function Base.:* (partials:: Partials , x:: Real )
86
+ return Partials (scale_tuple (partials. values, x))
87
+ end
88
+
89
+ @inline function Base.:/ (partials:: Partials , x:: Real )
90
+ return Partials (div_tuple_by_scalar (partials. values, x))
91
+ end
92
+
93
+ @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
94
+ return Partials (mul_tuples (a. values, b. values, x_a, x_b))
95
+ end
96
+
85
97
@inline function _div_partials (a:: Partials , b:: Partials , aval, bval)
86
98
return _mul_partials (a, b, inv (bval), - (aval / (bval* bval)))
87
99
end
90
102
# ----------------------#
91
103
92
104
if NANSAFE_MODE_ENABLED
93
- @inline function Base.: * (partials :: Partials , x :: Real )
94
- x = ifelse ( ! isfinite (x) && iszero (partials), one (x), x)
95
- return Partials ( scale_tuple (partials . values, x))
96
- end
97
-
98
- @inline function Base.: / (partials :: Partials , x:: Real )
99
- x = ifelse (x == zero (x) && iszero (partials), one (x), x)
100
- return Partials ( div_tuple_by_scalar (partials . values, x))
105
+ # A dual number with a zero partial is just an unperturbed non-dual number
106
+ # Hence when propagated the resulting dual number is unperturbed as well,
107
+ # ie., its partial is zero as well, regardless of the primal value
108
+ # However, standard floating point multiplication/division would return `NaN`
109
+ # if the primal is not-finite/zero
110
+ @inline function _mul_partial (partial :: Real , x:: Real )
111
+ y = partial * x
112
+ return iszero (partial) ? zero (y) : y
101
113
end
102
-
103
- @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
104
- x_a = ifelse (! isfinite (x_a) && iszero (a), one (x_a), x_a)
105
- x_b = ifelse (! isfinite (x_b) && iszero (b), one (x_b), x_b)
106
- return Partials (mul_tuples (a. values, b. values, x_a, x_b))
114
+ @inline function _div_partial (partial:: Real , x:: Real )
115
+ y = partial / x
116
+ return iszero (partial) ? zero (y) : y
107
117
end
108
118
else
109
- @inline function Base.:* (partials:: Partials , x:: Real )
110
- return Partials (scale_tuple (partials. values, x))
111
- end
112
-
113
- @inline function Base.:/ (partials:: Partials , x:: Real )
114
- return Partials (div_tuple_by_scalar (partials. values, x))
115
- end
116
-
117
- @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
118
- return Partials (mul_tuples (a. values, b. values, x_a, x_b))
119
- end
119
+ @inline _mul_partial (partial:: Real , x:: Real ) = partial * x
120
+ @inline _div_partial (partial:: Real , x:: Real ) = partial / x
120
121
end
121
122
122
123
# edge cases where N == 0 #
@@ -197,11 +198,11 @@ end
197
198
end
198
199
199
200
@generated function scale_tuple (tup:: NTuple{N} , x) where N
200
- return tupexpr (i -> :(tup[$ i] * x ), N)
201
+ return tupexpr (i -> :(_mul_partial ( tup[$ i], x) ), N)
201
202
end
202
203
203
204
@generated function div_tuple_by_scalar (tup:: NTuple{N} , x) where N
204
- return tupexpr (i -> :(tup[$ i] / x ), N)
205
+ return tupexpr (i -> :(_div_partial ( tup[$ i], x) ), N)
205
206
end
206
207
207
208
@generated function add_tuples (a:: NTuple{N} , b:: NTuple{N} ) where N
217
218
end
218
219
219
220
@generated function mul_tuples (a:: NTuple{N} , b:: NTuple{N} , afactor, bfactor) where N
220
- return tupexpr (i -> :((afactor * a[$ i]) + (bfactor * b[$ i])), N)
221
+ return tupexpr (i -> :(_mul_partial ( a[$ i], afactor ) + _mul_partial ( b[$ i], bfactor )), N)
221
222
end
222
223
223
224
# ##################
0 commit comments