| 
 | 1 | +## Base interface  | 
 | 2 | + | 
 | 3 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Nothing) =  | 
 | 4 | +    AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output)))  | 
 | 5 | + | 
 | 6 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Nothing) =  | 
 | 7 | +    AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output)))  | 
 | 8 | + | 
 | 9 | +Base._accumulate!(op, output::AnyGPUArray, input::MtlVector, dims::Nothing, init::Some) =  | 
 | 10 | +    AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init))  | 
 | 11 | + | 
 | 12 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Some) =  | 
 | 13 | +    AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init))  | 
 | 14 | + | 
 | 15 | +Base.accumulate_pairwise!(op, result::AnyGPUVector, v::AnyGPUVector) = accumulate!(op, result, v)  | 
 | 16 | + | 
 | 17 | +# default behavior unless dims are specified by the user  | 
 | 18 | +function Base.accumulate(op, A::WrappedGPUArray;  | 
 | 19 | +                         dims::Union{Nothing,Integer}=nothing, kw...)  | 
 | 20 | +    nt = values(kw)  | 
 | 21 | +    if dims === nothing && !(A isa AbstractVector)  | 
 | 22 | +        # This branch takes care of the cases not handled by `_accumulate!`.  | 
 | 23 | +        return reshape(AK.accumulate(op, A[:], get_backend(A); init = (:init in keys(kw) ? nt.init : AK.neutral_element(op, eltype(A)))), size(A))  | 
 | 24 | +    end  | 
 | 25 | +    if isempty(kw)  | 
 | 26 | +        out = similar(A, Base.promote_op(op, eltype(A), eltype(A)))  | 
 | 27 | +        init = AK.neutral_element(op, eltype(out))  | 
 | 28 | +    elseif keys(nt) === (:init,)  | 
 | 29 | +        out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A)))  | 
 | 30 | +        init = nt.init  | 
 | 31 | +    else  | 
 | 32 | +        throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))"))  | 
 | 33 | +    end  | 
 | 34 | +    AK.accumulate!(op, out, A, get_backend(A); dims, init)  | 
 | 35 | +end  | 
0 commit comments