Skip to content

Commit e0e2ca5

Browse files
committed
Removed device::combinatorial_kalman_filter_algorithm::propagate_to_next_surface_kernel_payload.
1 parent f14b528 commit e0e2ca5

13 files changed

Lines changed: 128 additions & 185 deletions

device/alpaka/include/traccc/alpaka/finding/combinatorial_kalman_filter_algorithm.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,16 @@ class combinatorial_kalman_filter_algorithm
132132
/// Launch the @c propagate_to_next_surface kernel
133133
///
134134
/// @param n_threads The number of threads to launch the kernel with
135+
/// @param config The track finding configuration
136+
/// @param det The detector object
137+
/// @param bfield The magnetic field object
135138
/// @param payload The payload for the kernel
136139
///
137140
void propagate_to_next_surface_kernel(
138-
unsigned int n_threads,
139-
const propagate_to_next_surface_kernel_payload& payload) const override;
141+
unsigned int n_threads, const finding_config& config,
142+
const detector_buffer& det, const magnetic_field& bfield,
143+
const device::propagate_to_next_surface_payload& payload)
144+
const override;
140145

141146
/// Launch the @c gather_best_tips_per_measurement kernel
142147
///

device/alpaka/src/finding/combinatorial_kalman_filter_algorithm.cpp

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,14 @@ struct propagate_to_next_surface {
131131
template <typename TAcc>
132132
ALPAKA_FN_ACC void operator()(
133133
TAcc const& acc, const finding_config& cfg,
134-
const device::propagate_to_next_surface_payload<propagator_t, bfield_t>*
135-
payload) const {
134+
const typename propagator_t::detector_type::const_view_type* det_data,
135+
const bfield_t& field_data,
136+
const device::propagate_to_next_surface_payload& payload) const {
136137

137138
device::global_index_t globalThreadIdx =
138139
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
139140
device::propagate_to_next_surface<propagator_t, bfield_t>(
140-
globalThreadIdx, cfg, *payload);
141+
globalThreadIdx, cfg, *det_data, field_data, payload);
141142
}
142143
};
143144

@@ -379,8 +380,9 @@ void combinatorial_kalman_filter_algorithm::sort_param_ids_by_keys(
379380
}
380381

381382
void combinatorial_kalman_filter_algorithm::propagate_to_next_surface_kernel(
382-
unsigned int n_threads,
383-
const propagate_to_next_surface_kernel_payload& payload) const {
383+
unsigned int n_threads, const finding_config& config,
384+
const detector_buffer& detector, const magnetic_field& field,
385+
const device::propagate_to_next_surface_payload& payload) const {
384386

385387
// Establish the kernel launch parameters.
386388
const unsigned int deviceThreads = warp_size() * 2;
@@ -390,47 +392,26 @@ void combinatorial_kalman_filter_algorithm::propagate_to_next_surface_kernel(
390392
// Launch the kernel for the appropriate detector and magnetic field type.
391393
detector_buffer_magnetic_field_visitor<detector_type_list,
392394
alpaka::bfield_type_list<scalar>>(
393-
payload.det, payload.field,
395+
detector, field,
394396
[&]<typename detector_traits_t, typename bfield_view_t>(
395-
const typename detector_traits_t::view& detector,
397+
const typename detector_traits_t::view& det,
396398
const bfield_view_t& bfield) {
397-
// Propagator type to use.
398-
using propagator_t = traccc::details::ckf_propagator_t<
399-
typename detector_traits_t::device, bfield_view_t>;
400-
// Allocate the kernel's payload in host memory.
401-
using payload_t =
402-
device::propagate_to_next_surface_payload<propagator_t,
403-
bfield_view_t>;
404-
const payload_t host_payload{
405-
.det_data = detector,
406-
.field_data = bfield,
407-
.params_view = payload.params,
408-
.params_liveness_view = payload.params_liveness,
409-
.param_ids_view = payload.param_ids,
410-
.links_view = payload.links,
411-
.prev_links_idx = payload.prev_links_idx,
412-
.step = payload.step,
413-
.n_in_params = n_threads,
414-
.tips_view = payload.tips,
415-
.tip_lengths_view = payload.tip_lengths,
416-
.tmp_jacobian_ptr = payload.tmp_jacobian.ptr()};
417-
// Now copy it to device memory.
418-
vecmem::data::vector_buffer<payload_t> device_payload(1u,
419-
mr().main);
420-
copy().setup(device_payload)->ignore();
421-
copy()(
422-
vecmem::data::vector_view<const payload_t>(1u, &host_payload),
423-
device_payload)
424-
->ignore();
399+
// Copy the detector data to device memory.
400+
vecmem::data::vector_buffer<typename detector_traits_t::view>
401+
device_det(1u, mr().main);
402+
copy().setup(device_det)->ignore();
403+
copy()({1u, &det}, device_det)->ignore();
425404

426405
// Launch the kernel to propagate all active tracks to the next
427406
// surface.
428407
::alpaka::exec<Acc>(
429408
details::get_queue(queue()),
430409
makeWorkDiv<Acc>(deviceBlocks, deviceThreads),
431-
kernels::propagate_to_next_surface<propagator_t,
432-
bfield_view_t>{},
433-
payload.config, device_payload.ptr());
410+
kernels::propagate_to_next_surface<
411+
traccc::details::ckf_propagator_t<
412+
typename detector_traits_t::device, bfield_view_t>,
413+
bfield_view_t>{},
414+
config, device_det.ptr(), bfield, payload);
434415
});
435416
}
436417

