338338
339339
340340# # matrix multiplication
341-
342- function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , a:: Number , b:: Number ) where {T,S,R}
341+ # legacy method
342+ generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
343+ generic_matmatmul! (C, A, B, MulAddMul (a, b))
344+ function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
343345 if size (A,2 ) != size (B,1 )
344346 throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
345347 end
@@ -350,20 +352,18 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
350352 return fill! (C, zero (R))
351353 end
352354
353- add = MulAddMul (a, b)
354-
355355 gpu_call (C, A, B; name= " matmatmul!" ) do ctx, C, A, B
356356 idx = @linearidx C
357357 assume .(size (C) .> 0 )
358358 i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
359359
360360 @inbounds if i <= size (A,1 ) && j <= size (B,2 )
361361 z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
362- Ctmp = convert (promote_type (R, typeof (z2)), z2)
362+ Cij = convert (promote_type (R, typeof (z2)), z2)
363363 for k in 1 : size (A,2 )
364- Ctmp += A[i, k]* B[k, j]
364+ Cij += A[i, k]* B[k, j]
365365 end
366- C[i,j] = add (Ctmp , C[i,j])
366+ C[i,j] = add (Cij , C[i,j])
367367 end
368368
369369 return
@@ -372,42 +372,229 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
372372 C
373373end
374374
375+ @static if VERSION < v " 1.12.0-"
375376function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , _add:: MulAddMul = MulAddMul ())
376- generic_matmatmul! (C, wrap (A, tA), B, _add. alpha, _add . beta )
377+ generic_matmatmul! (C, wrap (A, tA), B, _add)
377378end
378379
379380function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
380- generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add. alpha, _add. beta)
381+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
382+ end
383+ else
384+ function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , a:: Number , b:: Number )
385+ LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), B, MulAddMul (a, b))
386+ end
387+
388+ function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , a:: Number , b:: Number )
389+ LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (a, b))
390+ end
391+ end
392+
393+ function generic_trimatmul! (C:: AbstractGPUVecOrMat{R} , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVecOrMat{S} ) where {T,S,R}
394+ if size (A,2 ) != size (B,1 )
395+ throw (DimensionMismatch (lazy " matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))" ))
396+ end
397+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
398+ throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))" ))
399+ end
400+ if isempty (A) || isempty (B)
401+ return fill! (C, zero (R))
402+ end
403+
404+ upper = tfun === identity ? uploc == ' U' : uploc != ' U'
405+ unit = isunitc == ' U'
406+
407+ function trimatmul (ctx, C, A, B)
408+ idx = @linearidx C
409+ assume .(size (C) .> 0 )
410+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
411+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
412+
413+ @inbounds if i <= l && j <= n
414+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
415+ Cij = convert (promote_type (R, typeof (z2)), z2)
416+ Cij += (unit ? one (Cij) : A[i,i]) * B[i,j]
417+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
418+ Cij += A[i,k] * B[k,j]
419+ end
420+ C[i,j] += Cij
421+ end
422+
423+ return
424+ end
425+
426+ function trimatmul_t (ctx, C, A, B)
427+ idx = @linearidx C
428+ assume .(size (C) .> 0 )
429+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
430+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
431+
432+ @inbounds if i <= l && j <= n
433+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
434+ Cij = convert (promote_type (R, typeof (z2)), z2)
435+ Cij += (unit ? one (Cij) : transpose (A[i,i])) * B[i,j]
436+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
437+ Cij += transpose (A[k,i]) * B[k,j]
438+ end
439+ C[i,j] += Cij
440+ end
441+
442+ return
443+ end
444+
445+ function trimatmul_a (ctx, C, A, B)
446+ idx = @linearidx C
447+ assume .(size (C) .> 0 )
448+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
449+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
450+
451+ @inbounds if i <= l && j <= n
452+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
453+ Cij = convert (promote_type (R, typeof (z2)), z2)
454+ Cij += (unit ? one (Cij) : adjoint (A[i,i])) * B[i,j]
455+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
456+ Cij += adjoint (A[k,i]) * B[k,j]
457+ end
458+ C[i,j] += Cij
459+ end
460+
461+ return
462+ end
463+
464+ if tfun === identity
465+ gpu_call (trimatmul, C, A, B; name= " trimatmul" )
466+ elseif tfun == transpose
467+ gpu_call (trimatmul_t, C, A, B; name= " trimatmul_t" )
468+ elseif tfun === adjoint
469+ gpu_call (trimatmul_a, C, A, B; name= " trimatmul_a" )
470+ else
471+ error (" Not supported" )
472+ end
473+
474+ C
475+ end
476+
477+ function generic_mattrimul! (C:: AbstractGPUVecOrMat{R} , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVecOrMat{S} ) where {T,S,R}
478+ if size (A,2 ) != size (B,1 )
479+ throw (DimensionMismatch (lazy " matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))" ))
480+ end
481+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
482+ throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))" ))
483+ end
484+ if isempty (A) || isempty (B)
485+ return fill! (C, zero (R))
486+ end
487+
488+ upper = tfun === identity ? uploc == ' U' : uploc != ' U'
489+ unit = isunitc == ' U'
490+
491+ function mattrimul (ctx, C, A, B)
492+ idx = @linearidx C
493+ assume .(size (C) .> 0 )
494+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
495+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
496+
497+ @inbounds if i <= l && j <= n
498+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
499+ Cij = convert (promote_type (R, typeof (z2)), z2)
500+ Cij += A[i,j] * (unit ? one (Cij) : B[j,j])
501+ for k in (upper ? 1 : (j + 1 )): (upper ? (j - 1 ) : m)
502+ Cij += A[i,k] * B[k,j]
503+ end
504+ C[i,j] += Cij
505+ end
506+
507+ return
508+ end
509+
510+ function mattrimul_t (ctx, C, A, B)
511+ idx = @linearidx C
512+ assume .(size (C) .> 0 )
513+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
514+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
515+
516+ @inbounds if i <= l && j <= n
517+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
518+ Cij = convert (promote_type (R, typeof (z2)), z2)
519+ Cij += A[i,j] * (unit ? one (Cij) : transpose (B[j,j]))
520+ for k in (upper ? 1 : (j + 1 ) ): (upper ? (j - 1 ) : m)
521+ Cij += A[i,k] * transpose (B[j,k])
522+ end
523+ C[i,j] += Cij
524+ end
525+
526+ return
527+ end
528+
529+ function mattrimul_a (ctx, C, A, B)
530+ idx = @linearidx C
531+ assume .(size (C) .> 0 )
532+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
533+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
534+
535+ @inbounds if i <= l && j <= n
536+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
537+ Cij = convert (promote_type (R, typeof (z2)), z2)
538+ Cij += A[i,j] * (unit ? one (Cij) : adjoint (B[j,j]))
539+ for k in (upper ? 1 : (j + 1 )): (upper ? (j - 1 ) : m)
540+ Cij += A[i,k] * adjoint (B[j,k])
541+ end
542+ C[i,j] += Cij
543+ end
544+
545+ return
546+ end
547+
548+ if tfun === identity
549+ gpu_call (mattrimul, C, A, B; name= " mattrimul" )
550+ elseif tfun == transpose
551+ gpu_call (mattrimul_t, C, A, B; name= " mattrimul_t" )
552+ elseif tfun === adjoint
553+ gpu_call (mattrimul_a, C, A, B; name= " mattrimul_a" )
554+ else
555+ error (" Not supported" )
556+ end
557+
558+ C
559+ end
560+
561+ if VERSION >= v " 1.10-"
562+ function LinearAlgebra. generic_trimatmul! (C:: AbstractGPUVecOrMat , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix , B:: AbstractGPUVecOrMat )
563+ generic_trimatmul! (C, uploc, isunitc, tfun, A, B)
564+ end
565+ function LinearAlgebra. generic_mattrimul! (C:: AbstractGPUMatrix , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix , B:: AbstractGPUMatrix )
566+ generic_mattrimul! (C, uploc, isunitc, tfun, A, B)
567+ end
381568end
382569
383570if VERSION < v " 1.10.0-DEV.1365"
384571# catch other functions that are called by LinearAlgebra's mul!
385572function LinearAlgebra. gemv! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , a:: Number , b:: Number )
386- generic_matmatmul! (C, wrap (A, tA), B, a, b)
573+ generic_matmatmul! (C, wrap (A, tA), B, MulAddMul ( a, b) )
387574end
388575# disambiguation
389576function LinearAlgebra. gemv! (C:: AbstractGPUVector{T} , tA:: AbstractChar , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVector{T} , a:: Number , b:: Number ) where {T<: LinearAlgebra.BlasFloat }
390- generic_matmatmul! (C, wrap (A, tA), B, a, b)
577+ generic_matmatmul! (C, wrap (A, tA), B, MulAddMul ( a, b) )
391578end
392579
393580LinearAlgebra. gemm_wrapper! (C:: AbstractGPUVecOrMat , tA:: AbstractChar , tB:: AbstractChar , A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , _add:: MulAddMul ) =
394- LinearAlgebra . generic_matmatmul! (C, tA, tB, A, B , _add)
581+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
395582# disambiguation
396583LinearAlgebra. gemm_wrapper! (C:: AbstractGPUVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar , A:: AbstractGPUVecOrMat{T} , B:: AbstractGPUVecOrMat{T} , _add:: MulAddMul ) where {T<: LinearAlgebra.BlasFloat } =
397- LinearAlgebra . generic_matmatmul! (C, tA, tB, A, B , _add)
584+ generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
398585
399586function LinearAlgebra. syrk_wrapper! (C:: AbstractGPUMatrix , tA:: AbstractChar , A:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
400587 if tA == ' T'
401- LinearAlgebra . generic_matmatmul! (C, ' T ' , ' N ' , A , A, _add)
588+ generic_matmatmul! (C, wrap (A , ' T ' ) , A, _add)
402589 else # tA == 'N'
403- LinearAlgebra . generic_matmatmul! (C, ' N ' , ' T ' , A, A , _add)
590+ generic_matmatmul! (C, A, wrap ( A, ' T ' ) , _add)
404591 end
405592end
406593function LinearAlgebra. herk_wrapper! (C:: AbstractGPUMatrix , tA:: AbstractChar , A:: AbstractGPUVecOrMat , _add:: MulAddMul = MulAddMul ())
407594 if tA == ' C'
408- LinearAlgebra . generic_matmatmul! (C, ' C ' , ' N ' , A , A, _add)
595+ generic_matmatmul! (C, wrap (A , ' C ' ) , A, _add)
409596 else # tA == 'N'
410- LinearAlgebra . generic_matmatmul! (C, ' N ' , ' C ' , A, A , _add)
597+ generic_matmatmul! (C, A, wrap ( A, ' C ' ) , _add)
411598 end
412599end
413600end # VERSION
0 commit comments