@@ -49,6 +49,13 @@ function check_mul_axes(A, B, C...)
4949 check_mul_axes (B, C... )
5050end
5151
52+ # we need to special case AbstractQ as it allows non-compatiple multiplication
53+ function check_mul_axes (A:: AbstractQ , B, C... )
54+ axes (A. factors, 1 ) == axes (B, 1 ) || axes (A. factors, 2 ) == axes (B, 1 ) ||
55+ throw (DimensionMismatch (" First axis of B, $(axes (B,1 )) must match either axes of A, $(axes (A)) " ))
56+ check_mul_axes (B, C... )
57+ end
58+
5259
5360function instantiate (M:: MulAdd )
5461 @boundscheck check_mul_axes (M. α, M. A, M. B)
@@ -373,4 +380,85 @@ copy(M::MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout})
373380copy (M:: MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout}} ) = (M. α * getindex_value (M. B. diag)) .* M. A .+ M. β .* M. C
374381copy (M:: MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout},ZerosLayout} ) = (M. α * getindex_value (M. B. diag)) .* M. A
375382
376- BroadcastStyle (:: Type{<:MulAdd} ) = ApplyBroadcastStyle ()
383+ BroadcastStyle (:: Type{<:MulAdd} ) = ApplyBroadcastStyle ()
384+
385+ scalarone (:: Type{T} ) where T = one (T)
386+ scalarone (:: Type{<:AbstractArray{T}} ) where T = scalarone (T)
387+ scalarzero (:: Type{T} ) where T = zero (T)
388+ scalarzero (:: Type{<:AbstractArray{T}} ) where T = scalarzero (T)
389+
390+ fillzeros (:: Type{T} , ax) where T = Zeros {T} (ax)
391+
392+ function mul! (dest:: AbstractArray{W} , A:: AbstractArray{T} , b:: AbstractArray{V} ) where {T,V,W}
393+ TVW = promote_type (W, _mul_eltype (T,V))
394+ muladd! (scalarone (TVW), A, b, scalarzero (TVW), dest)
395+ end
396+
397+ function MulAdd (A:: AbstractArray{T} , B:: AbstractVector{V} ) where {T,V}
398+ TV = _mul_eltype (eltype (A), eltype (B))
399+ MulAdd (scalarone (TV), A, B, scalarzero (TV), fillzeros (TV,(axes (A,1 ))))
400+ end
401+
402+ function MulAdd (A:: AbstractArray{T} , B:: AbstractMatrix{V} ) where {T,V}
403+ TV = _mul_eltype (eltype (A), eltype (B))
404+ MulAdd (scalarone (TV), A, B, scalarzero (TV), fillzeros (TV,(axes (A,1 ),axes (B,2 ))))
405+ end
406+
407+ mul (A:: AbstractArray , B:: AbstractArray ) = materialize (MulAdd (A,B))
408+
409+ macro lazymul (Typ)
410+ ret = quote
411+ LinearAlgebra. mul! (dest:: AbstractVector , A:: $Typ , b:: AbstractVector ) =
412+ ArrayLayouts. mul! (dest,A,b)
413+
414+ LinearAlgebra. mul! (dest:: AbstractMatrix , A:: $Typ , b:: AbstractMatrix ) =
415+ ArrayLayouts. mul! (dest,A,b)
416+ LinearAlgebra. mul! (dest:: AbstractMatrix , A:: $Typ , b:: $Typ ) =
417+ ArrayLayouts. mul! (dest,A,b)
418+
419+ Base.:* (A:: $Typ , B:: $Typ ) = ArrayLayouts. mul (A,B)
420+ Base.:* (A:: $Typ , B:: AbstractMatrix ) = ArrayLayouts. mul (A,B)
421+ Base.:* (A:: $Typ , B:: AbstractVector ) = ArrayLayouts. mul (A,B)
422+ Base.:* (A:: AbstractMatrix , B:: $Typ ) = ArrayLayouts. mul (A,B)
423+ Base.:* (A:: LinearAlgebra.AdjointAbsVec , B:: $Typ ) = ArrayLayouts. mul (A,B)
424+ Base.:* (A:: LinearAlgebra.TransposeAbsVec , B:: $Typ ) = ArrayLayouts. mul (A,B)
425+
426+ Base.:* (A:: LinearAlgebra.AbstractQ , B:: $Typ ) = ArrayLayouts. lmul (A,B)
427+ Base.:* (A:: $Typ , B:: LinearAlgebra.AbstractQ ) = ArrayLayouts. rmul (A,B)
428+ end
429+ for Struc in (:AbstractTriangular , :Diagonal )
430+ ret = quote
431+ $ ret
432+
433+ Base.:* (A:: LinearAlgebra. $ Struc, B:: $Typ ) = ArrayLayouts. mul (A,B)
434+ Base.:* (A:: $Typ , B:: LinearAlgebra. $ Struc) = ArrayLayouts. mul (A,B)
435+ end
436+ end
437+ for Mod in (:Adjoint , :Transpose , :Symmetric , :Hermitian )
438+ ret = quote
439+ $ ret
440+
441+ LinearAlgebra. mul! (dest:: AbstractMatrix , A:: $Typ , b:: $Mod{<:Any,<:AbstractMatrix} ) =
442+ ArrayLayouts. mul! (dest,A,b)
443+
444+ LinearAlgebra. mul! (dest:: AbstractVector , A:: $Mod{<:Any,<:$Typ} , b:: AbstractVector ) =
445+ ArrayLayouts. mul! (dest,A,b)
446+
447+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
448+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: AbstractMatrix ) = ArrayLayouts. mul (A,B)
449+ Base.:* (A:: AbstractMatrix , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
450+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: AbstractVector ) = ArrayLayouts. mul (A,B)
451+
452+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: $Typ ) = ArrayLayouts. mul (A,B)
453+ Base.:* (A:: $Typ , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
454+
455+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: Diagonal ) = ArrayLayouts. mul (A,B)
456+ Base.:* (A:: Diagonal , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
457+
458+ Base.:* (A:: LinearAlgebra.AbstractTriangular , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
459+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: LinearAlgebra.AbstractTriangular ) = ArrayLayouts. mul (A,B)
460+ end
461+ end
462+
463+ esc (ret)
464+ end
0 commit comments