@@ -57,11 +57,13 @@ function add_broadcast!(
5757 bloopsyms = Symbol[k]
5858 cloopsyms = Symbol[m]
5959 reductdeps = Symbol[m, k]
60+ kvec = bloopsyms
6061 elseif ndims (B) == 2
6162 n = loopsyms[2 ];
6263 bloopsyms = Symbol[k,n]
6364 cloopsyms = Symbol[m,n]
6465 reductdeps = Symbol[m, k, n]
66+ kvec = Symbol[k]
6567 else
6668 throw (" B must be a vector or matrix." )
6769 end
@@ -72,13 +74,22 @@ function add_broadcast!(
7274 loadB = add_broadcast! (ls, gensym (:B ), mB, bloopsyms, B, elementbytes)
7375 # set Cₘₙ = 0
7476 # setC = add_constant!(ls, zero(promote_type(recursive_eltype(A), recursive_eltype(B))), cloopsyms, mC, elementbytes)
77+ # targetC will be used for reduce_to_add
78+ mCt = gensym (mC)
79+ targetC = add_constant! (ls, gensym (:zero ), cloopsyms, mCt, elementbytes, :numericconstant )
80+ push! (ls. preamble_zeros, (identifier (targetC), IntOrFloat))
7581 setC = add_constant! (ls, gensym (:zero ), cloopsyms, mC, elementbytes, :numericconstant )
7682 push! (ls. preamble_zeros, (identifier (setC), IntOrFloat))
83+ setC. reduced_children = kvec
7784 # compute Cₘₙ += Aₘₖ * Bₖₙ
7885 reductop = Operation (
79- ls, mC, elementbytes, :vmuladd , compute, reductdeps, Symbol[k] , Operation[loadA, loadB, setC]
86+ ls, mC, elementbytes, :vmuladd , compute, reductdeps, kvec , Operation[loadA, loadB, setC]
8087 )
81- pushop! (ls, reductop, mC)
88+ reductop = pushop! (ls, reductop, mC)
89+ reductfinal = Operation (
90+ ls, mCt, elementbytes, :reduce_to_add , compute, cloopsyms, kvec, Operation[reductop, targetC]
91+ )
92+ pushop! (ls, reductfinal, mCt)
8293end
8394
8495struct LowDimArray{D,T,N,A<: DenseArray{T,N} } <: DenseArray{T,N}
0 commit comments