1919 MulAdd {StyleA,StyleB,StyleC} (α, A, B, β, C)
2020end
2121
22- @inline MulAdd (α, A:: AA , B:: BB , β, C:: CC ) where {AA,BB,CC} =
22+ @inline MulAdd (α, A:: AA , B:: BB , β, C:: CC ) where {AA,BB,CC} =
2323 MulAdd {typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))} (α, A, B, β, C)
2424
2525MulAdd (A, B) = MulAdd (Mul (A, B))
2626function MulAdd (M:: Mul )
2727 TV = eltype (M)
28- MulAdd (scalarone (TV), M. A, M. B, scalarzero (TV), fillzeros (TV,axes (M) ))
28+ MulAdd (scalarone (TV), M. A, M. B, scalarzero (TV), mulzeros (TV,M ))
2929end
3030
3131@inline eltype (:: MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC} ) where {StyleA,StyleB,StyleC,T,AA,BB,CC} =
@@ -69,18 +69,11 @@ muladd!(α, A, B, β, C) = materialize!(MulAdd(α, A, B, β, C))
6969materialize (M:: MulAdd ) = copy (instantiate (M))
7070copy (M:: MulAdd ) = copyto! (similar (M), M)
7171
72- @inline function copyto! (dest:: AbstractArray{T} , M:: MulAdd ) where T
73- M. C === dest || copyto! (dest, M. C)
74- muladd! (M. α, M. A, M. B, M. β, dest)
75- end
72+ _fill_copyto! (dest, C) = copyto! (dest, C)
73+ _fill_copyto! (dest, C:: Zeros ) = zero! (dest) # exploit special fill! overload
7674
77- @inline function copyto! (dest:: AbstractArray{T} , M:: MulAdd{<:Any,<:Any,ZerosLayout} ) where T
78- α,A,B,β,C = M. α, M. A, M. B, M. β, M. C
79- if ! isbitstype (T) # instantiate
80- dest .= β .* view (A,:,1 ) .* Ref (B[1 ]) # get shape right
81- end
82- muladd! (α, A, B, β, dest)
83- end
75+ @inline copyto! (dest:: AbstractArray{T} , M:: MulAdd ) where T =
76+ muladd! (M. α, unalias (dest,M. A), unalias (dest,M. B), M. β, _fill_copyto! (dest, M. C))
8477
8578# Modified from LinearAlgebra._generic_matmatmul!
8679function tile_size (T, S, R)
@@ -226,32 +219,28 @@ function _default_blasmul!(::IndexCartesian, α, A::AbstractMatrix, B::AbstractV
226219 C
227220end
228221
229- default_blasmul! (α, A:: AbstractMatrix , B:: AbstractVector , β, C:: AbstractVector ) =
222+ default_blasmul! (α, A:: AbstractMatrix , B:: AbstractVector , β, C:: AbstractVector ) =
230223 _default_blasmul! (Base. IndexStyle (typeof (A)), α, A, B, β, C)
231224
232225function materialize! (M:: MatMulMatAdd )
233226 α, A, B, β, C = M. α, M. A, M. B, M. β, M. C
234- if C ≡ B
235- B = copy (B)
236- end
237- default_blasmul! (α, A, B, iszero (β) ? false : β, C)
227+ default_blasmul! (α, unalias (C,A), unalias (C,B), iszero (β) ? false : β, C)
238228end
239229
240230function materialize! (M:: MatMulMatAdd{<:AbstractStridedLayout,<:AbstractStridedLayout,<:AbstractStridedLayout} )
241- α, A, B, β, C = M. α, M. A, M. B, M. β, M. C
242- if C ≡ B
243- B = copy (B)
244- end
231+ α, Ain, Bin, β, C = M. α, M. A, M. B, M. β, M. C
232+ A = unalias (C, Ain)
233+ B = unalias (C, Bin)
245234 ts = tile_size (eltype (A), eltype (B), eltype (C))
246235 if iszero (β) # false is a "strong" zero to wipe out NaNs
247236 if ts == 0 || ! (axes (A) isa NTuple{2 ,OneTo{Int}}) || ! (axes (B) isa NTuple{2 ,OneTo{Int}}) || ! (axes (C) isa NTuple{2 ,OneTo{Int}})
248- default_blasmul! (α, A, B, false , C)
249- else
237+ default_blasmul! (α, A, B, false , C)
238+ else
250239 tiled_blasmul! (ts, α, A, B, false , C)
251240 end
252241 else
253242 if ts == 0 || ! (axes (A) isa NTuple{2 ,OneTo{Int}}) || ! (axes (B) isa NTuple{2 ,OneTo{Int}}) || ! (axes (C) isa NTuple{2 ,OneTo{Int}})
254- default_blasmul! (α, A, B, β, C)
243+ default_blasmul! (α, A, B, β, C)
255244 else
256245 tiled_blasmul! (ts, α, A, B, β, C)
257246 end
@@ -260,29 +249,11 @@ end
260249
261250function materialize! (M:: MatMulVecAdd )
262251 α, A, B, β, C = M. α, M. A, M. B, M. β, M. C
263- if C ≡ B
264- B = copy (B)
265- end
266- default_blasmul! (α, A, B, iszero (β) ? false : β, C)
252+ default_blasmul! (α, unalias (C,A), unalias (C,B), iszero (β) ? false : β, C)
267253end
268254
269- # make copy to make sure always works
270- @inline function _gemv! (tA, α, A, x, β, y)
271- if x ≡ y
272- BLAS. gemv! (tA, α, A, copy (x), β, y)
273- else
274- BLAS. gemv! (tA, α, A, x, β, y)
275- end
276- end
277-
278- # make copy to make sure always works
279- @inline function _gemm! (tA, tB, α, A, B, β, C)
280- if B ≡ C
281- BLAS. gemm! (tA, tB, α, A, copy (B), β, C)
282- else
283- BLAS. gemm! (tA, tB, α, A, B, β, C)
284- end
285- end
255+ @inline _gemv! (tA, α, A, x, β, y) = BLAS. gemv! (tA, α, unalias (y,A), unalias (y,x), β, y)
256+ @inline _gemm! (tA, tB, α, A, B, β, C) = BLAS. gemm! (tA, tB, α, unalias (C,A), unalias (C,B), β, C)
286257
287258# work around pointer issues
288259@inline materialize! (M:: BlasMatMulVecAdd{<:AbstractColumnMajor,<:AbstractStridedLayout,<:AbstractStridedLayout} ) =
350321# ##
351322
352323# make copy to make sure always works
353- @inline function _symv! (tA, α, A, x, β, y)
354- if x ≡ y
355- BLAS. symv! (tA, α, A, copy (x), β, y)
356- else
357- BLAS. symv! (tA, α, A, x, β, y)
358- end
359- end
360-
361- @inline function _hemv! (tA, α, A, x, β, y)
362- if x ≡ y
363- BLAS. hemv! (tA, α, A, copy (x), β, y)
364- else
365- BLAS. hemv! (tA, α, A, x, β, y)
366- end
367- end
324+ @inline _symv! (tA, α, A, x, β, y) = BLAS. symv! (tA, α, unalias (y,A), unalias (y,x), β, y)
325+ @inline _hemv! (tA, α, A, x, β, y) = BLAS. hemv! (tA, α, unalias (y,A), unalias (y,x), β, y)
368326
369327
370328materialize! (M:: BlasMatMulVecAdd{<:SymmetricLayout{<:AbstractColumnMajor},<:AbstractStridedLayout,<:AbstractStridedLayout} ) =
@@ -411,10 +369,28 @@ scalarone(::Type{<:AbstractArray{T}}) where T = scalarone(T)
411369scalarzero (:: Type{T} ) where T = zero (T)
412370scalarzero (:: Type{<:AbstractArray{T}} ) where T = scalarzero (T)
413371
414- fillzeros (:: Type{T} , ax) where T = Zeros {T} (ax)
372+ fillzeros (:: Type{T} , ax) where T<: Number = Zeros {T} (ax)
373+ mulzeros (:: Type{T} , M) where T<: Number = fillzeros (T, axes (M))
374+
375+ # initiate array-valued MulAdd
376+ function _mulzeros! (dest:: AbstractVector{T} , A, B) where T
377+ for k in axes (dest,1 )
378+ dest[k] = similar (Mul (A[k,1 ],B[1 ]), eltype (T))
379+ end
380+ dest
381+ end
382+
383+ function _mulzeros! (dest:: AbstractMatrix{T} , A, B) where T
384+ for j in axes (dest,2 ), k in axes (dest,1 )
385+ dest[k,j] = similar (Mul (A[k,1 ],B[1 ,j]), eltype (T))
386+ end
387+ dest
388+ end
389+
390+ mulzeros (:: Type{T} , M) where T<: AbstractArray = _mulzeros! (similar (Array{T}, axes (M)), M. A, M. B)
415391
416392# ##
417- # Fill
393+ # Fill
418394# ##
419395
420396copy (M:: MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout} ) = M. α* M. A* M. B + M. β* M. C
0 commit comments