From bfbfbefcc9e137fa6c1b0add9f0fb9f182477849 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 8 Oct 2025 15:53:07 +0200 Subject: [PATCH] Sparse GPU array and broadcasting support --- Project.toml | 2 + lib/JLArrays/Project.toml | 4 + lib/JLArrays/src/JLArrays.jl | 131 ++++++- src/GPUArrays.jl | 2 + src/device/sparse.jl | 135 +++++++ src/host/sparse.jl | 658 +++++++++++++++++++++++++++++++++++ test/Project.toml | 1 + test/runtests.jl | 8 +- test/testsuite.jl | 1 + test/testsuite/sparse.jl | 147 ++++++++ 10 files changed, 1087 insertions(+), 2 deletions(-) create mode 100644 src/device/sparse.jl create mode 100644 src/host/sparse.jl create mode 100644 test/testsuite/sparse.jl diff --git a/Project.toml b/Project.toml index c399348c..e996af59 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] @@ -33,5 +34,6 @@ Random = "1" Reexport = "1" ScopedValues = "1" Serialization = "1" +SparseArrays = "1" Statistics = "1" julia = "1.10" diff --git a/lib/JLArrays/Project.toml b/lib/JLArrays/Project.toml index 700a31aa..32674276 100644 --- a/lib/JLArrays/Project.toml +++ b/lib/JLArrays/Project.toml @@ -7,11 +7,15 @@ version = "0.3.0" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] Adapt = "2.0, 3.0, 4.0" GPUArrays = "11.1" KernelAbstractions = "0.9, 0.10" +LinearAlgebra = "1" Random = "1" +SparseArrays = "1" julia = "1.8" diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 4b238fa0..5ada9767 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -6,11 +6,14 @@ module JLArrays -export JLArray, JLVector, JLMatrix, jl, JLBackend +export JLArray, JLVector, JLMatrix, jl, JLBackend, JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR using GPUArrays using Adapt +using SparseArrays, LinearAlgebra + +import GPUArrays: _dense_array_type import KernelAbstractions import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config @@ -115,7 +118,90 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} end end +mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, Ti} + iPtr::JLArray{Ti, 1} + nzVal::JLArray{Tv, 1} + len::Int + nnz::Ti + + function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1}, + len::Integer) where {Tv, Ti <: Integer} + new{Tv, Ti}(iPtr, nzVal, len, length(nzVal)) + end +end +SparseArrays.SparseVector(x::JLSparseVector) = SparseVector(length(x), Array(x.iPtr), Array(x.nzVal)) +SparseArrays.nnz(x::JLSparseVector) = x.nnz +SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr +SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal + +mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti} + colPtr::JLArray{Ti, 1} + rowVal::JLArray{Ti, 1} + nzVal::JLArray{Tv, 1} + dims::NTuple{2,Int} + nnz::Ti + + function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1}, + nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} + new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal)) + end +end +function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} + return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims) +end +SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal)) + +JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A + +function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti} + r1 = Int(@inbounds A.colPtr[j]) + r2 = Int(@inbounds A.colPtr[j+1]-1) + (r1 > r2) && return zero(Tv) + r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1 + ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1] +end + +mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti} + rowPtr::JLArray{Ti, 1} + colVal::JLArray{Ti, 1} + nzVal::JLArray{Tv, 1} + dims::NTuple{2,Int} + nnz::Ti + + function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1}, + nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer} + new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal)) + end +end +function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} + return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims) +end +function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR) + x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal)) + return SparseMatrixCSC(transpose(x_transpose)) +end + +JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A + GPUArrays.storage(a::JLArray) = a.data +GPUArrays._dense_array_type(a::JLArray{T, N}) where {T, N} = JLArray{T, N} +GPUArrays._dense_array_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, N} +GPUArrays._dense_vector_type(a::JLArray{T, N}) where {T, N} = JLArray{T, 1} +GPUArrays._dense_vector_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, 1} + +GPUArrays._sparse_array_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSC +GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSC}) = JLSparseMatrixCSC +GPUArrays._sparse_array_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSR +GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR +GPUArrays._sparse_array_type(sa::JLSparseVector) = JLSparseVector +GPUArrays._sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector + +GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray +GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray +GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray +GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray +GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray +GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray # conversion of untyped data to a typed Array function typed_data(x::JLArray{T}) where {T} @@ -217,6 +303,41 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs) (::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x) JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A) +function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv} + iPtr = JLVector{Ti}(undef, length(xs.nzind)) + nzVal = JLVector{Tv}(undef, length(xs.nzval)) + copyto!(iPtr, convert(Vector{Ti}, xs.nzind)) + copyto!(nzVal, convert(Vector{Tv}, xs.nzval)) + return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),) +end +Base.length(x::JLSparseVector) = x.len +Base.size(x::JLSparseVector) = (x.len,) + +function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv} + colPtr = JLVector{Ti}(undef, length(xs.colptr)) + rowVal = JLVector{Ti}(undef, length(xs.rowval)) + nzVal = JLVector{Tv}(undef, length(xs.nzval)) + copyto!(colPtr, convert(Vector{Ti}, xs.colptr)) + copyto!(rowVal, convert(Vector{Ti}, xs.rowval)) + copyto!(nzVal, convert(Vector{Tv}, xs.nzval)) + return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n)) +end +Base.length(x::JLSparseMatrixCSC) = prod(x.dims) +Base.size(x::JLSparseMatrixCSC) = x.dims + +function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv} + csr_xs = SparseMatrixCSC(transpose(xs)) + rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr)) + colVal = JLVector{Ti}(undef, length(csr_xs.rowval)) + nzVal = JLVector{Tv}(undef, length(csr_xs.nzval)) + copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr)) + copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval)) + copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval)) + return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n)) +end +Base.length(x::JLSparseMatrixCSR) = prod(x.dims) +Base.size(x::JLSparseMatrixCSR) = x.dims + # idempotency JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs @@ -358,9 +479,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br R end +Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = +GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz) +Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} = +GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz) +Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} = +GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz) + ## KernelAbstractions interface KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend() +KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend() function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace) diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl index 8c1fc14e..a35c1ff0 100644 --- a/src/GPUArrays.jl +++ b/src/GPUArrays.jl @@ -19,6 +19,7 @@ using KernelAbstractions # device functionality include("device/abstractarray.jl") +include("device/sparse.jl") # host abstractions include("host/abstractarray.jl") @@ -34,6 +35,7 @@ include("host/random.jl") include("host/quirks.jl") include("host/uniformscaling.jl") include("host/statistics.jl") +include("host/sparse.jl") include("host/alloc_cache.jl") diff --git a/src/device/sparse.jl b/src/device/sparse.jl new file mode 100644 index 00000000..77abe0df --- /dev/null +++ b/src/device/sparse.jl @@ -0,0 +1,135 @@ +# on-device sparse array types +# should be excluded from coverage counts +# COV_EXCL_START +using SparseArrays + +# NOTE: this functionality is currently very bare-bones, only defining the array types +# without any device-compatible sparse array functionality + + +# core types + +export GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR, + GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO + +abstract type AbstractGPUSparseDeviceMatrix{Tv, Ti} <: AbstractSparseMatrix{Tv, Ti} end + + +struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti} + iPtr::Vi + nzVal::Vv + len::Int + nnz::Ti +end + +Base.length(g::GPUSparseDeviceVector) = g.len +Base.size(g::GPUSparseDeviceVector) = (g.len,) +SparseArrays.nnz(g::GPUSparseDeviceVector) = g.nnz +SparseArrays.nonzeroinds(g::GPUSparseDeviceVector) = g.iPtr +SparseArrays.nonzeros(g::GPUSparseDeviceVector) = g.nzVal + +struct GPUSparseDeviceMatrixCSC{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti} + colPtr::Vi + rowVal::Vi + nzVal::Vv + dims::NTuple{2,Int} + nnz::Ti +end + +SparseArrays.rowvals(g::GPUSparseDeviceMatrixCSC) = g.rowVal +SparseArrays.getcolptr(g::GPUSparseDeviceMatrixCSC) = g.colPtr +SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1) + +struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti} + rowPtr::Vi + colVal::Vi + nzVal::Vv + dims::NTuple{2, Int} + nnz::Ti +end + +struct GPUSparseDeviceMatrixBSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti} + rowPtr::Vi + colVal::Vi + nzVal::Vv + dims::NTuple{2,Int} + blockDim::Ti + dir::Char + nnz::Ti +end + +struct GPUSparseDeviceMatrixCOO{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti} + rowInd::Vi + colInd::Vi + nzVal::Vv + dims::NTuple{2,Int} + nnz::Ti +end + +Base.length(g::AbstractGPUSparseDeviceMatrix) = prod(g.dims) +Base.size(g::AbstractGPUSparseDeviceMatrix) = g.dims +SparseArrays.nnz(g::AbstractGPUSparseDeviceMatrix) = g.nnz +SparseArrays.getnzval(g::AbstractGPUSparseDeviceMatrix) = g.nzVal + +struct GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M} <: AbstractSparseArray{Tv, Ti, N} + rowPtr::Vi + colVal::Vi + nzVal::Vv + dims::NTuple{N, Int} + nnz::Ti +end + +function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, Vi<:AbstractDeviceArray{<:Integer,M}, Vv<:AbstractDeviceArray{Tv, M}, N} + @assert M == N - 1 "GPUSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1" + GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal)) +end + +Base.length(g::GPUSparseDeviceArrayCSR) = prod(g.dims) +Base.size(g::GPUSparseDeviceArrayCSR) = g.dims +SparseArrays.nnz(g::GPUSparseDeviceArrayCSR) = g.nnz +SparseArrays.getnzval(g::GPUSparseDeviceArrayCSR) = g.nzVal + +# input/output + +function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceVector) + println(io, "$(length(A))-element device sparse vector at:") + println(io, " iPtr: $(A.iPtr)") + print(io, " nzVal: $(A.nzVal)") +end + +function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSR) + println(io, "$(length(A))-element device sparse matrix CSR at:") + println(io, " rowPtr: $(A.rowPtr)") + println(io, " colVal: $(A.colVal)") + print(io, " nzVal: $(A.nzVal)") +end + +function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSC) + println(io, "$(length(A))-element device sparse matrix CSC at:") + println(io, " colPtr: $(A.colPtr)") + println(io, " rowVal: $(A.rowVal)") + print(io, " nzVal: $(A.nzVal)") +end + +function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixBSR) + println(io, "$(length(A))-element device sparse matrix BSR at:") + println(io, " rowPtr: $(A.rowPtr)") + println(io, " colVal: $(A.colVal)") + print(io, " nzVal: $(A.nzVal)") +end + +function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCOO) + println(io, "$(length(A))-element device sparse matrix COO at:") + println(io, " rowPtr: $(A.rowPtr)") + println(io, " colInd: $(A.colInd)") + print(io, " nzVal: $(A.nzVal)") +end + +function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceArrayCSR) + println(io, "$(length(A))-element device sparse array CSR at:") + println(io, " rowPtr: $(A.rowPtr)") + println(io, " colVal: $(A.colVal)") + print(io, " nzVal: $(A.nzVal)") +end + +# COV_EXCL_STOP diff --git a/src/host/sparse.jl b/src/host/sparse.jl new file mode 100644 index 00000000..ff820304 --- /dev/null +++ b/src/host/sparse.jl @@ -0,0 +1,658 @@ +abstract type AbstractGPUSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end +const AbstractGPUSparseVector{Tv, Ti} = AbstractGPUSparseArray{Tv, Ti, 1} +const AbstractGPUSparseMatrix{Tv, Ti} = AbstractGPUSparseArray{Tv, Ti, 2} + +abstract type AbstractGPUSparseMatrixCSC{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end +abstract type AbstractGPUSparseMatrixCSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end +abstract type AbstractGPUSparseMatrixCOO{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end +abstract type AbstractGPUSparseMatrixBSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end + +const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector,AbstractGPUSparseMatrix} + +Base.convert(T::Type{<:AbstractGPUSparseArray}, m::AbstractArray) = m isa T ? m : T(m) + +_dense_array_type(sa::SparseVector) = SparseVector +_dense_array_type(::Type{SparseVector}) = SparseVector +_sparse_array_type(sa::SparseVector) = SparseVector +_dense_vector_type(sa::AbstractSparseArray) = Vector +_dense_vector_type(sa::AbstractArray) = Vector +_dense_vector_type(::Type{<:AbstractSparseArray}) = Vector +_dense_vector_type(::Type{<:AbstractArray}) = Vector +_dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC +_dense_array_type(::Type{SparseMatrixCSC}) = SparseMatrixCSC +_sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC + +function _sparse_array_type(sa::AbstractGPUSparseArray) end +function _dense_array_type(sa::AbstractGPUSparseArray) end + +### BROADCAST + +# broadcast container type promotion for combinations of sparse arrays and other types +struct GPUSparseVecStyle <: Broadcast.AbstractArrayStyle{1} end +struct GPUSparseMatStyle <: Broadcast.AbstractArrayStyle{2} end +Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseVector}) = GPUSparseVecStyle() +Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseMatrix}) = GPUSparseMatStyle() +const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle} + +# GPUSparseVecStyle handles 0-1 dimensions, GPUSparseMatStyle 0-2 dimensions. +# GPUSparseVecStyle promotes to GPUSparseMatStyle for 2 dimensions. +# Fall back to DefaultArrayStyle for higher dimensionality. +GPUSparseVecStyle(::Val{0}) = GPUSparseVecStyle() +GPUSparseVecStyle(::Val{1}) = GPUSparseVecStyle() +GPUSparseVecStyle(::Val{2}) = GPUSparseMatStyle() +GPUSparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() +GPUSparseMatStyle(::Val{0}) = GPUSparseMatStyle() +GPUSparseMatStyle(::Val{1}) = GPUSparseMatStyle() +GPUSparseMatStyle(::Val{2}) = GPUSparseMatStyle() +GPUSparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() + +Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{1}) = GPUSparseVecStyle() +Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{2}) = GPUSparseMatStyle() +Broadcast.BroadcastStyle(::GPUSparseMatStyle, ::AbstractGPUArrayStyle{2}) = GPUSparseMatStyle() + +# don't wrap sparse arrays with Extruded +Broadcast.extrude(x::AbstractGPUSparseVecOrMat) = x + +## detection of zero-preserving functions + +# modified from SparseArrays.jl + +# capturescalars takes a function (f) and a tuple of broadcast arguments, and returns a +# partially-evaluated function and a reduced argument tuple where all scalar operations have +# been applied already. +@inline function capturescalars(f, mixedargs) + let (passedsrcargstup, makeargs) = _capturescalars(mixedargs...) + parevalf = (passed...) -> f(makeargs(passed...)...) + return (parevalf, passedsrcargstup) + end +end + +## sparse broadcast style + +# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates +@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((args...)->f(T, args...), Base.tail(mixedargs)) +@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Ref{Type{S}}, Vararg{Any}}) where {T, S} = + # This definition is identical to the one above and necessary only for + # avoiding method ambiguity. + capturescalars((args...)->f(T, args...), Base.tail(mixedargs)) +@inline capturescalars(f, mixedargs::Tuple{AbstractGPUSparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...)) +@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{<:Any,0}}, Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs))) + +scalararg(::Number) = true +scalararg(::Any) = false +scalarwrappedarg(::Union{AbstractArray{<:Any,0},Ref}) = true +scalarwrappedarg(::Any) = false + +@inline function _capturescalars() + return (), () -> () +end +@inline function _capturescalars(arg, mixedargs...) + let (rest, f) = _capturescalars(mixedargs...) + if scalararg(arg) + return rest, @inline function(tail...) + (arg, f(tail...)...) + end # add back scalararg after (in makeargs) + elseif scalarwrappedarg(arg) + return rest, @inline function(tail...) + (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple + end # unwrap and add back scalararg after (in makeargs) + else + return (arg, rest...), @inline function(head, tail...) + (head, f(tail...)...) + end # pass-through to broadcast + end + end +end +@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner) + if scalararg(arg) + return (), () -> (arg,) # add scalararg + elseif scalarwrappedarg(arg) + return (), () -> (arg[],) # unwrap + else + return (arg,), (head,) -> (head,) # pass-through + end +end + +@inline _iszero(x) = x == 0 +@inline _iszero(x::Number) = Base.iszero(x) +@inline _iszero(x::AbstractArray) = Base.iszero(x) +@inline _zeros_eltypes(A) = (zero(eltype(A)),) +@inline _zeros_eltypes(A, Bs...) = (zero(eltype(A)), _zeros_eltypes(Bs...)...) + +## COV_EXCL_START +## iteration helpers + +""" + CSRIterator{Ti}(row, args...) + +A GPU-compatible iterator for accessing the elements of a single row `row` of several CSR +matrices `args` in one go. The row should be in-bounds for every sparse argument. Each +iteration returns a 2-element tuple: The current column, and each arguments' pointer index +(or 0 if that input didn't have an element at that column). The pointers can then be used to +access the elements themselves. + +For convenience, this iterator can be passed non-sparse arguments as well, which will be +ignored (with the returned `col`/`ptr` values set to 0). +""" +struct CSRIterator{Ti,N,ATs} + row::Ti + col_ends::NTuple{N, Ti} + args::ATs +end + +function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N} + # check that `row` is valid for all arguments + @boundscheck begin + ntuple(Val(N)) do i + arg = @inbounds args[i] + arg isa GPUSparseDeviceMatrixCSR && checkbounds(arg, row, 1) + end + end + + col_ends = ntuple(Val(N)) do i + arg = @inbounds args[i] + if arg isa GPUSparseDeviceMatrixCSR + @inbounds(arg.rowPtr[row+1]) + else + zero(Ti) + end + end + + CSRIterator{Ti, N, typeof(args)}(row, col_ends, args) +end + +@inline function Base.iterate(iter::CSRIterator{Ti,N}, state=nothing) where {Ti,N} + # helper function to get the column of a sparse array at a specific pointer + @inline function get_col(i, ptr) + arg = @inbounds iter.args[i] + if arg isa GPUSparseDeviceMatrixCSR + col_end = @inbounds iter.col_ends[i] + if ptr < col_end + return @inbounds arg.colVal[ptr] % Ti + end + end + typemax(Ti) + end + + # initialize the state + # - ptr: the current index into the colVal/nzVal arrays + # - col: the current column index (cached so that we don't have to re-read each time) + state = something(state, + ntuple(Val(N)) do i + arg = @inbounds iter.args[i] + if arg isa GPUSparseDeviceMatrixCSR + ptr = @inbounds iter.args[i].rowPtr[iter.row] % Ti + col = @inbounds get_col(i, ptr) + else + ptr = typemax(Ti) + col = typemax(Ti) + end + (; ptr, col) + end + ) + + # determine the column we're currently processing + cols = ntuple(i -> @inbounds(state[i].col), Val(N)) + cur_col = min(cols...) + cur_col == typemax(Ti) && return + + # fetch the pointers (we don't look up the values, as the caller might want to index + # the sparse array directly, e.g., to mutate it). we don't return `ptrs` from the state + # directly, but first convert the `typemax(Ti)` to a more convenient zero value. + # NOTE: these values may end up unused by the caller (e.g. in the count_nnzs kernels), + # but LLVM appears smart enough to filter them away. + ptrs = ntuple(Val(N)) do i + ptr, col = @inbounds state[i] + col == cur_col ? ptr : zero(Ti) + end + + # advance the state + new_state = ntuple(Val(N)) do i + ptr, col = @inbounds state[i] + if col == cur_col + ptr += one(Ti) + col = get_col(i, ptr) + end + (; ptr, col) + end + + return (cur_col, ptrs), new_state +end + +struct CSCIterator{Ti,N,ATs} + col::Ti + row_ends::NTuple{N, Ti} + args::ATs +end + +function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N} + # check that `col` is valid for all arguments + @boundscheck begin + ntuple(Val(N)) do i + arg = @inbounds args[i] + arg isa GPUSparseDeviceMatrixCSR && checkbounds(arg, 1, col) + end + end + + row_ends = ntuple(Val(N)) do i + arg = @inbounds args[i] + x = if arg isa GPUSparseDeviceMatrixCSC + @inbounds(arg.colPtr[col+1]) + else + zero(Ti) + end + x + end + + CSCIterator{Ti, N, typeof(args)}(col, row_ends, args) +end + +@inline function Base.iterate(iter::CSCIterator{Ti,N}, state=nothing) where {Ti,N} + # helper function to get the column of a sparse array at a specific pointer + @inline function get_col(i, ptr) + arg = @inbounds iter.args[i] + if arg isa GPUSparseDeviceMatrixCSC + col_end = @inbounds iter.row_ends[i] + if ptr < col_end + return @inbounds arg.rowVal[ptr] % Ti + end + end + typemax(Ti) + end + + # initialize the state + # - ptr: the current index into the rowVal/nzVal arrays + # - row: the current row index (cached so that we don't have to re-read each time) + state = something(state, + ntuple(Val(N)) do i + arg = @inbounds iter.args[i] + if arg isa GPUSparseDeviceMatrixCSC + ptr = @inbounds iter.args[i].colPtr[iter.col] % Ti + row = @inbounds get_col(i, ptr) + else + ptr = typemax(Ti) + row = typemax(Ti) + end + (; ptr, row) + end + ) + + # determine the row we're currently processing + rows = ntuple(i -> @inbounds(state[i].row), Val(N)) + cur_row = min(rows...) + cur_row == typemax(Ti) && return + + # fetch the pointers (we don't look up the values, as the caller might want to index + # the sparse array directly, e.g., to mutate it). we don't return `ptrs` from the state + # directly, but first convert the `typemax(Ti)` to a more convenient zero value. + # NOTE: these values may end up unused by the caller (e.g. in the count_nnzs kernels), + # but LLVM appears smart enough to filter them away. + ptrs = ntuple(Val(N)) do i + ptr, row = @inbounds state[i] + row == cur_row ? ptr : zero(Ti) + end + + # advance the state + new_state = ntuple(Val(N)) do i + ptr, row = @inbounds state[i] + if row == cur_row + ptr += one(Ti) + row = get_col(i, ptr) + end + (; ptr, row) + end + + return (cur_row, ptrs), new_state +end + +# helpers to index a sparse or dense array +function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR,GPUSparseDeviceMatrixCSC}, I, ptr) + if ptr == 0 + zero(eltype(arg)) + else + @inbounds arg.nzVal[ptr] + end +end +@inline function _getindex(arg::AbstractDeviceArray{Tv}, I, ptr)::Tv where {Tv} + return @inbounds arg[I]::Tv +end +@inline _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I) + +## sparse broadcast implementation +iter_type(::Type{<:AbstractGPUSparseMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} +iter_type(::Type{<:AbstractGPUSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} +iter_type(::Type{<:GPUSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti} +iter_type(::Type{<:GPUSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti} + +_has_row(A, offsets, row, fpreszeros::Bool) = fpreszeros ? 0 : row +_has_row(A::AbstractDeviceArray, offsets, row, ::Bool) = row +function _has_row(A::GPUSparseDeviceVector, offsets, row, ::Bool) + for row_ix in 1:length(A.iPtr) + arg_row = @inbounds A.iPtr[row_ix] + arg_row == row && return row_ix + arg_row > row && break + end + return 0 +end + +@kernel function compute_offsets_kernel(::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti, + fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, + args...) where {Ti, N} + my_ix = @index(Global, Linear) + row = my_ix + first_row - one(eltype(my_ix)) + if row ≤ last_row + # TODO load arg.iPtr slices into shared memory + arg_row_is_nnz = ntuple(Val(N)) do i + arg = @inbounds args[i] + _has_row(arg, offsets, row, fpreszeros) + end + row_is_nnz = 0 + for i in 1:N + row_is_nnz |= @inbounds arg_row_is_nnz[i] + end + key = (row_is_nnz == 0) ? typemax(Ti) : row + @inbounds offsets[my_ix] = key => arg_row_is_nnz + end +end + +# kernel to count the number of non-zeros in a row, to determine the row offsets +@kernel function compute_offsets_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}}, + offsets::AbstractVector{Ti}, args...) where Ti + # every thread processes an entire row + leading_dim = @index(Global, Linear) + if leading_dim ≤ length(offsets)-1 + iter = @inbounds iter_type(T, Ti)(leading_dim, args...) + + # count the nonzero leading_dims of all inputs + accum = zero(Ti) + for (leading_dim, vals) in iter + accum += one(Ti) + end + + # the way we write the nnz counts is a bit strange, but done so that the result + # after accumulation can be directly used as the rowPtr/colPtr array of a CSR/CSC matrix. + @inbounds begin + if leading_dim == 1 + offsets[1] = 1 + end + offsets[leading_dim+1] = accum + end + end +end + +@kernel function sparse_to_sparse_broadcast_kernel(f::F, output::GPUSparseDeviceVector{Tv,Ti}, + offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, + args...) where {Tv, Ti, N, F} + row_ix = @index(Global, Linear) + if row_ix ≤ output.nnz + row_and_ptrs = @inbounds offsets[row_ix] + row = @inbounds row_and_ptrs[1] + arg_ptrs = @inbounds row_and_ptrs[2] + vals = ntuple(Val(N)) do i + @inline + arg = @inbounds args[i] + # ptr is 0 if the sparse vector doesn't have an element at this row + # ptr is 0 if the arg is a scalar AND f preserves zeros + ptr = @inbounds arg_ptrs[i] + _getindex(arg, row, ptr) + end + output_val = f(vals...) + @inbounds output.iPtr[row_ix] = row + @inbounds output.nzVal[row_ix] = output_val + end +end + +@kernel function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{<:AbstractArray,Nothing}, + args...) where {Ti, T<:Union{GPUSparseDeviceMatrixCSR{<:Any,Ti}, + GPUSparseDeviceMatrixCSC{<:Any,Ti}}} + # every thread processes an entire row + leading_dim = @index(Global, Linear) + leading_dim_size = output isa GPUSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2) + if leading_dim ≤ leading_dim_size + iter = @inbounds iter_type(T, Ti)(leading_dim, args...) + + + output_ptrs = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr + output_ivals = output isa GPUSparseDeviceMatrixCSR ? output.colVal : output.rowVal + # fetch the row offset, and write it to the output + @inbounds begin + output_ptr = output_ptrs[leading_dim] = offsets[leading_dim] + if leading_dim == leading_dim_size + output_ptrs[leading_dim+one(eltype(leading_dim))] = offsets[leading_dim+one(eltype(leading_dim))] + end + end + + # set the values for this row + for (sub_leading_dim, ptrs) in iter + index_first = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim + index_second = output isa GPUSparseDeviceMatrixCSR ? sub_leading_dim : leading_dim + I = CartesianIndex(index_first, index_second) + vals = ntuple(Val(length(args))) do i + arg = @inbounds args[i] + ptr = @inbounds ptrs[i] + _getindex(arg, I, ptr) + end + + @inbounds output_ivals[output_ptr] = sub_leading_dim + @inbounds output.nzVal[output_ptr] = f(vals...) + output_ptr += one(Ti) + end + end +end +@kernel function sparse_to_dense_broadcast_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR{Tv, Ti}, + AbstractGPUSparseMatrixCSC{Tv, Ti}}}, + f, output::AbstractDeviceArray, args...) where {Tv, Ti} + # every thread processes an entire row + leading_dim = @index(Global, Linear) + leading_dim_size = T <: AbstractGPUSparseMatrixCSR ? size(output, 1) : size(output, 2) + if leading_dim ≤ leading_dim_size + iter = @inbounds iter_type(T, Ti)(leading_dim, args...) + + # set the values for this row + for (sub_leading_dim, ptrs) in iter + index_first = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim + index_second = T <: AbstractGPUSparseMatrixCSR ? sub_leading_dim : leading_dim + I = CartesianIndex(index_first, index_second) + vals = ntuple(Val(length(args))) do i + arg = @inbounds args[i] + ptr = @inbounds ptrs[i] + _getindex(arg, I, ptr) + end + + @inbounds output[I] = f(vals...) + end + end +end + +@kernel function sparse_to_dense_broadcast_kernel(::Type{<:AbstractGPUSparseVector}, f::F, + output::AbstractDeviceArray{Tv}, + offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, + args...) where {Tv, F, N, Ti} + # every thread processes an entire row + row_ix = @index(Global, Linear) + if row_ix ≤ length(output) + row_and_ptrs = @inbounds offsets[row_ix] + row = @inbounds row_and_ptrs[1] + arg_ptrs = @inbounds row_and_ptrs[2] + vals = ntuple(Val(length(args))) do i + @inline + arg = @inbounds args[i] + # ptr is 0 if the sparse vector doesn't have an element at this row + # ptr is row if the arg is dense OR a scalar with non-zero-preserving f + # ptr is 0 if the arg is a scalar AND f preserves zeros + ptr = @inbounds arg_ptrs[i] + _getindex(arg, row, ptr) + end + @inbounds output[row] = f(vals...) + end +end +## COV_EXCL_STOP + +function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatStyle}}) + # find the sparse inputs + bc = Broadcast.flatten(bc) + sparse_args = findall(bc.args) do arg + arg isa AbstractGPUSparseArray + end + sparse_types = unique(map(i->nameof(typeof(bc.args[i])), sparse_args)) + if length(sparse_types) > 1 + error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported") + end + sparse_typ = typeof(bc.args[first(sparse_args)]) + sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC,AbstractGPUSparseVector} || + error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices") + Ti = if sparse_typ <: AbstractGPUSparseMatrixCSR + reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args)) + elseif sparse_typ <: AbstractGPUSparseMatrixCSC + reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args)) + elseif sparse_typ <: AbstractGPUSparseVector + reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args)) + end + + # determine the output type + Tv = Broadcast.combine_eltypes(bc.f, eltype.(bc.args)) + if !Base.isconcretetype(Tv) + error("""GPU sparse broadcast resulted in non-concrete element type $Tv. + This probably means that the function you are broadcasting contains an error or type instability.""") + end + + # partially-evaluate the function, removing scalars. + parevalf, passedsrcargstup = capturescalars(bc.f, bc.args) + # check if the partially-evaluated function preserves zeros. if so, we'll only need to + # apply it to the sparse input arguments, preserving the sparse structure. + if all(arg->isa(arg, AbstractSparseArray), passedsrcargstup) + fofzeros = parevalf(_zeros_eltypes(passedsrcargstup...)...) + fpreszeros = _iszero(fofzeros) + else + fpreszeros = false + end + + # the kernels below parallelize across rows or cols, not elements, so it's unlikely + # we'll launch many threads. to maximize utilization, parallelize across blocks first. + rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1) + # `size(bc, ::Int)` is missing + # for AbstractGPUSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20 + # but the only rows present in any sparse vector input are between 2 and 128, no need to + # launch massive threads. + # TODO: use the difference here to set the thread count + overall_first_row = one(Ti) + overall_last_row = Ti(rows) + offsets = nothing + # allocate the output container + sparse_arg = bc.args[first(sparse_args)] + if !fpreszeros && sparse_typ <: Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC} + # either we have dense inputs, or the function isn't preserving zeros, + # so use a dense output to broadcast into. + val_array = sparse_arg.nzVal + output = similar(val_array, Tv, size(bc)) + # since we'll be iterating the sparse inputs, we need to pre-fill the dense output + # with appropriate values (while setting the sparse inputs to zero). we do this by + # re-using the dense broadcast implementation. + nonsparse_args = map(bc.args) do arg + # NOTE: this assumes the broadcast is flattened, but not yet preprocessed + if arg isa AbstractGPUSparseArray + zero(eltype(arg)) + else + arg + end + end + broadcast!(bc.f, output, nonsparse_args...) + elseif length(sparse_args) == 1 && sparse_typ <: Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC} + # we only have a single sparse input, so we can reuse its structure for the output. + # this avoids a kernel launch and costly synchronization. + if sparse_typ <: AbstractGPUSparseMatrixCSR + offsets = rowPtr = sparse_arg.rowPtr + colVal = similar(sparse_arg.colVal) + nzVal = similar(sparse_arg.nzVal, Tv) + output = _sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc)) + elseif sparse_typ <: AbstractGPUSparseMatrixCSC + offsets = colPtr = sparse_arg.colPtr + rowVal = similar(sparse_arg.rowVal) + nzVal = similar(sparse_arg.nzVal, Tv) + output = _sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc)) + end + else + # determine the number of non-zero elements per row so that we can create an + # appropriately-structured output container + offsets = if sparse_typ <: AbstractGPUSparseMatrixCSR + ptr_array = sparse_arg.rowPtr + similar(ptr_array, Ti, rows+1) + elseif sparse_typ <: AbstractGPUSparseMatrixCSC + ptr_array = sparse_arg.colPtr + similar(ptr_array, Ti, cols+1) + elseif sparse_typ <: AbstractGPUSparseVector + ptr_array = sparse_arg.iPtr + @allowscalar begin + arg_first_rows = ntuple(Val(length(bc.args))) do i + bc.args[i] isa AbstractGPUSparseVector && return bc.args[i].iPtr[1] + return one(Ti) + end + arg_last_rows = ntuple(Val(length(bc.args))) do i + bc.args[i] isa AbstractGPUSparseVector && return bc.args[i].iPtr[end] + return Ti(rows) + end + end + overall_first_row = min(arg_first_rows...) + overall_last_row = max(arg_last_rows...) + similar(ptr_array, Pair{Ti, NTuple{length(bc.args), Ti}}, overall_last_row - overall_first_row + 1) + end + let + args = if sparse_typ <: AbstractGPUSparseVector + (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...) + else + (sparse_typ, offsets, bc.args...) + end + kernel = compute_offsets_kernel(get_backend(bc.args[first(sparse_args)])) + kernel(args...; ndrange=length(offsets)) + end + # accumulate these values so that we can use them directly as row pointer offsets, + # as well as to get the total nnz count to allocate the sparse output array. + # cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort + if !(sparse_typ <: AbstractGPUSparseVector) + @allowscalar accumulate!(Base.add_sum, offsets, offsets) + total_nnz = @allowscalar last(offsets[end]) - 1 + else + @allowscalar sort!(offsets; by=first) + total_nnz = mapreduce(x->first(x) != typemax(first(x)), +, offsets) + end + output = if sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC} + ixVal = similar(offsets, Ti, total_nnz) + nzVal = similar(offsets, Tv, total_nnz) + sparse_typ(offsets, ixVal, nzVal, size(bc)) + elseif sparse_typ <: AbstractGPUSparseVector && !fpreszeros + val_array = bc.args[first(sparse_args)].nzVal + similar(val_array, Tv, size(bc)) + elseif sparse_typ <: AbstractGPUSparseVector && fpreszeros + iPtr = similar(offsets, Ti, total_nnz) + nzVal = similar(offsets, Tv, total_nnz) + _sparse_array_type(sparse_arg){Tv, Ti}(iPtr, nzVal, rows) + end + if sparse_typ <: AbstractGPUSparseVector && !fpreszeros + nonsparse_args = map(bc.args) do arg + # NOTE: this assumes the broadcst is flattened, but not yet preprocessed + if arg isa AbstractGPUSparseArray + zero(eltype(arg)) + else + arg + end + end + broadcast!(bc.f, output, nonsparse_args...) + end + end + # perform the actual broadcast + if output isa AbstractGPUSparseArray + args = (bc.f, output, offsets, bc.args...) + kernel = sparse_to_sparse_broadcast_kernel(get_backend(bc.args[first(sparse_args)])) + ndrange = output.nnz + else + args = sparse_typ <: AbstractGPUSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) : + (sparse_typ, bc.f, output, bc.args...) + kernel = sparse_to_dense_broadcast_kernel(get_backend(bc.args[first(sparse_args)])) + ndrange = sparse_typ <: AbstractGPUSparseMatrixCSC ? size(output, 2) : size(output, 1) + end + kernel(args...; ndrange) + return output +end diff --git a/test/Project.toml b/test/Project.toml index e6f21d04..a274db35 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,5 +10,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 766c2041..f59f07a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Distributed using Dates +using SparseArrays import REPL using Printf: @sprintf @@ -47,7 +48,12 @@ include("setup.jl") # make sure everything is precompiled # choose tests const tests = [] const test_runners = Dict() -for AT in (JLArray, Array), name in keys(TestSuite.tests) +for AT in (JLArray, Array), name in filter(n->n != "sparse", keys(TestSuite.tests)) + push!(tests, "$(AT)/$name") + test_runners["$(AT)/$name"] = ()->TestSuite.tests[name](AT) +end + +for AT in ( JLSparseMatrixCSR, JLSparseMatrixCSC, JLSparseVector, SparseMatrixCSC, SparseVector), name in ["sparse"] push!(tests, "$(AT)/$name") test_runners["$(AT)/$name"] = ()->TestSuite.tests[name](AT) end diff --git a/test/testsuite.jl b/test/testsuite.jl index f2ec6388..819b7d6e 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -97,6 +97,7 @@ include("testsuite/math.jl") include("testsuite/random.jl") include("testsuite/uniformscaling.jl") include("testsuite/statistics.jl") +include("testsuite/sparse.jl") include("testsuite/alloc_cache.jl") include("testsuite/jld2ext.jl") diff --git a/test/testsuite/sparse.jl b/test/testsuite/sparse.jl new file mode 100644 index 00000000..d9783e1e --- /dev/null +++ b/test/testsuite/sparse.jl @@ -0,0 +1,147 @@ +@testsuite "sparse" (AT, eltypes)->begin + if AT <: AbstractSparseVector + broadcasting_vector(AT, eltypes) + elseif AT <: AbstractSparseMatrix + broadcasting_matrix(AT, eltypes) + end +end + +using SparseArrays + +function broadcasting_vector(AT, eltypes) + dense_AT = GPUArrays._dense_array_type(AT) + dense_VT = GPUArrays._dense_vector_type(AT) + for ET in eltypes + @testset "SparseVector($ET)" begin + m = 64 + p = 0.5 + x = sprand(ET, m, p) + dx = AT(x) + + # zero-preserving + y = x .* ET(1) + dy = dx .* ET(1) + @test dy isa AT{ET} + @test collect(SparseArrays.nonzeroinds(dy)) == collect(SparseArrays.nonzeroinds(dx)) + @test collect(SparseArrays.nonzeroinds(dy)) == SparseArrays.nonzeroinds(y) + @test collect(SparseArrays.nonzeros(dy)) == SparseArrays.nonzeros(y) + @test y == SparseVector(dy) + + # not zero-preserving + y = x .+ ET(1) + dy = dx .+ ET(1) + @test dy isa dense_AT{ET} + hy = Array(dy) + @test Array(y) == hy + + # involving something dense + y = x .+ ones(ET, m) + dy = dx .+ dense_AT(ones(ET, m)) + @test dy isa dense_AT{ET} + @test Array(y) == Array(dy) + + # sparse to sparse + dx = AT(x) + y = sprand(ET, m, p) + dy = AT(y) + z = x .* y + dz = dx .* dy + @test dz isa AT{ET} + @test z == SparseVector(dz) + + # multiple inputs + y = sprand(ET, m, p) + w = sprand(ET, m, p) + dy = AT(y) + dx = AT(x) + dw = AT(w) + z = @. x * y * w + dz = @. dx * dy * dw + @test dz isa AT{ET} + @test z == SparseVector(dz) + + y = sprand(ET, m, p) + w = sprand(ET, m, p) + dense_arr = rand(ET, m) + d_dense_arr = dense_AT(dense_arr) + dy = AT(y) + dw = AT(w) + z = @. x * y * w * dense_arr + dz = @. dx * dy * dw * d_dense_arr + @test dz isa dense_AT{ET} + @test Array(z) == Array(dz) + + y = sprand(ET, m, p) + dy = AT(y) + dx = AT(x) + z = x .* y .* ET(2) + dz = dx .* dy .* ET(2) + @test dz isa AT{ET} + @test z == SparseVector(dz) + + # type-mismatching + ## non-zero-preserving + dx = AT(x) + dy = dx .+ 1 + y = x .+ 1 + @test dy isa dense_AT{promote_type(ET, Int)} + @test Array(y) == Array(dy) + ## zero-preserving + dy = dx .* 1 + y = x .* 1 + @test dy isa AT{promote_type(ET, Int)} + @test collect(SparseArrays.nonzeroinds(dy)) == collect(SparseArrays.nonzeroinds(dx)) + @test collect(SparseArrays.nonzeroinds(dy)) == SparseArrays.nonzeroinds(y) + @test collect(SparseArrays.nonzeros(dy)) == SparseArrays.nonzeros(y) + @test y == SparseVector(dy) + end + end +end + +function broadcasting_matrix(AT, eltypes) + dense_AT = GPUArrays._dense_array_type(AT) + dense_VT = GPUArrays._dense_vector_type(AT) + for ET in eltypes + @testset "SparseMatrix($ET)" begin + m, n = 5, 6 + p = 0.5 + x = sprand(ET, m, n, p) + dx = AT(x) + # zero-preserving + y = x .* ET(1) + dy = dx .* ET(1) + @test dy isa AT{ET} + @test y == SparseMatrixCSC(dy) + + # not zero-preserving + y = x .+ ET(1) + dy = dx .+ ET(1) + @test dy isa dense_AT{ET} + hy = Array(dy) + dense_y = Array(y) + @test Array(y) == Array(dy) + + # involving something dense + y = x .* ones(ET, m, n) + dy = dx .* dense_AT(ones(ET, m, n)) + @test dy isa dense_AT{ET} + @test Array(y) == Array(dy) + + # multiple inputs + y = sprand(ET, m, n, p) + dy = AT(y) + z = x .* y .* ET(2) + dz = dx .* dy .* ET(2) + @test dz isa AT{ET} + @test z == SparseMatrixCSC(dz) + + # multiple inputs + w = sprand(ET, m, n, p) + dw = AT(w) + z = x .* y .* w + dz = dx .* dy .* dw + @test dz isa AT{ET} + @test z == SparseMatrixCSC(dz) + end + end +end