device/common/include/traccc/finding/device/combinatorial_kalman_filter_algorithm.hpp

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "traccc/finding/device/find_tracks_payload.hpp"
1818
#include "traccc/finding/device/gather_best_tips_per_measurement.hpp"
1919
#include "traccc/finding/device/gather_measurement_votes.hpp"
20+
#include "traccc/finding/device/propagate_to_next_surface.hpp"
2021
#include "traccc/finding/device/remove_duplicates.hpp"
2122
#include "traccc/finding/device/update_tip_length_buffer.hpp"
2223

@@ -185,42 +186,18 @@ class combinatorial_kalman_filter_algorithm
185186
vecmem::data::vector_view<device::sort_key>& keys,
186187
vecmem::data::vector_view<unsigned int>& param_ids) const = 0;
187188

188-
/// Payload for the @c propagate_to_next_surface_kernel function
189-
struct propagate_to_next_surface_kernel_payload {
190-
/// The track finding configuration
191-
const finding_config& config;
192-
/// The detector object
193-
const detector_buffer& det;
194-
/// The magnetic field object
195-
const magnetic_field& field;
196-
/// The vector of track parameters
197-
bound_track_parameters_collection_types::view& params;
198-
/// The vector of track parameter liveness values
199-
vecmem::data::vector_view<unsigned int>& params_liveness;
200-
/// Sorted parameter identifiers
201-
const vecmem::data::vector_view<const unsigned int>& param_ids;
202-
/// The vector of candidate links
203-
const vecmem::data::vector_view<const candidate_link>& links;
204-
/// Index in the link vector at which the current step starts
205-
unsigned int prev_links_idx;
206-
/// Current CKF step number
207-
unsigned int step;
208-
/// The vector of tips
209-
vecmem::data::vector_view<unsigned int>& tips;
210-
/// The number of track states per tip
211-
vecmem::data::vector_view<unsigned int>& tip_lengths;
212-
/// The temporary Jacobian buffer
213-
vecmem::data::vector_view<bound_matrix<default_algebra>>& tmp_jacobian;
214-
};
215-
216189
/// Launch the @c propagate_to_next_surface kernel
217190
///
218191
/// @param n_threads The number of threads to launch the kernel with
192+
/// @param config The track finding configuration
193+
/// @param det The detector object
194+
/// @param bfield The magnetic field object
219195
/// @param payload The payload for the kernel
220196
///
221197
virtual void propagate_to_next_surface_kernel(
222-
unsigned int n_threads,
223-
const propagate_to_next_surface_kernel_payload& payload) const = 0;
198+
unsigned int n_threads, const finding_config& config,
199+
const detector_buffer& det, const magnetic_field& bfield,
200+
const device::propagate_to_next_surface_payload& payload) const = 0;
224201

225202
/// Launch the @c gather_best_tips_per_measurement kernel
226203
///

