@@ -82,6 +82,18 @@ Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = parti
82
82
@inline Base.:- (partials:: Partials ) = Partials (minus_tuple (partials. values))
83
83
@inline Base.:* (x:: Real , partials:: Partials ) = partials* x
84
84
85
+ @inline function Base.:* (partials:: Partials , x:: Real )
86
+ return Partials (scale_tuple (partials. values, x))
87
+ end
88
+
89
+ @inline function Base.:/ (partials:: Partials , x:: Real )
90
+ return Partials (div_tuple_by_scalar (partials. values, x))
91
+ end
92
+
93
+ @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
94
+ return Partials (mul_tuples (a. values, b. values, x_a, x_b))
95
+ end
96
+
85
97
@inline function _div_partials (a:: Partials , b:: Partials , aval, bval)
86
98
return _mul_partials (a, b, inv (bval), - (aval / (bval* bval)))
87
99
end
90
102
# ----------------------#
91
103
92
104
if NANSAFE_MODE_ENABLED
93
- @inline function Base.:* (partials:: Partials , x:: Real )
94
- x = ifelse (! isfinite (x) && iszero (partials), one (x), x)
95
- return Partials (scale_tuple (partials. values, x))
96
- end
97
-
98
- @inline function Base.:/ (partials:: Partials , x:: Real )
99
- x = ifelse (x == zero (x) && iszero (partials), one (x), x)
100
- return Partials (div_tuple_by_scalar (partials. values, x))
105
+ @inline function _mul_partial (partial:: Real , x:: Real )
106
+ y = partial * x
107
+ return ! isfinite (x) && iszero (partial) ? zero (y) : y
101
108
end
102
-
103
- @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
104
- x_a = ifelse (! isfinite (x_a) && iszero (a), one (x_a), x_a)
105
- x_b = ifelse (! isfinite (x_b) && iszero (b), one (x_b), x_b)
106
- return Partials (mul_tuples (a. values, b. values, x_a, x_b))
109
+ @inline function _div_partial (partial:: Real , x:: Real )
110
+ y = partial / x
111
+ return iszero (x) && iszero (partial) ? zero (y) : y
107
112
end
108
113
else
109
- @inline function Base.:* (partials:: Partials , x:: Real )
110
- return Partials (scale_tuple (partials. values, x))
111
- end
112
-
113
- @inline function Base.:/ (partials:: Partials , x:: Real )
114
- return Partials (div_tuple_by_scalar (partials. values, x))
115
- end
116
-
117
- @inline function _mul_partials (a:: Partials{N} , b:: Partials{N} , x_a, x_b) where N
118
- return Partials (mul_tuples (a. values, b. values, x_a, x_b))
119
- end
114
+ @inline _mul_partial (partial:: Real , x:: Real ) = partial * x
115
+ @inline _div_partial (partial:: Real , x:: Real ) = partial / x
120
116
end
121
117
122
118
# edge cases where N == 0 #
@@ -197,11 +193,11 @@ end
197
193
end
198
194
199
195
@generated function scale_tuple (tup:: NTuple{N} , x) where N
200
- return tupexpr (i -> :(tup[$ i] * x ), N)
196
+ return tupexpr (i -> :(_mul_partial ( tup[$ i], x) ), N)
201
197
end
202
198
203
199
@generated function div_tuple_by_scalar (tup:: NTuple{N} , x) where N
204
- return tupexpr (i -> :(tup[$ i] / x ), N)
200
+ return tupexpr (i -> :(_div_partial ( tup[$ i], x) ), N)
205
201
end
206
202
207
203
@generated function add_tuples (a:: NTuple{N} , b:: NTuple{N} ) where N
217
213
end
218
214
219
215
@generated function mul_tuples (a:: NTuple{N} , b:: NTuple{N} , afactor, bfactor) where N
220
- return tupexpr (i -> :((afactor * a[$ i]) + (bfactor * b[$ i])), N)
216
+ return tupexpr (i -> :(_mul_partial ( a[$ i], afactor ) + _mul_partial ( b[$ i], bfactor )), N)
221
217
end
222
218
223
219
# ##################
0 commit comments