@@ -347,6 +347,43 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
347347 R
348348end
349349
350+ # # Base interface
351+
352+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLVector , dims:: Nothing , init:: Nothing ) =
353+ accumulate! (op, typed_data (output), typed_data (input); dims= 1 )
354+
355+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLArray , dims:: Integer , init:: Nothing ) =
356+ accumulate! (op, typed_data (output), typed_data (input); dims)
357+
358+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLVector , dims:: Nothing , init:: Some ) =
359+ accumulate! (op, typed_data (output), typed_data (input); dims= 1 , init= something (init))
360+
361+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLArray , dims:: Integer , init:: Some ) =
362+ accumulate! (op, typed_data (output), typed_data (input); dims, init= something (init))
363+
364+ Base. accumulate_pairwise! (op, result:: AnyJLVector , v:: AnyJLVector ) = accumulate! (op, result, v)
365+
366+ # default behavior unless dims are specified by the user
367+ function Base. accumulate (op, A:: AnyJLArray ;
368+ dims:: Union{Nothing,Integer} = nothing , kw... )
369+ nt = values (kw)
370+ if dims === nothing && ! (A isa AbstractVector)
371+ # This branch takes care of the cases not handled by `_accumulate!`.
372+ return reshape (accumulate (op, typed_data (A)[:]; kw... ), size (A))
373+ end
374+ if isempty (kw)
375+ out = similar (A, Base. promote_op (op, eltype (A), eltype (A)))
376+ init = AK. neutral_element (op, eltype (out))
377+ elseif keys (nt) === (:init ,)
378+ out = similar (A, Base. promote_op (op, typeof (nt. init), eltype (A)))
379+ init = nt. init
380+ else
381+ throw (ArgumentError (" accumulate does not support the keyword arguments $(setdiff (keys (nt), (:init ,))) " ))
382+ end
383+ accumulate! (op, typed_data (out), typed_data (A); dims, init)
384+ end
385+
386+
350387# # KernelAbstractions interface
351388
352389KernelAbstractions. get_backend (a:: JLA ) where JLA <: JLArray = JLBackend ()
0 commit comments