diff --git a/torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_cuda.cu b/torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_cuda.cu index 3804704..d2d3123 100644 --- a/torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_cuda.cu +++ b/torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_cuda.cu @@ -1,4 +1,5 @@ #include +#include #include "convolution_backward_wgrad_implicit_gemm_cuda.h" #include "../utils/memory.cuh" #include @@ -1616,6 +1617,7 @@ at::Tensor conv_backward_wgrad_implicit_gemm_cuda( torch::Tensor _out_in_map, const int split_k_iters, bool allow_tf32, bool allow_fp16) { + c10::cuda::CUDAGuard guard(_in_feats.device()); bool is_tf = allow_tf32; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); diff --git a/torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_sorted_cuda.cu b/torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_sorted_cuda.cu index d357bc0..5bf2abb 100644 --- a/torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_sorted_cuda.cu +++ b/torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_sorted_cuda.cu @@ -1,4 +1,5 @@ #include +#include #include "convolution_backward_wgrad_implicit_gemm_sorted_cuda.h" #include "../utils/memory.cuh" #include @@ -1747,6 +1748,7 @@ at::Tensor conv_backward_wgrad_implicit_gemm_sorted_cuda( torch::Tensor _reorder_loc, const int split_k_iters, bool allow_tf32, bool allow_fp16) { + c10::cuda::CUDAGuard guard(_in_feats.device()); bool is_tf = allow_tf32; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); diff --git a/torchsparse/backend/convolution/convolution_forward_fetch_on_demand_cuda.cu b/torchsparse/backend/convolution/convolution_forward_fetch_on_demand_cuda.cu index 05ea14e..c54347b 100644 --- a/torchsparse/backend/convolution/convolution_forward_fetch_on_demand_cuda.cu +++ b/torchsparse/backend/convolution/convolution_forward_fetch_on_demand_cuda.cu @@ -10,6 +10,7 @@ Please consider citing the following paper when using the code: */ #include +#include #include #include #include @@ -1979,12 +1980,12 @@ __global__ void fetch_on_demand_gemm_no_fusion_fp16( // with unused weights having 0 and neighbor_offset[k^3/2] // holding w[0,0]. at::Tensor conv_forward_fetch_on_demand_cuda( - at::Tensor in_feat, at::Tensor kernel, - at::Tensor neighbor_map, const int sum_nnz, + at::Tensor in_feat, at::Tensor kernel, + at::Tensor neighbor_map, const int sum_nnz, at::Tensor neighbor_address, at::Tensor q_neighbor_address, - const int output_size, const int qsum_nnz, const bool transpose, + const int output_size, const int qsum_nnz, const bool transpose, const bool allow_tf32, const bool allow_fp16) { - + c10::cuda::CUDAGuard guard(in_feat.device()); // int sum_nnz = (int)torch::sum(neighbor_offset).item(); int input_size = in_feat.size(0); int in_channel = in_feat.size(1); @@ -2135,10 +2136,10 @@ at::Tensor conv_forward_fetch_on_demand_cuda( at::Tensor conv_forward_fetch_on_demand_no_fusion_cuda( at::Tensor in_feat, at::Tensor kernel, - at::Tensor neighbor_map, at::Tensor neighbor_offset, - const int sum_nnz, const int output_size, const bool transpose, + at::Tensor neighbor_map, at::Tensor neighbor_offset, + const int sum_nnz, const int output_size, const bool transpose, const bool allow_tf32, const bool allow_fp16){ - + c10::cuda::CUDAGuard guard(in_feat.device()); // int sum_nnz = (int)torch::sum(neighbor_offset).item(); int input_size = in_feat.size(0); int in_channel = in_feat.size(1); diff --git a/torchsparse/backend/convolution/convolution_forward_implicit_gemm_cuda.cu b/torchsparse/backend/convolution/convolution_forward_implicit_gemm_cuda.cu index 65cc8b8..3b2ca03 100644 --- a/torchsparse/backend/convolution/convolution_forward_implicit_gemm_cuda.cu +++ b/torchsparse/backend/convolution/convolution_forward_implicit_gemm_cuda.cu @@ -1,4 +1,5 @@ #include +#include #include "convolution_forward_implicit_gemm_cuda.h" #include "../utils/memory.cuh" #include @@ -1531,6 +1532,7 @@ at::Tensor conv_forward_implicit_gemm_cuda( int num_out_feats, int num_out_channels, bool allow_tf32, bool allow_fp16) { + c10::cuda::CUDAGuard guard(_in_feats.device()); bool is_tf = allow_tf32; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); diff --git a/torchsparse/backend/convolution/convolution_forward_implicit_gemm_sorted_cuda.cu b/torchsparse/backend/convolution/convolution_forward_implicit_gemm_sorted_cuda.cu index a9f61f2..bbf60a0 100644 --- a/torchsparse/backend/convolution/convolution_forward_implicit_gemm_sorted_cuda.cu +++ b/torchsparse/backend/convolution/convolution_forward_implicit_gemm_sorted_cuda.cu @@ -1,4 +1,5 @@ #include +#include #include "convolution_forward_implicit_gemm_sorted_cuda.h" #include "../utils/memory.cuh" #include @@ -1759,6 +1760,7 @@ at::Tensor conv_forward_implicit_gemm_sorted_cuda( int num_out_feats, int num_out_channels, bool allow_tf32, bool allow_fp16) { + c10::cuda::CUDAGuard guard(_in_feats.device()); bool is_tf = allow_tf32; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); diff --git a/torchsparse/backend/convolution/convolution_gather_scatter_cuda.cu b/torchsparse/backend/convolution/convolution_gather_scatter_cuda.cu index ec342b8..b1d63cd 100644 --- a/torchsparse/backend/convolution/convolution_gather_scatter_cuda.cu +++ b/torchsparse/backend/convolution/convolution_gather_scatter_cuda.cu @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -281,6 +282,7 @@ at::Tensor conv_forward_gather_scatter_cuda( at::Tensor neighbor_offset, at::Tensor input_mask, at::Tensor output_mask, const int output_size, const float epsilon, const int mm_thresh, const int conv_mode, const bool transpose, at::Tensor global_buffer) { + c10::cuda::CUDAGuard guard(in_feat.device()); int buffer_size = (int)torch::sum(neighbor_offset).item(); // be careful about the fallback setting @@ -412,6 +414,7 @@ at::Tensor conv_forward_gather_scatter_cuda_latest( at::Tensor neighbor_offset, at::Tensor input_mask, at::Tensor output_mask, const int output_size, const float epsilon, const int mm_thresh, const int conv_mode, const bool transpose, at::Tensor global_buffer) { + c10::cuda::CUDAGuard guard(in_feat.device()); if (in_feat.size(1) != _kernel.size(1)) { throw std::invalid_argument("Input feature size and kernel size mismatch"); } @@ -682,6 +685,7 @@ at::Tensor conv_forward_gather_scatter_cuda_fallback( at::Tensor in_feat, at::Tensor kernel, at::Tensor neighbor_map, const int output_size, const int conv_mode, at::Tensor neighbor_offset, const bool transpose) { + c10::cuda::CUDAGuard guard(in_feat.device()); if (in_feat.size(1) != kernel.size(1)) { throw std::invalid_argument("Input feature size and kernel size mismatch"); } @@ -817,6 +821,7 @@ void conv_backward_gather_scatter_cuda(at::Tensor in_feat, at::Tensor grad_in_fe at::Tensor grad_kernel, at::Tensor neighbor_map, at::Tensor neighbor_offset, const bool transpose) { + c10::cuda::CUDAGuard guard(in_feat.device()); grad_in_feat.resize_as_(in_feat); grad_in_feat.zero_(); grad_kernel.resize_as_(kernel); diff --git a/torchsparse/backend/devoxelize/devoxelize_cuda.cu b/torchsparse/backend/devoxelize/devoxelize_cuda.cu index 62d6492..40004f6 100644 --- a/torchsparse/backend/devoxelize/devoxelize_cuda.cu +++ b/torchsparse/backend/devoxelize/devoxelize_cuda.cu @@ -3,6 +3,7 @@ #include #include +#include #include // input features (n, c), indices (N, 8), weight (N, 8) -> output features (N, @@ -61,6 +62,7 @@ __global__ void devoxelize_backward_kernel( at::Tensor devoxelize_forward_cuda(const at::Tensor feat, const at::Tensor indices, const at::Tensor weight) { + c10::cuda::CUDAGuard guard(feat.device()); int c = feat.size(1); int N = indices.size(0); @@ -82,6 +84,7 @@ at::Tensor devoxelize_forward_cuda(const at::Tensor feat, at::Tensor devoxelize_backward_cuda(const at::Tensor top_grad, const at::Tensor indices, const at::Tensor weight, int n) { + c10::cuda::CUDAGuard guard(top_grad.device()); int c = top_grad.size(1); int N = top_grad.size(0); at::Tensor bottom_grad = torch::zeros( diff --git a/torchsparse/backend/hash/hash_cuda.cu b/torchsparse/backend/hash/hash_cuda.cu index f846239..0b55828 100644 --- a/torchsparse/backend/hash/hash_cuda.cu +++ b/torchsparse/backend/hash/hash_cuda.cu @@ -2,6 +2,7 @@ #include #include +#include #include #include // hashing @@ -64,6 +65,7 @@ void hash_wrapper(int N, const int *data, int64_t *out) { } at::Tensor hash_cuda(const at::Tensor idx) { + c10::cuda::CUDAGuard guard(idx.device()); int N = idx.size(0); at::Tensor out = torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long)); @@ -73,6 +75,7 @@ at::Tensor hash_cuda(const at::Tensor idx) { at::Tensor kernel_hash_cuda(const at::Tensor idx, const at::Tensor kernel_offset) { + c10::cuda::CUDAGuard guard(idx.device()); int N = idx.size(0); int K = kernel_offset.size(0); at::Tensor out = torch::zeros( diff --git a/torchsparse/backend/others/count_cuda.cu b/torchsparse/backend/others/count_cuda.cu index 4860422..5b57c2b 100644 --- a/torchsparse/backend/others/count_cuda.cu +++ b/torchsparse/backend/others/count_cuda.cu @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -23,6 +24,7 @@ void count_wrapper(int N, const int *data, int *out) { // feat: (b,c,n) indices: (b,n) -> out: (b,c,s), out_indices: (b,n) // (preprocessed indices) at::Tensor count_cuda(const at::Tensor idx, const int s) { + c10::cuda::CUDAGuard guard(idx.device()); int N = idx.size(0); at::Tensor out = torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int)); diff --git a/torchsparse/backend/others/downsample_cuda.cu b/torchsparse/backend/others/downsample_cuda.cu index e3850e4..4a251eb 100644 --- a/torchsparse/backend/others/downsample_cuda.cu +++ b/torchsparse/backend/others/downsample_cuda.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -147,7 +148,7 @@ Idea: launch get_output_coords_kernel then inverse_transform_coords_kernel at::Tensor downsample_cuda(at::Tensor _in_coords, at::Tensor _coords_max, at::Tensor _coords_min, at::Tensor _kernel_sizes, at::Tensor _stride, at::Tensor _padding) { - + c10::cuda::CUDAGuard guard(_in_coords.device()); int N = _in_coords.size(0); int kernel_volume = (int)(torch::prod(_kernel_sizes).item()); int *in_coords = _in_coords.data_ptr(); diff --git a/torchsparse/backend/others/exclusive_scan_cuda.cu b/torchsparse/backend/others/exclusive_scan_cuda.cu index 37cd186..7440087 100644 --- a/torchsparse/backend/others/exclusive_scan_cuda.cu +++ b/torchsparse/backend/others/exclusive_scan_cuda.cu @@ -1,6 +1,7 @@ #include #include +#include #include "exclusive_scan_cuda.h" // to derive quantified address of activated features @@ -29,9 +30,9 @@ __global__ void exclusive_scan_for_kernel_quantified( } at::Tensor exclusive_scan_quantified_wrapper( - const int k_vol, at::Tensor neighbor_offset, + const int k_vol, at::Tensor neighbor_offset, at::Tensor neighbor_address, at::Tensor q_neighbor_address){ - + c10::cuda::CUDAGuard guard(neighbor_offset.device()); int *knnz_ptr = neighbor_offset.data_ptr(); int *kpos_ptr = neighbor_address.data_ptr(); int *qkpos_ptr = q_neighbor_address.data_ptr(); diff --git a/torchsparse/backend/others/query_cuda.cu b/torchsparse/backend/others/query_cuda.cu index 2eac8b7..02967dd 100644 --- a/torchsparse/backend/others/query_cuda.cu +++ b/torchsparse/backend/others/query_cuda.cu @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -33,6 +34,7 @@ __global__ void derive_bit_mask_from_out_in_map_kernel(int* out_in_map, int* bit at::Tensor hash_query_cuda(const at::Tensor hash_query, const at::Tensor hash_target, const at::Tensor idx_target) { + c10::cuda::CUDAGuard guard(hash_query.device()); // return group_point_forward_gpu(points, indices); int n = hash_target.size(0); int n1 = hash_query.size(0); @@ -49,6 +51,7 @@ at::Tensor hash_query_cuda(const at::Tensor hash_query, void convert_transposed_out_in_map(const at::Tensor out_in_map, at::Tensor out_in_map_t) { + c10::cuda::CUDAGuard guard(out_in_map.device()); convert_out_in_map_kernel<<<(out_in_map.size(0) * out_in_map.size(1) + 255) / 256, 256>>>( out_in_map.data_ptr(), out_in_map_t.data_ptr(), out_in_map.size(0), out_in_map.size(1)); } @@ -57,6 +60,7 @@ void convert_transposed_out_in_map(const at::Tensor out_in_map, at::Tensor derive_bitmask_from_out_in_map(const at::Tensor out_in_map, const int split_mask_num, int valid_n) { + c10::cuda::CUDAGuard guard(out_in_map.device()); at::Tensor bitmask = torch::full( {split_mask_num, out_in_map.size(0)}, -1, at::device(out_in_map.device()).dtype(at::ScalarType::Int)); derive_bit_mask_from_out_in_map_kernel<<<(split_mask_num * out_in_map.size(0) + 255) / 256, 256>>>( diff --git a/torchsparse/backend/others/reduce_bitmask_cuda.cu b/torchsparse/backend/others/reduce_bitmask_cuda.cu index d792b69..f757a0a 100644 --- a/torchsparse/backend/others/reduce_bitmask_cuda.cu +++ b/torchsparse/backend/others/reduce_bitmask_cuda.cu @@ -1,4 +1,5 @@ #include +#include #include "reduce_bitmask_cuda.h" @@ -61,6 +62,7 @@ torch::Tensor reduce_bitmask_cuda( torch::Tensor _bitmask_int, int M_tile ){ + c10::cuda::CUDAGuard guard(_bitmask_int.device()); if (M_tile % 4 != 0) { throw std::runtime_error("[Bitmask reduce] reduce tile size must be multiple of 4."); diff --git a/torchsparse/backend/others/reorder_map_cuda.cu b/torchsparse/backend/others/reorder_map_cuda.cu index 7f7b2ea..c8bc6af 100644 --- a/torchsparse/backend/others/reorder_map_cuda.cu +++ b/torchsparse/backend/others/reorder_map_cuda.cu @@ -1,4 +1,5 @@ #include +#include #include "reorder_map_cuda.h" #define cta_M 128 @@ -26,7 +27,7 @@ at::Tensor reorder_out_in_map_cuda( torch::Tensor _out_in_map, torch::Tensor _reorder_loc ){ - + c10::cuda::CUDAGuard guard(_out_in_map.device()); int M = _out_in_map.size(0); int kernel_volume = _out_in_map.size(1); int split_mask_num = _reorder_loc.size(0); diff --git a/torchsparse/backend/others/sparsemapping_cuda.cu b/torchsparse/backend/others/sparsemapping_cuda.cu index 49792d4..46dffc4 100644 --- a/torchsparse/backend/others/sparsemapping_cuda.cu +++ b/torchsparse/backend/others/sparsemapping_cuda.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -261,6 +262,7 @@ std::vector build_kernel_map_subm_hashmap_int32( at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max, at::Tensor _kernel_sizes, at::Tensor _stride, at::Tensor _padding, bool to_insert) { + c10::cuda::CUDAGuard guard(_in_coords.device()); int n_points = _in_coords.size(0); int kernel_volume = (int)(torch::prod(_kernel_sizes).item()); int *in_coords = _in_coords.data_ptr(); @@ -305,6 +307,7 @@ std::vector build_kernel_map_subm_hashmap( at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max, at::Tensor _kernel_sizes, at::Tensor _stride, at::Tensor _padding, bool to_insert) { + c10::cuda::CUDAGuard guard(_in_coords.device()); int n_points = _in_coords.size(0); int kernel_volume = (int)(torch::prod(_kernel_sizes).item()); int *in_coords = _in_coords.data_ptr(); @@ -349,6 +352,7 @@ std::vector build_kernel_map_downsample_hashmap_int32( at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max, at::Tensor _kernel_sizes, at::Tensor _stride, at::Tensor _padding, bool to_insert) { + c10::cuda::CUDAGuard guard(_in_coords.device()); int n_points = _in_coords.size(0); int kernel_volume = (int)(torch::prod(_kernel_sizes).item()); int *in_coords = _in_coords.data_ptr(); @@ -431,6 +435,7 @@ std::vector build_kernel_map_downsample_hashmap( at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max, at::Tensor _kernel_sizes, at::Tensor _stride, at::Tensor _padding, bool to_insert) { + c10::cuda::CUDAGuard guard(_in_coords.device()); int n_points = _in_coords.size(0); int kernel_volume = (int)(torch::prod(_kernel_sizes).item()); int *in_coords = _in_coords.data_ptr(); @@ -510,6 +515,7 @@ std::vector build_kernel_map_downsample_hashmap( std::vector build_mask_from_kmap(int n_points, int n_out_points, at::Tensor _kmap, at::Tensor _kmap_sizes) { + c10::cuda::CUDAGuard guard(_kmap.device()); int kernel_volume = _kmap_sizes.size(0); auto options = torch::TensorOptions().dtype(at::ScalarType::Int).device(_kmap.device()); diff --git a/torchsparse/backend/voxelize/voxelize_cuda.cu b/torchsparse/backend/voxelize/voxelize_cuda.cu index 0be3fff..3f2a47a 100644 --- a/torchsparse/backend/voxelize/voxelize_cuda.cu +++ b/torchsparse/backend/voxelize/voxelize_cuda.cu @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -78,6 +79,7 @@ __global__ void voxelize_backward_kernel(int N, int c, int s, at::Tensor voxelize_forward_cuda(const at::Tensor inputs, const at::Tensor idx, const at::Tensor counts) { + c10::cuda::CUDAGuard guard(inputs.device()); int N = inputs.size(0); int c = inputs.size(1); int N1 = counts.size(0); @@ -98,6 +100,7 @@ at::Tensor voxelize_backward_cuda(const at::Tensor top_grad, const at::Tensor idx, const at::Tensor counts, const int N) { + c10::cuda::CUDAGuard guard(top_grad.device()); int c = top_grad.size(1); int N1 = counts.size(0); @@ -116,6 +119,7 @@ at::Tensor voxelize_backward_cuda(const at::Tensor top_grad, void to_dense_forward_cuda(const at::Tensor inputs, const at::Tensor idx, const at::Tensor range, at::Tensor outputs) { + c10::cuda::CUDAGuard guard(inputs.device()); int N = inputs.size(0); int c = inputs.size(1); @@ -130,6 +134,7 @@ void to_dense_backward_cuda(const at::Tensor top_grad, const at::Tensor idx, const at::Tensor range, const at::Tensor bottom_grad) { + c10::cuda::CUDAGuard guard(top_grad.device()); int N = bottom_grad.size(0); int c = bottom_grad.size(1);