Skip to content

Commit c2ad19b

Browse files
committed
Julia 1.12 matmatmul! dispatch
1 parent 7f6a693 commit c2ad19b

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

lib/mkl/linalg.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function LinearAlgebra.generic_matvecmul!(Y::oneVector, tA::AbstractChar, A::one
104104
end
105105
end
106106
end
107-
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, MulAddMul(alpha, beta))
107+
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, alpha, beta)
108108
end
109109

110110
# triangular
@@ -120,14 +120,24 @@ LinearAlgebra.generic_trimatdiv!(C::oneStridedVector{T}, uploc, isunitc, tfun::F
120120
# BLAS 3
121121
#
122122

123+
if VERSION >= v"1.12-"
124+
# Otherwise dispatches onto:
125+
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/4e7c3f40316a956119ac419a97c4b8aad7a17e6c/src/matmul.jl#L490
126+
for blas_flag in (LinearAlgebra.BlasFlag.SyrkHerkGemm, LinearAlgebra.BlasFlag.SymmHemmGeneric)
127+
@eval LinearAlgebra.generic_matmatmul_wrapper!(
128+
C::oneStridedMatrix, tA::AbstractChar, tB::AbstractChar, A::oneStridedVecOrMat, B::oneStridedVecOrMat,
129+
alpha::Number, beta::Number, ::$blas_flag) =
130+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
131+
end
132+
end
133+
123134
LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStridedVecOrMat, B::oneStridedVecOrMat, _add::MulAddMul=MulAddMul()) =
124135
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
125136
function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStridedVecOrMat, B::oneStridedVecOrMat, a::Number, b::Number)
126-
T = eltype(C)
127137
alpha, beta = promote(a, b, zero(T))
128138
mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1)
129139
mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)
130-
140+
T = eltype(C)
131141
if nA != mB
132142
throw(DimensionMismatch("A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
133143
end

0 commit comments

Comments
 (0)