@@ -109,8 +109,8 @@ void inclusive_sum(
109109 stream);
110110}
111111
112- void random_sample_workspace (size_t &size_radix_sort, size_t &size_scan,
113- int voc, DT dtype) {
112+ infiniopStatus_t random_sample_workspace (size_t &size_radix_sort, size_t &size_scan,
113+ int voc, DT dtype) {
114114 if (dtype_eq (dtype, F16)) {
115115 sort_pairs_descending<half, uint64_t >(nullptr , size_radix_sort,
116116 nullptr , nullptr ,
@@ -121,6 +121,7 @@ void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
121121 nullptr , size_scan,
122122 nullptr , voc,
123123 nullptr );
124+ return STATUS_SUCCESS;
124125 } else if (dtype_eq (dtype, F32)) {
125126 sort_pairs_descending<float , uint64_t >(nullptr , size_radix_sort,
126127 nullptr , nullptr ,
@@ -131,6 +132,7 @@ void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
131132 nullptr , size_scan,
132133 nullptr , voc,
133134 nullptr );
135+ return STATUS_SUCCESS;
134136 } else if (dtype_eq (dtype, F64)) {
135137 sort_pairs_descending<double , uint64_t >(nullptr , size_radix_sort,
136138 nullptr , nullptr ,
@@ -141,8 +143,9 @@ void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
141143 nullptr , size_scan,
142144 nullptr , voc,
143145 nullptr );
146+ return STATUS_SUCCESS;
144147 } else {
145- throw std::invalid_argument ( " Unsupported dtype provided. " ) ;
148+ return STATUS_BAD_TENSOR_DTYPE ;
146149 }
147150}
148151__global__ void random_sample_kernel (uint64_t *result,
0 commit comments