@@ -17,9 +17,9 @@ infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t,
1717 infiniopTensorDescriptor_t y,
1818 infiniopTensorDescriptor_t x,
1919 infiniopTensorDescriptor_t w,
20- void const *pads,
21- void const *strides,
22- void const *dilations,
20+ uint64_t const *pads,
21+ int64_t const *strides,
22+ uint64_t const *dilations,
2323 uint64_t n) {
2424 uint64_t ndim = y->ndim ;
2525 if (ndim < 3 || ndim != x->ndim || ndim != w->ndim ) {
@@ -36,27 +36,39 @@ infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t,
3636 }
3737
3838 uint64_t y_size = getTotalSize (y->shape , ndim);
39- const auto pads_ = reinterpret_cast <uint64_t const *>(pads);
40- uint64_t padded_x_size = requirePadding (pads_, ndim) ? getPaddedSize (ndim, x->shape , pads_) : 0 ;
39+ uint64_t padded_x_size = requirePadding (pads, ndim) ? getPaddedSize (ndim, x->shape , pads) : 0 ;
4140 uint64_t *x_shape = new uint64_t [ndim];
4241 uint64_t *w_shape = new uint64_t [ndim];
4342 uint64_t *y_shape = new uint64_t [ndim];
43+ uint64_t *pads_ = new uint64_t [n];
44+ int64_t *strides_ = new int64_t [n];
45+ uint64_t *dilations_ = new uint64_t [n];
4446 memcpy (x_shape, x->shape , ndim * sizeof (uint64_t ));
4547 memcpy (w_shape, w->shape , ndim * sizeof (uint64_t ));
4648 memcpy (y_shape, y->shape , ndim * sizeof (uint64_t ));
49+ memcpy (pads_, pads, n * sizeof (*pads));
50+ memcpy (strides_, strides, n * sizeof (*strides));
51+ memcpy (dilations_, dilations, n * sizeof (*dilations));
52+
53+ uint64_t *padded_shape = nullptr ;
54+ if (padded_x_size > 0 ) {
55+ padded_shape = new uint64_t [ndim];
56+ getPaddedShape (ndim, x_shape, pads_, padded_shape);
57+ }
4758
4859 *desc_ptr = new ConvCpuDescriptor{
4960 DevCpu,
5061 y->dt ,
5162 ndim,
5263 y_size,
5364 padded_x_size,
65+ padded_shape,
5466 x_shape,
5567 w_shape,
5668 y_shape,
57- reinterpret_cast < uint64_t const *>(pads) ,
58- reinterpret_cast < int64_t const *>(strides) ,
59- reinterpret_cast < uint64_t const *>(dilations) ,
69+ pads_ ,
70+ strides_ ,
71+ dilations_ ,
6072 };
6173
6274 return STATUS_SUCCESS;
@@ -71,9 +83,13 @@ infiniopStatus_t cpuGetConvWorkspaceSize(ConvCpuDescriptor_t desc, uint64_t *siz
7183}
7284
7385infiniopStatus_t cpuDestroyConvDescriptor (ConvCpuDescriptor_t desc) {
86+ delete[] desc->padded_shape ;
7487 delete[] desc->x_shape ;
7588 delete[] desc->w_shape ;
7689 delete[] desc->y_shape ;
90+ delete[] desc->pads ;
91+ delete[] desc->strides ;
92+ delete[] desc->dilations ;
7793 delete desc;
7894 return STATUS_SUCCESS;
7995}
@@ -121,6 +137,7 @@ void _applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x,
121137
122138 // perform all the convolutions along this axis
123139 for (size_t i = 0 ; i < steps; ++i, ++y_index) {
140+ #pragma unroll
124141 // perform a single convolution
125142 for (size_t k = 0 ; k < kernel_size; ++k) {
126143 // calculate the current indices
@@ -129,7 +146,7 @@ void _applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x,
129146
130147 // base case (last dimension)
131148 if (ndim == desc->ndim - 1 ) {
132- if (desc-> dtype == F16 ) {
149+ if constexpr (std::is_same_v<Xdata, uint16_t > ) {
133150 y[y_index] += f16_to_f32 (x[curr_x_index]) * f16_to_f32 (w[curr_w_index]);
134151 } else {
135152 y[y_index] += x[curr_x_index] * w[curr_w_index];
@@ -173,12 +190,9 @@ void _conv_cpu(ConvCpuDescriptor_t desc, void *workspace, uint64_t workspace_siz
173190 Ydata *y, Xdata const *x, Xdata const *w) {
174191 if (desc->padded_x_size > 0 ) {
175192 auto padded_x = reinterpret_cast <Xdata *>(workspace);
176- std::vector<uint64_t > padded_shape_ (desc->ndim );
177- auto padded_shape = padded_shape_.data ();
178193 std::fill (padded_x, padded_x + desc->padded_x_size , 0 );
179- getPaddedShape (desc->ndim , desc->x_shape , desc->pads , padded_shape);
180- fillPaddedInput<Xdata>(desc, padded_shape, padded_x, x, desc->pads , 0 , 0 , 0 );
181- applyConv<Xdata, Ydata>(desc, y, padded_x, w, padded_shape);
194+ fillPaddedInput<Xdata>(desc, desc->padded_shape , padded_x, x, desc->pads , 0 , 0 , 0 );
195+ applyConv<Xdata, Ydata>(desc, y, padded_x, w, desc->padded_shape );
182196 } else {
183197 applyConv<Xdata, Ydata>(desc, y, x, w, desc->x_shape );
184198 }
0 commit comments