@@ -21,16 +21,23 @@ function mergesetdiffv!(
2121 end
2222 nothing
2323end
24+ # Everything in arg2 (s1) that isn't in arg3 (s2) is added to arg1 (s3)
2425function setdiffv! (s3:: AbstractVector{T} , s1:: AbstractVector{T} , s2:: AbstractVector{T} ) where {T}
2526 for s ∈ s1
2627 (s ∈ s2) || (s ∉ s3 && push! (s3, s))
2728 end
2829end
30+ function setdiffv! (s4:: AbstractVector{T} , s3:: AbstractVector{T} , s1:: AbstractVector{T} , s2:: AbstractVector{T} ) where {T}
31+ for s ∈ s1
32+ (s ∈ s2) ? (s ∉ s4 && push! (s4, s)) : (s ∉ s3 && push! (s3, s))
33+ end
34+ end
2935function update_deps! (deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , parent:: Operation )
30- mergesetdiffv ! (deps, loopdependencies (parent), reduceddependencies (parent))
36+ mergesetv ! (deps, loopdependencies (parent)) # , reduceddependencies(parent))
3137 if ! (isload (parent) || isconstant (parent)) && parent. instruction. instr ∉ (:reduced_add , :reduced_prod , :reduce_to_add , :reduce_to_prod )
3238 mergesetv! (reduceddeps, reduceddependencies (parent))
3339 end
40+ #
3441 nothing
3542end
3643
@@ -42,19 +49,19 @@ function pushparent!(mpref::ArrayReferenceMetaPosition, parent::Operation)
4249 pushparent! (mpref. parents, mpref. loopdependencies, mpref. reduceddeps, parent)
4350end
4451function add_parent! (
45- parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var, elementbytes:: Int = 8
52+ parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var, elementbytes:: Int , position :: Int
4653)
4754 parent = if var isa Symbol
4855 getop (ls, var, elementbytes)
4956 elseif var isa Expr # CSE candidate
50- add_operation! (ls, gensym (:temporary ), var, elementbytes)
57+ add_operation! (ls, gensym (:temporary ), var, elementbytes, position )
5158 else # assumed constant
5259 add_constant! (ls, var, elementbytes)
5360 end
5461 pushparent! (parents, deps, reduceddeps, parent)
5562end
5663function add_reduction! (
57- parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var:: Symbol , elementbytes:: Int = 8
64+ parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var:: Symbol , elementbytes:: Int
5865)
5966 get! (ls. opdict, var) do
6067 add_constant! (ls, var, elementbytes)
@@ -80,10 +87,10 @@ function update_reduction_status!(parentvec::Vector{Operation}, deps::Vector{Sym
8087 end
8188end
8289function add_reduction_update_parent! (
83- parents :: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet ,
84- var :: Symbol , instr:: Symbol , directdependency:: Bool , elementbytes:: Int = 8
90+ vparents :: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet ,
91+ parent :: Operation , instr:: Symbol , directdependency:: Bool , elementbytes:: Int
8592)
86- parent = getop (ls, var, elementbytes )
93+ var = name (parent )
8794 isouterreduction = parent. instruction === LOOPCONSTANT
8895 Instr = instruction (ls, instr)
8996 instrclass = reduction_instruction_class (Instr) # key allows for faster lookups
@@ -110,27 +117,27 @@ function add_reduction_update_parent!(
110117 reductsym = var
111118 reductcombine = Symbol (" " )
112119 end
113- setdiffv! (reduceddeps, deps, loopdependencies (reductinit))
114120 combineddeps = copy (deps); mergesetv! (combineddeps, reduceddeps)
115- directdependency && pushparent! (parents , deps, reduceddeps, reductinit)# parent) # deps and reduced deps will not be disjoint
116- update_reduction_status! (parents , combineddeps, name (reductinit))
121+ directdependency && pushparent! (vparents , deps, reduceddeps, reductinit)# parent) # deps and reduced deps will not be disjoint
122+ update_reduction_status! (vparents , combineddeps, name (reductinit))
117123 # this is the op added by add_compute
118- op = Operation (length (operations (ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, parents )
124+ op = Operation (length (operations (ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, vparents )
119125 parent. instruction === LOOPCONSTANT && push! (ls. outer_reductions, identifier (op))
120126 opout = pushop! (ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
127+ # isouterreduction || iszero(length(reduceddeps)) && return opout
121128 isouterreduction && return opout
122129 # create child op, which is the reduction combination
123- childdeps = Symbol[]; childrdeps = Symbol[]; childparents = Operation[]
124- pushparent! (childparents, childdeps, childrdeps, op) # reduce op
125- pushparent! (childparents, childdeps, childrdeps, parent) # to
130+ childrdeps = Symbol[]; childparents = Operation[ op, parent ]
131+ childdeps = loopdependencies (reductinit)
132+ setdiffv! ( childrdeps, loopdependencies (op), childdeps)
126133 child = Operation (
127134 length (operations (ls)), name (parent), elementbytes, reductcombine, compute, childdeps, childrdeps, childparents
128135 )
129136 pushop! (ls, child, name (parent))
130137 opout
131138end
132139function add_compute! (
133- ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 ,
140+ ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int , position :: Int ,
134141 mpref:: Union{Nothing,ArrayReferenceMetaPosition} = nothing
135142)
136143 @assert ex. head === :call
@@ -149,12 +156,12 @@ function add_compute!(
149156 if isref
150157 if mpref == argref
151158 reduction = true
152- add_load! (ls, var, mpref , elementbytes)
159+ add_load! (ls, var, argref , elementbytes)
153160 else
154161 pushparent! (parents, deps, reduceddeps, add_load! (ls, gensym (:tempload ), argref, elementbytes))
155162 end
156163 else
157- add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes)
164+ add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes, position )
158165 end
159166 elseif arg ∈ ls. loopsymbols
160167 loopsym = gensym (arg)
@@ -164,11 +171,30 @@ function add_compute!(
164171 push! (ls. refs_aliasing_syms, loopsymop. ref)
165172 pushparent! (parents, deps, reduceddeps, loopsymop)
166173 else
167- add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes)
174+ add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes, position )
168175 end
169176 end
177+ if iszero (length (deps)) && reduction
178+ loopnestview = view (ls. loopsymbols, 1 : position)
179+ append! (deps, loopnestview)
180+ append! (reduceddeps, loopnestview)
181+ else
182+ loopnestview = view (ls. loopsymbols, 1 : position)
183+ newloopdeps = Symbol[]; newreduceddeps = Symbol[];
184+ setdiffv! (newloopdeps, newreduceddeps, deps, loopnestview)
185+ mergesetv! (newreduceddeps, reduceddeps)
186+ deps = newloopdeps; reduceddeps = newreduceddeps
187+ end
170188 if reduction || search_tree (parents, var)
171- add_reduction_update_parent! (parents, deps, reduceddeps, ls, var, instr, reduction, elementbytes)
189+ parent = getop (ls, var, elementbytes)
190+ setdiffv! (reduceddeps, deps, loopdependencies (parent))
191+ if length (reduceddeps) == 0
192+ push! (parents, parent)
193+ op = Operation (length (operations (ls)), var, elementbytes, instruction (ls,instr), compute, deps, reduceddeps, parents)
194+ pushop! (ls, op, var)
195+ else
196+ add_reduction_update_parent! (parents, deps, reduceddeps, ls, parent, instr, reduction, elementbytes)
197+ end
172198 else
173199 op = Operation (length (operations (ls)), var, elementbytes, instruction (ls,instr), compute, deps, reduceddeps, parents)
174200 pushop! (ls, op, var)
0 commit comments