@@ -54,10 +54,17 @@ template <typename T1, typename T2> T1 ceiling_quotient(T1 n, T2 m)
5454 return ceiling_quotient<T1>(n, static_cast <T1>(m));
5555}
5656
57- template <typename inputT, typename outputT, typename IndexerT, size_t n_wi>
57+ template <typename inputT,
58+ typename outputT,
59+ size_t n_wi,
60+ typename IndexerT,
61+ typename TransformerT>
5862class inclusive_scan_rec_local_scan_krn ;
5963
60- template <typename inputT, typename outputT, typename IndexerT>
64+ template <typename inputT,
65+ typename outputT,
66+ typename IndexerT,
67+ typename TransformerT>
6168class inclusive_scan_rec_chunk_update_krn ;
6269
6370struct NoOpIndexer
@@ -136,29 +143,40 @@ struct Strided1DCyclicIndexer
136143 py::ssize_t step = 1 ;
137144};
138145
139- template <typename _IndexerFn > struct ZeroChecker
146+ template <typename inputT, typename outputT > struct NonZeroIndicator
140147{
148+ NonZeroIndicator () {}
141149
142- ZeroChecker (_IndexerFn _indexer) : indexer_fn(_indexer) {}
143-
144- template <typename dataT>
145- bool operator ()(dataT const *data, size_t gid) const
150+ outputT operator ()(const inputT &val) const
146151 {
147- constexpr dataT _zero (0 );
152+ constexpr outputT out_one (1 );
153+ constexpr outputT out_zero (0 );
154+ constexpr inputT val_zero (0 );
148155
149- return data[ indexer_fn (gid)] == _zero ;
156+ return (val == val_zero) ? out_zero : out_one ;
150157 }
158+ };
151159
152- private:
153- _IndexerFn indexer_fn;
160+ template <typename T> struct NoOpTransformer
161+ {
162+ NoOpTransformer () {}
163+
164+ T operator ()(const T &val) const
165+ {
166+ return val;
167+ }
154168};
155169
156170/*
157171 * for integer type maskT,
158172 * output[j] = sum( input[s0 + i * s1], 0 <= i <= j)
159173 * for 0 <= j < n_elems
160174 */
161- template <typename inputT, typename outputT, typename IndexerT, size_t n_wi>
175+ template <typename inputT,
176+ typename outputT,
177+ size_t n_wi,
178+ typename IndexerT,
179+ typename TransformerT>
162180sycl::event inclusive_scan_rec (sycl::queue exec_q,
163181 size_t n_elems,
164182 size_t wg_size,
@@ -167,6 +185,7 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
167185 size_t s0,
168186 size_t s1,
169187 IndexerT indexer,
188+ TransformerT transformer,
170189 std::vector<sycl::event> const &depends = {})
171190{
172191 size_t n_groups = ceiling_quotient (n_elems, n_wi * wg_size);
@@ -181,9 +200,7 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
181200
182201 slmT slm_iscan_tmp (lws, cgh);
183202
184- ZeroChecker<IndexerT> is_zero_fn (indexer);
185-
186- cgh.parallel_for <class inclusive_scan_rec_local_scan_krn <inputT, outputT, ZeroChecker<IndexerT>, n_wi>>(
203+ cgh.parallel_for <class inclusive_scan_rec_local_scan_krn <inputT, outputT, n_wi, IndexerT, decltype (transformer)>>(
187204 sycl::nd_range<1 >(gws, lws),
188205 [=](sycl::nd_item<1 > it)
189206 {
@@ -195,11 +212,10 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
195212 size_t i = chunk_gid * n_wi;
196213 for (size_t m_wi = 0 ; m_wi < n_wi; ++m_wi) {
197214 constexpr outputT out_zero (0 );
198- constexpr outputT out_one ( 1 );
215+
199216 local_isum[m_wi] =
200217 (i + m_wi < n_elems)
201- ? (is_zero_fn (input, s0 + s1 * (i + m_wi)) ? out_zero
202- : out_one)
218+ ? transformer (input[indexer (s0 + s1 * (i + m_wi))])
203219 : out_zero;
204220 }
205221
@@ -240,14 +256,17 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
240256 auto chunk_size = wg_size * n_wi;
241257
242258 NoOpIndexer _no_op_indexer{};
243- auto e2 = inclusive_scan_rec<outputT, outputT, NoOpIndexer, n_wi>(
259+ NoOpTransformer<outputT> _no_op_transformer{};
260+ auto e2 = inclusive_scan_rec<outputT, outputT, n_wi, NoOpIndexer,
261+ decltype (_no_op_transformer)>(
244262 exec_q, n_groups - 1 , wg_size, output, temp, chunk_size - 1 ,
245- chunk_size, _no_op_indexer, {inc_scan_phase1_ev});
263+ chunk_size, _no_op_indexer, _no_op_transformer,
264+ {inc_scan_phase1_ev});
246265
247266 // output[ chunk_size * (i + 1) + j] += temp[i]
248267 auto e3 = exec_q.submit ([&](sycl::handler &cgh) {
249268 cgh.depends_on (e2 );
250- cgh.parallel_for <class inclusive_scan_rec_chunk_update_krn <inputT, outputT, IndexerT>>(
269+ cgh.parallel_for <class inclusive_scan_rec_chunk_update_krn <inputT, outputT, IndexerT, decltype (transformer) >>(
251270 {n_elems},
252271 [=](auto wiid)
253272 {
@@ -258,14 +277,13 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
258277 );
259278 });
260279
261- // dangling task to free the temporary
262- exec_q.submit ([&](sycl::handler &cgh) {
280+ sycl::event e4 = exec_q.submit ([&](sycl::handler &cgh) {
263281 cgh.depends_on (e3 );
264282 auto ctx = exec_q.get_context ();
265283 cgh.host_task ([ctx, temp]() { sycl::free (temp, ctx); });
266284 });
267285
268- out_event = e3 ;
286+ out_event = e4 ;
269287 }
270288
271289 return out_event;
@@ -502,10 +520,13 @@ size_t mask_positions_contig_impl(sycl::queue q,
502520 size_t wg_size = 128 ;
503521
504522 NoOpIndexer flat_indexer{};
523+ NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
505524
506- sycl::event comp_ev = inclusive_scan_rec<maskT, cumsumT, NoOpIndexer, n_wi>(
507- q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0 , 1 , flat_indexer,
508- depends);
525+ sycl::event comp_ev =
526+ inclusive_scan_rec<maskT, cumsumT, n_wi, decltype (flat_indexer),
527+ decltype (non_zero_indicator)>(
528+ q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0 , 1 ,
529+ flat_indexer, non_zero_indicator, depends);
509530
510531 cumsumT *last_elem = cumsum_data_ptr + (n_elems - 1 );
511532
@@ -558,11 +579,13 @@ size_t mask_positions_strided_impl(sycl::queue q,
558579 size_t wg_size = 128 ;
559580
560581 StridedIndexer strided_indexer{nd, input_offset, shape_strides};
582+ NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
561583
562584 sycl::event comp_ev =
563- inclusive_scan_rec<maskT, cumsumT, StridedIndexer, n_wi>(
585+ inclusive_scan_rec<maskT, cumsumT, n_wi, decltype (strided_indexer),
586+ decltype (non_zero_indicator)>(
564587 q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0 , 1 ,
565- strided_indexer, depends);
588+ strided_indexer, non_zero_indicator, depends);
566589
567590 cumsumT *last_elem = cumsum_data_ptr + (n_elems - 1 );
568591
0 commit comments