device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** TRACCC library, part of the ACTS project (R&D line)
22
*
3-
* (c) 2023-2025 CERN for the benefit of the ACTS project
3+
* (c) 2023-2026 CERN for the benefit of the ACTS project
44
*
55
* Mozilla Public License Version 2.0
66
*/
@@ -21,7 +21,9 @@ namespace traccc::device {
2121
template <typename propagator_t, typename bfield_t>
2222
TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
2323
const global_index_t globalIndex, const finding_config& cfg,
24-
const propagate_to_next_surface_payload<propagator_t, bfield_t>& payload) {
24+
const typename propagator_t::detector_type::const_view_type& det_data,
25+
const bfield_t& field_data,
26+
const propagate_to_next_surface_payload& payload) {
2527

2628
using scalar_t = propagator_t::detector_type::scalar_type;
2729

@@ -51,7 +53,7 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
5153
vecmem::device_vector<unsigned int> tip_lengths(payload.tip_lengths_view);
5254

5355
// Detector
54-
typename propagator_t::detector_type det(payload.det_data);
56+
typename propagator_t::detector_type det(det_data);
5557

5658
// Parameters
5759
bound_track_parameters_collection_types::device params(payload.params_view);
@@ -69,7 +71,7 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
6971
propagator_t propagator(prop_cfg);
7072

7173
// Create propagator state
72-
typename propagator_t::state propagation(in_par, payload.field_data, det);
74+
typename propagator_t::state propagation(in_par, field_data, det);
7375
propagation.set_particle(
7476
detail::correct_particle_hypothesis(cfg.ptc_hypothesis, in_par));
7577
propagation._stepping
@@ -107,11 +109,13 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
107109
* is set to the multiplicative identity.
108110
*/
109111
if (cfg.run_mbf_smoother) {
110-
assert(payload.tmp_jacobian_ptr != nullptr);
112+
assert(payload.tmp_jacobian_view.ptr() != nullptr);
111113

112-
payload.tmp_jacobian_ptr[param_id] = matrix::identity<
113-
bound_matrix<typename propagator_t::detector_type::algebra_type>>();
114-
s1._full_jacobian_ptr = &payload.tmp_jacobian_ptr[param_id];
114+
vecmem::device_vector<bound_matrix<default_algebra>> tmp_jacobian(
115+
payload.tmp_jacobian_view);
116+
tmp_jacobian.at(param_id) =
117+
matrix::identity<bound_matrix<default_algebra>>();
118+
s1._full_jacobian_ptr = &(tmp_jacobian.at(param_id));
115119
}
116120

117121
s5.min_pT(static_cast<scalar_t>(cfg.min_pT));

device/common/include/traccc/finding/device/propagate_to_next_surface.hpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** TRACCC library, part of the ACTS project (R&D line)
22
*
3-
* (c) 2023-2025 CERN for the benefit of the ACTS project
3+
* (c) 2023-2026 CERN for the benefit of the ACTS project
44
*
55
* Mozilla Public License Version 2.0
66
*/
@@ -23,18 +23,7 @@ namespace traccc::device {
2323

2424
/// (Event Data) Payload for the @c traccc::device::propagate_to_next_surface
2525
/// function
26-
template <typename propagator_t, typename bfield_t>
2726
struct propagate_to_next_surface_payload {
28-
/**
29-
* @brief View object to the tracking detector description
30-
*/
31-
typename propagator_t::detector_type::const_view_type det_data;
32-
33-
/**
34-
* @brief View object to the magnetic field
35-
*/
36-
bfield_t field_data;
37-
3827
/**
3928
* @brief View object to the vector of track parameters
4029
*/
@@ -80,8 +69,7 @@ struct propagate_to_next_surface_payload {
8069
*/
8170
vecmem::data::vector_view<unsigned int> tip_lengths_view;
8271

83-
bound_matrix<typename propagator_t::detector_type::algebra_type>*
84-
tmp_jacobian_ptr;
72+
vecmem::data::vector_view<bound_matrix<default_algebra> > tmp_jacobian_view;
8573
};
8674

8775
/// Function for propagating the kalman-updated tracks to the next surface
@@ -93,12 +81,17 @@ struct propagate_to_next_surface_payload {
9381
///
9482
/// @param[in] globalIndex The index of the current thread
9583
/// @param[in] cfg Track finding config object
84+
/// @param[in] det_data View object to the tracking detector
85+
/// description
86+
/// @param[in] field_data View object to the magnetic field
9687
/// @param[inout] payload The function call payload
9788
///
9889
template <typename propagator_t, typename bfield_t>
9990
TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
10091
global_index_t globalIndex, const finding_config& cfg,
101-
const propagate_to_next_surface_payload<propagator_t, bfield_t>& payload);
92+
const typename propagator_t::detector_type::const_view_type& det_data,
93+
const bfield_t& field_data,
94+
const propagate_to_next_surface_payload& payload);
10295

10396
} // namespace traccc::device
10497

device/common/src/finding/combinatorial_kalman_filter_algorithm.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -379,18 +379,17 @@ auto combinatorial_kalman_filter_algorithm::operator()(
379379
}
380380

381381
propagate_to_next_surface_kernel(
382-
n_candidates, {.config = m_data->m_config,
383-
.det = det,
384-
.field = bfield,
385-
.params = in_params_buffer,
386-
.params_liveness = param_liveness_buffer,
387-
.param_ids = param_ids_buffer,
388-
.links = links_buffer,
389-
.prev_links_idx = step_to_link_idx_map[step],
390-
.step = step,
391-
.tips = tips_buffer,
392-
.tip_lengths = tip_length_buffer,
393-
.tmp_jacobian = tmp_jacobian_buffer});
382+
n_candidates, m_data->m_config, det, bfield,
383+
{.params_view = in_params_buffer,
384+
.params_liveness_view = param_liveness_buffer,
385+
.param_ids_view = param_ids_buffer,
386+
.links_view = links_buffer,
387+
.prev_links_idx = step_to_link_idx_map[step],
388+
.step = step,
389+
.n_in_params = n_candidates,
390+
.tips_view = tips_buffer,
391+
.tip_lengths_view = tip_length_buffer,
392+
.tmp_jacobian_view = tmp_jacobian_buffer});
394393
}
395394

