@@ -84,17 +84,22 @@ _logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_o
8484
8585# reduce two numbers
8686function _logsumexp_onepass_op (x1:: T , x2:: T ) where {T<: Number }
87- xmax, a = if x1 == x2
88- # handle `x1 = x2 = ±Inf` correctly
89- x2, zero (x1 - x2)
90- elseif isnan (x1) || isnan (x2)
87+ xmax, a = if isnan (x1) || isnan (x2)
9188 # ensure that `NaN` is propagated correctly for complex numbers
9289 z = oftype (x1, NaN )
9390 z, exp (z)
94- elseif real (x1) > real (x2)
95- x1, x2 - x1
9691 else
97- x2, x1 - x2
92+ real_x1 = real (x1)
93+ real_x2 = real (x2)
94+ if real_x1 > real_x2
95+ x1, x2 - x1
96+ elseif real_x1 < real_x2
97+ x2, x1 - x2
98+ else
99+ # handle `x1 = x2 = ±Inf` correctly
100+ # checking inequalities above instead of equality fixes issue #59
101+ x2, zero (x1 - x2)
102+ end
98103 end
99104 r = exp (a)
100105 return xmax, r
@@ -109,17 +114,22 @@ _logsumexp_onepass_op((xmax, r)::Tuple{<:Number,<:Number}, x::Number) =
109114_logsumexp_onepass_op (x:: Number , xmax:: Number , r:: Number ) =
110115 _logsumexp_onepass_op (promote (x, xmax)... , r)
111116function _logsumexp_onepass_op (x:: T , xmax:: T , r:: Number ) where {T<: Number }
112- _xmax, _r = if x == xmax
113- # handle `x = xmax = ±Inf` correctly
114- xmax, r + exp (zero (x - xmax))
115- elseif isnan (x) || isnan (xmax)
117+ _xmax, _r = if isnan (x) || isnan (xmax)
116118 # ensure that `NaN` is propagated correctly for complex numbers
117119 z = oftype (x, NaN )
118120 z, r + exp (z)
119- elseif real (x) > real (xmax)
120- x, (r + one (r)) * exp (xmax - x)
121121 else
122- xmax, r + exp (x - xmax)
122+ real_x = real (x)
123+ real_xmax = real (xmax)
124+ if real_x > real_xmax
125+ x, (r + one (r)) * exp (xmax - x)
126+ elseif real_x < real_xmax
127+ xmax, r + exp (x - xmax)
128+ else
129+ # handle `x = xmax = ±Inf` correctly
130+ # checking inequalities above instead of equality fixes issue #59
131+ xmax, r + exp (zero (x - xmax))
132+ end
123133 end
124134 return _xmax, _r
125135end
@@ -134,17 +144,22 @@ function _logsumexp_onepass_op(xmax1::Number, xmax2::Number, r1::Number, r2::Num
134144 return _logsumexp_onepass_op (promote (xmax1, xmax2)... , promote (r1, r2)... )
135145end
136146function _logsumexp_onepass_op (xmax1:: T , xmax2:: T , r1:: R , r2:: R ) where {T<: Number ,R<: Number }
137- xmax, r = if xmax1 == xmax2
138- # handle `xmax1 = xmax2 = ±Inf` correctly
139- xmax2, r2 + (r1 + one (r1)) * exp (zero (xmax1 - xmax2))
140- elseif isnan (xmax1) || isnan (xmax2)
147+ xmax, r = if isnan (xmax1) || isnan (xmax2)
141148 # ensure that `NaN` is propagated correctly for complex numbers
142149 z = oftype (xmax1, NaN )
143150 z, r1 + exp (z)
144- elseif real (xmax1) > real (xmax2)
145- xmax1, r1 + (r2 + one (r2)) * exp (xmax2 - xmax1)
146151 else
147- xmax2, r2 + (r1 + one (r1)) * exp (xmax1 - xmax2)
152+ real_xmax1 = real (xmax1)
153+ real_xmax2 = real (xmax2)
154+ if real_xmax1 > real_xmax2
155+ xmax1, r1 + (r2 + one (r2)) * exp (xmax2 - xmax1)
156+ elseif real_xmax1 < real_xmax2
157+ xmax2, r2 + (r1 + one (r1)) * exp (xmax1 - xmax2)
158+ else
159+ # handle `xmax1 = xmax2 = ±Inf` correctly
160+ # checking inequalities above instead of equality fixes issue #59
161+ xmax2, r2 + (r1 + one (r1)) * exp (zero (xmax1 - xmax2))
162+ end
148163 end
149164 return xmax, r
150165end
0 commit comments