Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ext/ReactantKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ReactantKernelAbstractionsExt

using Reactant: Reactant
using ReactantCore: ReactantCore

using Adapt: Adapt
using KernelAbstractions: KernelAbstractions
Expand Down Expand Up @@ -101,6 +102,14 @@ function tokw(ndrange, workgroupsize, obj, args...)
end

function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing)
# If we're already inside a compilation/tracing context, or if any arguments are traced,
# we should trace through this kernel call instead of trying to compile it again.
if Reactant.within_compile() || any(ReactantCore.is_traced, args)
return Reactant.call_with_reactant(
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
)
end

if Reactant.precompiling()
Reactant.@code_hlo optimize = false tokw(ndrange, workgroupsize, obj, args...)
else
Expand Down
2 changes: 2 additions & 0 deletions ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ using SparseArrays:
include("Errors.jl")
include("ReadOnly.jl")

Reactant.use_overlayed_version(::AbstractSparseArray) = false

end
9 changes: 5 additions & 4 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,19 @@ for (cT, aT, bT) in (
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)
A, B = aos_to_soa(A), aos_to_soa(B)
A2, B2 = aos_to_soa(A), aos_to_soa(B)
C2 = aos_to_soa(C)
if use_overlayed_version((C2, A, B))
TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β)
# A2 can also be a SparseMatrix, which should be handled by its own methods
if use_overlayed_version(A2) && use_overlayed_version((C2, A2, B2))
TracedLinearAlgebra.overloaded_mul!(C2, A2, B2, α, β)
if C2 !== C
C .= C2
end
else
# Inference barrier is required when calling function recursively within
# overload. This is required since otherwise type inference will think this
# is a recursive edge rather than a call to the base method
Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β)
Base.inferencebarrier(LinearAlgebra.mul!)(C2, A2, B2, α, β)
end
return C
end
Expand Down
Loading