Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
Expand All @@ -33,5 +34,6 @@ Random = "1"
Reexport = "1"
ScopedValues = "1"
Serialization = "1"
SparseArrays = "1"
Statistics = "1"
julia = "1.10"
4 changes: 4 additions & 0 deletions lib/JLArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ version = "0.3.0"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
Adapt = "2.0, 3.0, 4.0"
GPUArrays = "11.1"
KernelAbstractions = "0.9, 0.10"
LinearAlgebra = "1"
Random = "1"
SparseArrays = "1"
julia = "1.8"
131 changes: 130 additions & 1 deletion lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

module JLArrays

export JLArray, JLVector, JLMatrix, jl, JLBackend
export JLArray, JLVector, JLMatrix, jl, JLBackend, JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR

using GPUArrays

using Adapt
using SparseArrays, LinearAlgebra

import GPUArrays: _dense_array_type

import KernelAbstractions
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
Expand Down Expand Up @@ -115,7 +118,90 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
end
end

mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, Ti}
iPtr::JLArray{Ti, 1}
nzVal::JLArray{Tv, 1}
len::Int
nnz::Ti

function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
len::Integer) where {Tv, Ti <: Integer}
new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
end
end
SparseArrays.SparseVector(x::JLSparseVector) = SparseVector(length(x), Array(x.iPtr), Array(x.nzVal))
SparseArrays.nnz(x::JLSparseVector) = x.nnz
SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal

mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti}
colPtr::JLArray{Ti, 1}
rowVal::JLArray{Ti, 1}
nzVal::JLArray{Tv, 1}
dims::NTuple{2,Int}
nnz::Ti

function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
end
end
function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
end
SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal))

JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A

function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti}
r1 = Int(@inbounds A.colPtr[j])
r2 = Int(@inbounds A.colPtr[j+1]-1)
(r1 > r2) && return zero(Tv)
r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1
((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
end

mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti}
rowPtr::JLArray{Ti, 1}
colVal::JLArray{Ti, 1}
nzVal::JLArray{Tv, 1}
dims::NTuple{2,Int}
nnz::Ti

function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer}
new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
end
end
function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
end
function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
return SparseMatrixCSC(transpose(x_transpose))
end

JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A

GPUArrays.storage(a::JLArray) = a.data
GPUArrays._dense_array_type(a::JLArray{T, N}) where {T, N} = JLArray{T, N}
GPUArrays._dense_array_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, N}
GPUArrays._dense_vector_type(a::JLArray{T, N}) where {T, N} = JLArray{T, 1}
GPUArrays._dense_vector_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, 1}

GPUArrays._sparse_array_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSC
GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSC}) = JLSparseMatrixCSC
GPUArrays._sparse_array_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSR
GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
GPUArrays._sparse_array_type(sa::JLSparseVector) = JLSparseVector
GPUArrays._sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector

GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray
GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray
GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray
GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray
GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray

# conversion of untyped data to a typed Array
function typed_data(x::JLArray{T}) where {T}
Expand Down Expand Up @@ -217,6 +303,41 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
(::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x)
JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)

function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv}
iPtr = JLVector{Ti}(undef, length(xs.nzind))
nzVal = JLVector{Tv}(undef, length(xs.nzval))
copyto!(iPtr, convert(Vector{Ti}, xs.nzind))
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),)
end
Base.length(x::JLSparseVector) = x.len
Base.size(x::JLSparseVector) = (x.len,)

function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
colPtr = JLVector{Ti}(undef, length(xs.colptr))
rowVal = JLVector{Ti}(undef, length(xs.rowval))
nzVal = JLVector{Tv}(undef, length(xs.nzval))
copyto!(colPtr, convert(Vector{Ti}, xs.colptr))
copyto!(rowVal, convert(Vector{Ti}, xs.rowval))
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
end
Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
Base.size(x::JLSparseMatrixCSC) = x.dims

function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
csr_xs = SparseMatrixCSC(transpose(xs))
rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr))
colVal = JLVector{Ti}(undef, length(csr_xs.rowval))
nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr))
copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval))
copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n))
end
Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
Base.size(x::JLSparseMatrixCSR) = x.dims

# idempotency
JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs

Expand Down Expand Up @@ -358,9 +479,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
R
end

Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} =
GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} =
GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)

## KernelAbstractions interface

KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend()

function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
Expand Down
2 changes: 2 additions & 0 deletions src/GPUArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using KernelAbstractions

# device functionality
include("device/abstractarray.jl")
include("device/sparse.jl")

# host abstractions
include("host/abstractarray.jl")
Expand All @@ -34,6 +35,7 @@ include("host/random.jl")
include("host/quirks.jl")
include("host/uniformscaling.jl")
include("host/statistics.jl")
include("host/sparse.jl")
include("host/alloc_cache.jl")


