11module NDIteration
22
3+ import Base. MultiplicativeInverses: SignedMultiplicativeInverse
4+
5+ # CartesianIndex uses Int instead of Int32
6+
7+ @eval EmptySMI () = $ (Expr (:new , SignedMultiplicativeInverse{Int32}, Int32 (0 ), typemax (Int32), 0 % Int8, 0 % UInt8))
8+ SMI (i) = i == 0 ? EmptySMI () : SignedMultiplicativeInverse {Int32} (i)
9+
10+ struct FastCartesianIndices{N} <: AbstractArray{CartesianIndex{N}, N}
11+ inverses:: NTuple{N, SignedMultiplicativeInverse{Int32}}
12+ end
13+
14+ function FastCartesianIndices (indices:: NTuple{N} ) where {N}
15+ inverses = map (i -> SMI (Int32 (i)), indices)
16+ FastCartesianIndices (inverses)
17+ end
18+
19+ function Base. size (FCI:: FastCartesianIndices{N} ) where {N}
20+ ntuple (Val (N)) do I
21+ FCI. inverses[I]. divisor
22+ end
23+ end
24+
25+ @inline function Base. getindex (:: FastCartesianIndices{0} )
26+ return CartesianIndex ()
27+ end
28+
29+ @inline function Base. getindex (iter:: FastCartesianIndices{N} , I:: Vararg{Int, N} ) where {N}
30+ @boundscheck checkbounds (iter, I... )
31+ index = map (iter. inverses, I) do inv, i
32+ @inbounds getindex (Base. OneTo (inv. divisor), i)
33+ end
34+ CartesianIndex (index)
35+ end
36+
37+ _ind2sub_recuse (:: Tuple{} , ind) = (ind + 1 ,)
38+ function _ind2sub_recurse (indslast:: NTuple{1} , ind)
39+ Base. @_inline_meta
40+ (_lookup (ind, indslast[1 ]),)
41+ end
42+
43+ function _ind2sub_recurse (inds, ind)
44+ Base. @_inline_meta
45+ inv = inds[1 ]
46+ indnext, f, l = _div (ind, inv)
47+ (ind - l * indnext + f, _ind2sub_recurse (Base. tail (inds), indnext)... )
48+ end
49+
50+ _lookup (ind, inv:: SignedMultiplicativeInverse ) = ind + 1
51+ function _div (ind, inv:: SignedMultiplicativeInverse )
52+ inv. divisor == 0 && throw (DivideError ())
53+ div (ind % Int32, inv), 1 , inv. divisor
54+ end
55+
56+ function Base. _ind2sub (inv:: FastCartesianIndices , ind)
57+ Base. @_inline_meta
58+ _ind2sub_recurse (inv. inverses, ind - 1 )
59+ end
60+
361export _Size, StaticSize, DynamicSize, get
462export NDRange, blocks, workitems, expand
563export DynamicCheck, NoDynamicCheck
@@ -50,18 +108,32 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems}
50108 blocks:: DynamicBlock
51109 workitems:: DynamicWorkitems
52110
53- function NDRange {N, B, W} () where {N, B, W}
54- new {N, B, W, Nothing, Nothing} (nothing , nothing )
55- end
56-
57- function NDRange {N, B, W} (blocks, workitems) where {N, B, W}
111+ function NDRange {N, B, W} (blocks:: Union{Nothing, FastCartesianIndices{N}} , workitems:: Union{Nothing, FastCartesianIndices{N}} ) where {N, B, W}
112+ @assert B <: _Size
113+ @assert W <: _Size
58114 new {N, B, W, typeof(blocks), typeof(workitems)} (blocks, workitems)
59115 end
60116end
61117
62- @inline workitems (range:: NDRange{N, B, W} ) where {N, B, W <: DynamicSize } = range. workitems:: CartesianIndices{N}
118+ function NDRange {N, B, W} () where {N, B, W}
119+ NDRange {N, B, W} (nothing , nothing )
120+ end
121+
122+ function NDRange {N, B, W} (blocks:: CartesianIndices , workitems:: CartesianIndices ) where {N, B, W}
123+ return NDRange {N, B, W} (FastCartesianIndices (size (blocks)), FastCartesianIndices (size (workitems)))
124+ end
125+
126+ function NDRange {N, B, W} (blocks:: Nothing , workitems:: CartesianIndices ) where {N, B, W}
127+ return NDRange {N, B, W} (blocks, FastCartesianIndices (size (workitems)))
128+ end
129+
130+ function NDRange {N, B, W} (blocks:: CartesianIndices , workitems:: Nothing ) where {N, B, W}
131+ return NDRange {N, B, W} (FastCartesianIndices (size (blocks)), workitems)
132+ end
133+
134+ @inline workitems (range:: NDRange{N, B, W} ) where {N, B, W <: DynamicSize } = range. workitems:: FastCartesianIndices{N}
63135@inline workitems (range:: NDRange{N, B, W} ) where {N, B, W <: StaticSize } = CartesianIndices (get (W)):: CartesianIndices{N}
64- @inline blocks (range:: NDRange{N, B} ) where {N, B <: DynamicSize } = range. blocks:: CartesianIndices {N}
136+ @inline blocks (range:: NDRange{N, B} ) where {N, B <: DynamicSize } = range. blocks:: FastCartesianIndices {N}
65137@inline blocks (range:: NDRange{N, B} ) where {N, B <: StaticSize } = CartesianIndices (get (B)):: CartesianIndices{N}
66138
67139import Base. iterate
@@ -80,8 +152,8 @@ Base.length(range::NDRange) = length(blocks(range))
80152 CartesianIndex (nI)
81153end
82154
83- Base. @propagate_inbounds function expand (ndrange:: NDRange , groupidx:: Integer , idx:: Integer )
84- expand (ndrange, blocks (ndrange)[groupidx], workitems (ndrange)[idx])
155+ Base. @propagate_inbounds function expand (ndrange:: NDRange{N} , groupidx:: Integer , idx:: Integer ) where {N}
156+ return expand (ndrange, blocks (ndrange)[groupidx], workitems (ndrange)[idx])
85157end
86158
87159Base. @propagate_inbounds function expand (ndrange:: NDRange{N} , groupidx:: CartesianIndex{N} , idx:: Integer ) where {N}
0 commit comments