1+ #include " simsycl/sycl/group.hh"
2+ #include " simsycl/sycl/khr/sub_group_queries.hh"
3+ #include " simsycl/sycl/nd_item.hh"
4+
15#include < simsycl/detail/utils.hh>
26#include < simsycl/schedule.hh>
37#include < simsycl/sycl/device.hh>
@@ -181,6 +185,23 @@ void cooperative_for_nd_range(const sycl::device &device, const sycl::nd_range<D
181185 std::vector<detail::concurrent_sub_group> concurrent_sub_groups (num_concurrent_sub_groups);
182186 std::vector<detail::concurrent_nd_item> num_concurrent_nd_items (num_concurrent_items);
183187
188+ #if SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
189+ std::vector<const sycl::nd_item<Dimensions> *> concurrent_khr_wi_query_nd_item_ptrs (num_concurrent_items, nullptr );
190+
191+ auto update_global_khr_wi_query_data = [&](int cc_g_idx = -1 ) {
192+ if (cc_g_idx != -1 && concurrent_khr_wi_query_nd_item_ptrs[cc_g_idx] != nullptr ) {
193+ const auto nd_item = *concurrent_khr_wi_query_nd_item_ptrs[cc_g_idx];
194+ khr::detail::g_khr_wi_query_this_nd_item<Dimensions> = nd_item;
195+ khr::detail::g_khr_wi_query_this_group<Dimensions> = nd_item.get_group ();
196+ khr::detail::g_khr_wi_query_this_sub_group = nd_item.get_sub_group ();
197+ } else {
198+ khr::detail::g_khr_wi_query_this_nd_item<Dimensions> = std::nullopt ;
199+ khr::detail::g_khr_wi_query_this_group<Dimensions> = std::nullopt ;
200+ khr::detail::g_khr_wi_query_this_sub_group = std::nullopt ;
201+ }
202+ };
203+ #endif // SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
204+
184205 for (auto &cgroup : concurrent_groups) {
185206 cgroup.local_memory_allocations .resize (local_memory.size ());
186207 for (size_t i = 0 ; i < local_memory.size (); ++i) {
@@ -220,8 +241,13 @@ void cooperative_for_nd_range(const sycl::device &device, const sycl::nd_range<D
220241 group_linear_range, sub_group_linear_id_in_group, sub_group_linear_range_in_group,
221242 sub_group_max_local_linear_range, sub_group_max_local_range, thread_id_in_sub_group,
222243 sub_group_id_in_group, sub_group_range_in_group, &concurrent_nd_item, &concurrent_group,
223- &concurrent_sub_group, &kernel, &concurrent_items_exited, &caught_exceptions,
224- &range](boost::context::continuation &&scheduler) //
244+ &concurrent_sub_group, &kernel, &concurrent_items_exited, &caught_exceptions, &range
245+ #if SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
246+ ,
247+ concurrent_global_idx, &concurrent_khr_wi_query_nd_item_ptrs,
248+ &update_global_khr_wi_query_data
249+ #endif
250+ ](boost::context::continuation &&scheduler) //
225251 {
226252 // yield immediately to allow the scheduling loop to set up local memory pointers
227253 enter_kernel_fiber (std::move (scheduler));
@@ -245,7 +271,8 @@ void cooperative_for_nd_range(const sycl::device &device, const sycl::nd_range<D
245271
246272 SIMSYCL_START_IGNORING_DEPRECATIONS;
247273 const auto group_id = linear_index_to_id (group_range, group_linear_id);
248- const auto global_id = range.get_offset () + (group_id * sycl::id<Dimensions>(local_range)) + local_id;
274+ const auto global_id
275+ = range.get_offset () + (group_id * sycl::id<Dimensions>(local_range)) + local_id;
249276
250277 // if sub-group range is not divisible by local range, the last sub-group will be smaller
251278 const auto sub_group_local_linear_range = std::min (sub_group_max_local_linear_range,
@@ -265,6 +292,12 @@ void cooperative_for_nd_range(const sycl::device &device, const sycl::nd_range<D
265292 const auto nd_item
266293 = detail::make_nd_item (global_item, local_item, group, sub_group, &concurrent_nd_item);
267294
295+ #if SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
296+ concurrent_khr_wi_query_nd_item_ptrs[concurrent_global_idx] = &nd_item;
297+ // adjust the globals now that the data is available, before starting the kernel
298+ update_global_khr_wi_query_data (concurrent_global_idx);
299+ #endif // SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
300+
268301 try {
269302 kernel (nd_item);
270303 // Add an implicit "exit" operations to groups and sub-groups to catch potential divergence on
@@ -311,11 +344,21 @@ void cooperative_for_nd_range(const sycl::device &device, const sycl::nd_range<D
311344 *local_memory[i].ptr = concurrent_groups[concurrent_group_idx].local_memory_allocations [i].get ();
312345 }
313346
347+ #if SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
348+ // adjust globals before switching fibers
349+ update_global_khr_wi_query_data (concurrent_global_idx);
350+ #endif // SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
351+
314352 fibers[concurrent_global_idx] = fibers[concurrent_global_idx].resume ();
315353 }
316354 schedule_state = schedule.update (schedule_state, order);
317355 }
318356
357+ #if SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
358+ // reset globals
359+ update_global_khr_wi_query_data ();
360+ #endif // SIMSYCL_ENABLE_SYCL_KHR_WORK_ITEM_QUERIES
361+
319362 // rethrow any encountered exceptions
320363 for (auto &exception : caught_exceptions) { std::rethrow_exception (exception); }
321364}
0 commit comments