Skip to content

Commit 6f14719

Browse files
committed
modified workspace return
1 parent 1809565 commit 6f14719

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

src/ops/random_sample/cuda/random_sample.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ void inclusive_sum(
109109
stream);
110110
}
111111

112-
void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
113-
int voc, DT dtype) {
112+
infiniopStatus_t random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
113+
int voc, DT dtype) {
114114
if (dtype_eq(dtype, F16)) {
115115
sort_pairs_descending<half, uint64_t>(nullptr, size_radix_sort,
116116
nullptr, nullptr,
@@ -121,6 +121,7 @@ void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
121121
nullptr, size_scan,
122122
nullptr, voc,
123123
nullptr);
124+
return STATUS_SUCCESS;
124125
} else if (dtype_eq(dtype, F32)) {
125126
sort_pairs_descending<float, uint64_t>(nullptr, size_radix_sort,
126127
nullptr, nullptr,
@@ -131,6 +132,7 @@ void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
131132
nullptr, size_scan,
132133
nullptr, voc,
133134
nullptr);
135+
return STATUS_SUCCESS;
134136
} else if (dtype_eq(dtype, F64)) {
135137
sort_pairs_descending<double, uint64_t>(nullptr, size_radix_sort,
136138
nullptr, nullptr,
@@ -141,8 +143,9 @@ void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
141143
nullptr, size_scan,
142144
nullptr, voc,
143145
nullptr);
146+
return STATUS_SUCCESS;
144147
} else {
145-
throw std::invalid_argument("Unsupported dtype provided.");
148+
return STATUS_BAD_TENSOR_DTYPE;
146149
}
147150
}
148151
__global__ void random_sample_kernel(uint64_t *result,

src/ops/random_sample/cuda/random_sample.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ typedef struct RandomSampleCudaDescriptor *RandomSampleCudaDescriptor_t;
1919
infiniopStatus_t cudaCreateRandomSampleDescriptor(CudaHandle_t handle,
2020
RandomSampleCudaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result,
2121
infiniopTensorDescriptor_t probs);
22-
void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
23-
int voc, DT dtype);
22+
infiniopStatus_t random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
23+
int voc, DT dtype);
2424
infiniopStatus_t cudaGetRandomSampleWorkspaceSize(RandomSampleCudaDescriptor_t desc, unsigned long int *size);
2525

2626
infiniopStatus_t cudaRandomSample(RandomSampleCudaDescriptor_t desc,

src/ops/random_sample/cuda/random_sample_cuda.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ infiniopStatus_t cudaCreateRandomSampleDescriptor(CudaHandle_t handle,
3333
infiniopStatus_t cudaGetRandomSampleWorkspaceSize(RandomSampleCudaDescriptor_t desc, unsigned long int *size) {
3434
size_t size_radix_sort;
3535
size_t size_scan;
36-
random_sample_workspace(size_radix_sort, size_scan,
37-
desc->voc, desc->dtype);
36+
infiniopStatus_t status = random_sample_workspace(size_radix_sort, size_scan,
37+
desc->voc, desc->dtype);
38+
if (status != STATUS_SUCCESS) {
39+
return status;
40+
}
3841
*size = desc->step + std::max(size_radix_sort, size_scan);
3942
return STATUS_SUCCESS;
4043
}

0 commit comments

Comments
 (0)