Expand Down
135 changes: 135 additions & 0 deletions src/device/sparse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# on-device sparse array types
# should be excluded from coverage counts
# COV_EXCL_START
using SparseArrays

# NOTE: this functionality is currently very bare-bones, only defining the array types
# without any device-compatible sparse array functionality


# core types

export GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR,
GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO

abstract type AbstractGPUSparseDeviceMatrix{Tv, Ti} <: AbstractSparseMatrix{Tv, Ti} end


struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti}
iPtr::Vi
nzVal::Vv
len::Int
nnz::Ti
end
Comment on lines +18 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit inconsistent that we keep the host sparse object layout to the back-end, but define the device one concretely. I'm not sure if it's better to entirely move the definitions away from (or rather into) GPUArrays.jl though. I guess back-ends may want additional control over the object layout in order to facilitate vendor library interactions, but maybe we should then also leave the device-side version up to the back-end and only implement things here in terms of SparseArrays interfaces (rowvals, getcolptr, etc). Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was torn on this too. The advantage here is that libraries get a working device-side implementation "for free" -- they are able to implement their own (better) one and just give Adapt.jl information about how to move their host-side structs to it.


Base.length(g::GPUSparseDeviceVector) = g.len
Base.size(g::GPUSparseDeviceVector) = (g.len,)
SparseArrays.nnz(g::GPUSparseDeviceVector) = g.nnz
SparseArrays.nonzeroinds(g::GPUSparseDeviceVector) = g.iPtr
SparseArrays.nonzeros(g::GPUSparseDeviceVector) = g.nzVal

struct GPUSparseDeviceMatrixCSC{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
colPtr::Vi
rowVal::Vi
nzVal::Vv
dims::NTuple{2,Int}
nnz::Ti
end

SparseArrays.rowvals(g::GPUSparseDeviceMatrixCSC) = g.rowVal
SparseArrays.getcolptr(g::GPUSparseDeviceMatrixCSC) = g.colPtr
SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1)

struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
rowPtr::Vi
colVal::Vi
nzVal::Vv
dims::NTuple{2, Int}
nnz::Ti
end

struct GPUSparseDeviceMatrixBSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
rowPtr::Vi
colVal::Vi
nzVal::Vv
dims::NTuple{2,Int}
blockDim::Ti
dir::Char
nnz::Ti
end

struct GPUSparseDeviceMatrixCOO{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
rowInd::Vi
colInd::Vi
nzVal::Vv
dims::NTuple{2,Int}
nnz::Ti
end

Base.length(g::AbstractGPUSparseDeviceMatrix) = prod(g.dims)
Base.size(g::AbstractGPUSparseDeviceMatrix) = g.dims
SparseArrays.nnz(g::AbstractGPUSparseDeviceMatrix) = g.nnz
SparseArrays.getnzval(g::AbstractGPUSparseDeviceMatrix) = g.nzVal

struct GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M} <: AbstractSparseArray{Tv, Ti, N}
rowPtr::Vi
colVal::Vi
nzVal::Vv
dims::NTuple{N, Int}
nnz::Ti
end

function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, Vi<:AbstractDeviceArray{<:Integer,M}, Vv<:AbstractDeviceArray{Tv, M}, N}
@assert M == N - 1 "GPUSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
end

Base.length(g::GPUSparseDeviceArrayCSR) = prod(g.dims)
Base.size(g::GPUSparseDeviceArrayCSR) = g.dims
SparseArrays.nnz(g::GPUSparseDeviceArrayCSR) = g.nnz
SparseArrays.getnzval(g::GPUSparseDeviceArrayCSR) = g.nzVal

# input/output

function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceVector)
println(io, "$(length(A))-element device sparse vector at:")
println(io, " iPtr: $(A.iPtr)")
print(io, " nzVal: $(A.nzVal)")
end

function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSR)
println(io, "$(length(A))-element device sparse matrix CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
print(io, " nzVal: $(A.nzVal)")
end

function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSC)
println(io, "$(length(A))-element device sparse matrix CSC at:")
println(io, " colPtr: $(A.colPtr)")
println(io, " rowVal: $(A.rowVal)")
print(io, " nzVal: $(A.nzVal)")
end

function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixBSR)
println(io, "$(length(A))-element device sparse matrix BSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
print(io, " nzVal: $(A.nzVal)")
end

function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCOO)
println(io, "$(length(A))-element device sparse matrix COO at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colInd: $(A.colInd)")
print(io, " nzVal: $(A.nzVal)")
end

function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceArrayCSR)
println(io, "$(length(A))-element device sparse array CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
print(io, " nzVal: $(A.nzVal)")
end

# COV_EXCL_STOP
Loading
Loading