396395
n_in_params = n_candidates;

device/cuda/include/traccc/cuda/finding/combinatorial_kalman_filter_algorithm.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,16 @@ class combinatorial_kalman_filter_algorithm
132132
/// Launch the @c propagate_to_next_surface kernel
133133
///
134134
/// @param n_threads The number of threads to launch the kernel with
135+
/// @param config The track finding configuration
136+
/// @param det The detector object
137+
/// @param bfield The magnetic field object
135138
/// @param payload The payload for the kernel
136139
///
137140
void propagate_to_next_surface_kernel(
138-
unsigned int n_threads,
139-
const propagate_to_next_surface_kernel_payload& payload) const override;
141+
unsigned int n_threads, const finding_config& config,
142+
const detector_buffer& det, const magnetic_field& bfield,
143+
const device::propagate_to_next_surface_payload& payload)
144+
const override;
140145

141146
/// Launch the @c gather_best_tips_per_measurement kernel
142147
///

device/cuda/src/finding/combinatorial_kalman_filter_algorithm.cu

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,9 @@ void combinatorial_kalman_filter_algorithm::sort_param_ids_by_keys(
214214
}
215215
216216
void combinatorial_kalman_filter_algorithm::propagate_to_next_surface_kernel(
217-
unsigned int n_threads,
218-
const propagate_to_next_surface_kernel_payload& payload) const {
217+
unsigned int n_threads, const finding_config& config,
218+
const detector_buffer& detector, const magnetic_field& field,
219+
const device::propagate_to_next_surface_payload& payload) const {
219220
220221
// Establish the kernel launch parameters.
221222
const unsigned int deviceThreads = warp_size() * 2;
@@ -225,32 +226,16 @@ void combinatorial_kalman_filter_algorithm::propagate_to_next_surface_kernel(
225226
// Launch the kernel for the appropriate detector and magnetic field type.
226227
detector_buffer_magnetic_field_visitor<detector_type_list,
227228
cuda::bfield_type_list<scalar>>(
228-
payload.det, payload.field,
229+
detector, field,
229230
[&]<typename detector_traits_t, typename bfield_view_t>(
230-
const typename detector_traits_t::view& detector,
231+
const typename detector_traits_t::view& det,
231232
const bfield_view_t& bfield) {
232233
propagate_to_next_surface<
233234
traccc::details::ckf_propagator_t<
234235
typename detector_traits_t::device, bfield_view_t>,
235-
bfield_view_t>(
236-
deviceBlocks, deviceThreads, 0u, details::get_stream(stream()),
237-
payload.config,
238-
device::propagate_to_next_surface_payload<
239-
traccc::details::ckf_propagator_t<
240-
typename detector_traits_t::device, bfield_view_t>,
241-
bfield_view_t>{
242-
.det_data = detector,
243-
.field_data = bfield,
244-
.params_view = payload.params,
245-
.params_liveness_view = payload.params_liveness,
246-
.param_ids_view = payload.param_ids,
247-
.links_view = payload.links,
248-
.prev_links_idx = payload.prev_links_idx,
249-
.step = payload.step,
250-
.n_in_params = n_threads,
251-
.tips_view = payload.tips,
252-
.tip_lengths_view = payload.tip_lengths,
253-
.tmp_jacobian_ptr = payload.tmp_jacobian.ptr()});
236+
bfield_view_t>(deviceBlocks, deviceThreads, 0u,
237+
details::get_stream(stream()), config, det,
238+
bfield, payload);
254239
});
255240
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
256241
}

device/cuda/src/finding/kernels/propagate_to_next_surface.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** TRACCC library, part of the ACTS project (R&D line)
22
*
3-
* (c) 2023-2025 CERN for the benefit of the ACTS project
3+
* (c) 2023-2026 CERN for the benefit of the ACTS project
44
*
55
* Mozilla Public License Version 2.0
66
*/
@@ -19,7 +19,9 @@ namespace traccc::cuda {
1919
template <typename propagator_t, typename bfield_t>
2020
void propagate_to_next_surface(
2121
const dim3& grid_size, const dim3& block_size, std::size_t shared_mem_size,
22-
const cudaStream_t& stream, const finding_config cfg,
23-
device::propagate_to_next_surface_payload<propagator_t, bfield_t> payload);
22+
const cudaStream_t& stream, const finding_config& cfg,
23+
const typename propagator_t::detector_type::const_view_type& det_data,
24+
const bfield_t& field_data,
25+
const device::propagate_to_next_surface_payload& payload);
2426

2527
} // namespace traccc::cuda

0 commit comments

Comments
 (0)