@@ -2,9 +2,25 @@ struct Product{A,B}
22 a:: A
33 b:: B
44end
5+ function Base. size (p:: Product )
6+ M = size (p. a, 1 )
7+ (M, Base. tail (size (p. b))... )
8+ end
9+ @inline Base. length (p:: Product ) = prod (size (p))
10+ @inline Base. broadcastable (p:: Product ) = p
11+ @inline Base. ndims (p:: Type{Product{A,B}} ) where {A,B} = ndims (B)
12+
13+ Base. Broadcast. _broadcast_getindex_eltype (:: Product{A,B} ) where {T, A <: AbstractVecOrMat{T} , B <: AbstractVecOrMat{T} } = T
14+ function Base. Broadcast. _broadcast_getindex_eltype (p:: Product )
15+ promote_type (
16+ Base. Broadcast. _broadcast_getindex_eltype (p. a),
17+ Base. Broadcast. _broadcast_getindex_eltype (p. b)
18+ )
19+ end
20+
521
622@inline ∗ (a:: A , b:: B ) where {A,B} = Product {A,B} (a, b)
7- @inline Base. Broadcast. Broadcasted (:: typeof (∗ ), a:: A , b:: B ) where {A, B} = Product {A,B} (a, b)
23+ @inline Base. Broadcast. broadcasted (:: typeof (∗ ), a:: A , b:: B ) where {A, B} = Product {A,B} (a, b)
824# TODO : Need to make this handle A or B being (1 or 2)-D broadcast objects.
925function add_broadcast! (
1026 ls:: LoopSet , mC:: Symbol , bcname:: Symbol , loopsyms:: Vector{Symbol} ,
@@ -19,17 +35,29 @@ function add_broadcast!(
1935
2036 k = gensym (:k )
2137 ls. loops[k] = Loop (k, K)
22- m = loopsyms[1 ]; n = loopsyms[2 ];
38+ m = loopsyms[1 ];
39+ if ndims (B) == 1
40+ bloopsyms = Symbol[k]
41+ cloopsyms = Symbol[m]
42+ reductdeps = Symbol[m, k]
43+ elseif ndims (B) == 2
44+ n = loopsyms[2 ];
45+ bloopsyms = Symbol[k,n]
46+ cloopsyms = Symbol[m,n]
47+ reductdeps = Symbol[m, k, n]
48+ else
49+ throw (" B must be a vector or matrix." )
50+ end
2351 # load A
2452 # loadA = add_load!(ls, gensym(:A), productref(A, mA, m, k), elementbytes)
25- loadA = add_broadcast! (ls, gensym (:A ), mA, [m,k], A, elementbytes)
53+ loadA = add_broadcast! (ls, gensym (:A ), mA, Symbol [m,k], A, elementbytes)
2654 # load B
27- loadB = add_broadcast! (ls, gensym (:B ), mB, [k,n] , B, elementbytes)
55+ loadB = add_broadcast! (ls, gensym (:B ), mB, bloopsyms , B, elementbytes)
2856 # set Cₘₙ = 0
29- setC = add_constant! (ls, 0.0 , Symbol[m, k] , mC, elementbytes)
57+ setC = add_constant! (ls, 0.0 , cloopsyms , mC, elementbytes)
3058 # compute Cₘₙ += Aₘₖ * Bₖₙ
3159 reductop = Operation (
32- ls, mC, elementbytes, :vmuladd , compute, Symbol[m, k, n] , Symbol[k], Operation[loadA, loadB, setC]
60+ ls, mC, elementbytes, :vmuladd , compute, reductdeps , Symbol[k], Operation[loadA, loadB, setC]
3361 )
3462 pushop! (ls, reductop, mC)
3563end
102130# size of dest determines loops
103131@generated function vmaterialize! (
104132 dest:: AbstractArray{T,N} , bc:: BC
105- # ) where {T, N, BC <: Broadcasted}
106- ) where {N, T, BC <: Broadcasted }
133+ ) where {T, N, BC <: Broadcasted }
134+ # ) where {N, T, BC <: Broadcasted}
107135 # we have an N dimensional loop.
108136 # need to construct the LoopSet
109137 loopsyms = [gensym (:n ) for n ∈ 1 : N]
0 commit comments