Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "convolution_backward_wgrad_implicit_gemm_cuda.h"
#include "../utils/memory.cuh"
#include <cuda_fp16.h>
Expand Down Expand Up @@ -1616,6 +1617,7 @@ at::Tensor conv_backward_wgrad_implicit_gemm_cuda(
torch::Tensor _out_in_map, const int split_k_iters,
bool allow_tf32, bool allow_fp16)
{
c10::cuda::CUDAGuard guard(_in_feats.device());
bool is_tf = allow_tf32;
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "convolution_backward_wgrad_implicit_gemm_sorted_cuda.h"
#include "../utils/memory.cuh"
#include <cuda_fp16.h>
Expand Down Expand Up @@ -1747,6 +1748,7 @@ at::Tensor conv_backward_wgrad_implicit_gemm_sorted_cuda(
torch::Tensor _reorder_loc, const int split_k_iters,
bool allow_tf32, bool allow_fp16)
{
c10::cuda::CUDAGuard guard(_in_feats.device());
bool is_tf = allow_tf32;
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Please consider citing the following paper when using the code:
*/

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -1979,12 +1980,12 @@ __global__ void fetch_on_demand_gemm_no_fusion_fp16(
// with unused weights having 0 and neighbor_offset[k^3/2]
// holding w[0,0].
at::Tensor conv_forward_fetch_on_demand_cuda(
at::Tensor in_feat, at::Tensor kernel,
at::Tensor neighbor_map, const int sum_nnz,
at::Tensor in_feat, at::Tensor kernel,
at::Tensor neighbor_map, const int sum_nnz,
at::Tensor neighbor_address, at::Tensor q_neighbor_address,
const int output_size, const int qsum_nnz, const bool transpose,
const int output_size, const int qsum_nnz, const bool transpose,
const bool allow_tf32, const bool allow_fp16) {

c10::cuda::CUDAGuard guard(in_feat.device());
// int sum_nnz = (int)torch::sum(neighbor_offset).item<int>();
int input_size = in_feat.size(0);
int in_channel = in_feat.size(1);
Expand Down Expand Up @@ -2135,10 +2136,10 @@ at::Tensor conv_forward_fetch_on_demand_cuda(

at::Tensor conv_forward_fetch_on_demand_no_fusion_cuda(
at::Tensor in_feat, at::Tensor kernel,
at::Tensor neighbor_map, at::Tensor neighbor_offset,
const int sum_nnz, const int output_size, const bool transpose,
at::Tensor neighbor_map, at::Tensor neighbor_offset,
const int sum_nnz, const int output_size, const bool transpose,
const bool allow_tf32, const bool allow_fp16){

c10::cuda::CUDAGuard guard(in_feat.device());
// int sum_nnz = (int)torch::sum(neighbor_offset).item<int>();
int input_size = in_feat.size(0);
int in_channel = in_feat.size(1);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "convolution_forward_implicit_gemm_cuda.h"
#include "../utils/memory.cuh"
#include <cuda_fp16.h>
Expand Down Expand Up @@ -1531,6 +1532,7 @@ at::Tensor conv_forward_implicit_gemm_cuda(
int num_out_feats, int num_out_channels,
bool allow_tf32, bool allow_fp16)
{
c10::cuda::CUDAGuard guard(_in_feats.device());
bool is_tf = allow_tf32;
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "convolution_forward_implicit_gemm_sorted_cuda.h"
#include "../utils/memory.cuh"
#include <cuda_fp16.h>
Expand Down Expand Up @@ -1759,6 +1760,7 @@ at::Tensor conv_forward_implicit_gemm_sorted_cuda(
int num_out_feats, int num_out_channels,
bool allow_tf32, bool allow_fp16)
{
c10::cuda::CUDAGuard guard(_in_feats.device());
bool is_tf = allow_tf32;
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -281,6 +282,7 @@ at::Tensor conv_forward_gather_scatter_cuda(
at::Tensor neighbor_offset, at::Tensor input_mask, at::Tensor output_mask,
const int output_size, const float epsilon, const int mm_thresh,
const int conv_mode, const bool transpose, at::Tensor global_buffer) {
c10::cuda::CUDAGuard guard(in_feat.device());
int buffer_size = (int)torch::sum(neighbor_offset).item<int>();
// be careful about the fallback setting

Expand Down Expand Up @@ -412,6 +414,7 @@ at::Tensor conv_forward_gather_scatter_cuda_latest(
at::Tensor neighbor_offset, at::Tensor input_mask, at::Tensor output_mask,
const int output_size, const float epsilon, const int mm_thresh,
const int conv_mode, const bool transpose, at::Tensor global_buffer) {
c10::cuda::CUDAGuard guard(in_feat.device());
if (in_feat.size(1) != _kernel.size(1)) {
throw std::invalid_argument("Input feature size and kernel size mismatch");
}
Expand Down Expand Up @@ -682,6 +685,7 @@ at::Tensor conv_forward_gather_scatter_cuda_fallback(
at::Tensor in_feat, at::Tensor kernel, at::Tensor neighbor_map,
const int output_size, const int conv_mode, at::Tensor neighbor_offset,
const bool transpose) {
c10::cuda::CUDAGuard guard(in_feat.device());
if (in_feat.size(1) != kernel.size(1)) {
throw std::invalid_argument("Input feature size and kernel size mismatch");
}
Expand Down Expand Up @@ -817,6 +821,7 @@ void conv_backward_gather_scatter_cuda(at::Tensor in_feat, at::Tensor grad_in_fe
at::Tensor grad_kernel, at::Tensor neighbor_map,
at::Tensor neighbor_offset,
const bool transpose) {
c10::cuda::CUDAGuard guard(in_feat.device());
grad_in_feat.resize_as_(in_feat);
grad_in_feat.zero_();
grad_kernel.resize_as_(kernel);
Expand Down
3 changes: 3 additions & 0 deletions torchsparse/backend/devoxelize/devoxelize_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <thrust/device_vector.h>
#include <torch/extension.h>

#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>

// input features (n, c), indices (N, 8), weight (N, 8) -> output features (N,
Expand Down Expand Up @@ -61,6 +62,7 @@ __global__ void devoxelize_backward_kernel(
at::Tensor devoxelize_forward_cuda(const at::Tensor feat,
const at::Tensor indices,
const at::Tensor weight) {
c10::cuda::CUDAGuard guard(feat.device());
int c = feat.size(1);
int N = indices.size(0);

Expand All @@ -82,6 +84,7 @@ at::Tensor devoxelize_forward_cuda(const at::Tensor feat,
at::Tensor devoxelize_backward_cuda(const at::Tensor top_grad,
const at::Tensor indices,
const at::Tensor weight, int n) {
c10::cuda::CUDAGuard guard(top_grad.device());
int c = top_grad.size(1);
int N = top_grad.size(0);
at::Tensor bottom_grad = torch::zeros(
Expand Down
3 changes: 3 additions & 0 deletions torchsparse/backend/hash/hash_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <stdlib.h>
#include <torch/torch.h>

#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include <vector>
// hashing
Expand Down Expand Up @@ -64,6 +65,7 @@ void hash_wrapper(int N, const int *data, int64_t *out) {
}

at::Tensor hash_cuda(const at::Tensor idx) {
c10::cuda::CUDAGuard guard(idx.device());
int N = idx.size(0);
at::Tensor out =
torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long));
Expand All @@ -73,6 +75,7 @@ at::Tensor hash_cuda(const at::Tensor idx) {

at::Tensor kernel_hash_cuda(const at::Tensor idx,
const at::Tensor kernel_offset) {
c10::cuda::CUDAGuard guard(idx.device());
int N = idx.size(0);
int K = kernel_offset.size(0);
at::Tensor out = torch::zeros(
Expand Down
2 changes: 2 additions & 0 deletions torchsparse/backend/others/count_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <stdlib.h>
#include <torch/torch.h>

#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include <vector>

Expand All @@ -23,6 +24,7 @@ void count_wrapper(int N, const int *data, int *out) {
// feat: (b,c,n) indices: (b,n) -> out: (b,c,s), out_indices: (b,n)
// (preprocessed indices)
at::Tensor count_cuda(const at::Tensor idx, const int s) {
c10::cuda::CUDAGuard guard(idx.device());
int N = idx.size(0);
at::Tensor out =
torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int));
Expand Down
3 changes: 2 additions & 1 deletion torchsparse/backend/others/downsample_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <torch/torch.h>

#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <cstdio>
#include <vector>
Expand Down Expand Up @@ -147,7 +148,7 @@ Idea: launch get_output_coords_kernel then inverse_transform_coords_kernel
at::Tensor downsample_cuda(at::Tensor _in_coords, at::Tensor _coords_max,
at::Tensor _coords_min, at::Tensor _kernel_sizes,
at::Tensor _stride, at::Tensor _padding) {

c10::cuda::CUDAGuard guard(_in_coords.device());
int N = _in_coords.size(0);
int kernel_volume = (int)(torch::prod(_kernel_sizes).item<int>());
int *in_coords = _in_coords.data_ptr<int>();
Expand Down
5 changes: 3 additions & 2 deletions torchsparse/backend/others/exclusive_scan_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <c10/cuda/CUDAGuard.h>
#include "exclusive_scan_cuda.h"

// to derive quantified address of activated features
Expand Down Expand Up @@ -29,9 +30,9 @@ __global__ void exclusive_scan_for_kernel_quantified(
}

at::Tensor exclusive_scan_quantified_wrapper(
const int k_vol, at::Tensor neighbor_offset,
const int k_vol, at::Tensor neighbor_offset,
at::Tensor neighbor_address, at::Tensor q_neighbor_address){

c10::cuda::CUDAGuard guard(neighbor_offset.device());
int *knnz_ptr = neighbor_offset.data_ptr<int>();
int *kpos_ptr = neighbor_address.data_ptr<int>();
int *qkpos_ptr = q_neighbor_address.data_ptr<int>();
Expand Down
4 changes: 4 additions & 0 deletions torchsparse/backend/others/query_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/torch.h>

#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include <iostream>
#include <vector>
Expand Down Expand Up @@ -33,6 +34,7 @@ __global__ void derive_bit_mask_from_out_in_map_kernel(int* out_in_map, int* bit
at::Tensor hash_query_cuda(const at::Tensor hash_query,
const at::Tensor hash_target,
const at::Tensor idx_target) {
c10::cuda::CUDAGuard guard(hash_query.device());
// return group_point_forward_gpu(points, indices);
int n = hash_target.size(0);
int n1 = hash_query.size(0);
Expand All @@ -49,6 +51,7 @@ at::Tensor hash_query_cuda(const at::Tensor hash_query,

void convert_transposed_out_in_map(const at::Tensor out_in_map,
at::Tensor out_in_map_t) {
c10::cuda::CUDAGuard guard(out_in_map.device());
convert_out_in_map_kernel<<<(out_in_map.size(0) * out_in_map.size(1) + 255) / 256, 256>>>(
out_in_map.data_ptr<int>(), out_in_map_t.data_ptr<int>(), out_in_map.size(0), out_in_map.size(1));
}
Expand All @@ -57,6 +60,7 @@ void convert_transposed_out_in_map(const at::Tensor out_in_map,


at::Tensor derive_bitmask_from_out_in_map(const at::Tensor out_in_map, const int split_mask_num, int valid_n) {
c10::cuda::CUDAGuard guard(out_in_map.device());
at::Tensor bitmask = torch::full(
{split_mask_num, out_in_map.size(0)}, -1, at::device(out_in_map.device()).dtype(at::ScalarType::Int));
derive_bit_mask_from_out_in_map_kernel<<<(split_mask_num * out_in_map.size(0) + 255) / 256, 256>>>(
Expand Down
2 changes: 2 additions & 0 deletions torchsparse/backend/others/reduce_bitmask_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "reduce_bitmask_cuda.h"


Expand Down Expand Up @@ -61,6 +62,7 @@ torch::Tensor reduce_bitmask_cuda(
torch::Tensor _bitmask_int,
int M_tile
){
c10::cuda::CUDAGuard guard(_bitmask_int.device());
if (M_tile % 4 != 0)
{
throw std::runtime_error("[Bitmask reduce] reduce tile size must be multiple of 4.");
Expand Down
3 changes: 2 additions & 1 deletion torchsparse/backend/others/reorder_map_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "reorder_map_cuda.h"

#define cta_M 128
Expand Down Expand Up @@ -26,7 +27,7 @@ at::Tensor reorder_out_in_map_cuda(
torch::Tensor _out_in_map,
torch::Tensor _reorder_loc
){

c10::cuda::CUDAGuard guard(_out_in_map.device());
int M = _out_in_map.size(0);
int kernel_volume = _out_in_map.size(1);
int split_mask_num = _reorder_loc.size(0);
Expand Down
6 changes: 6 additions & 0 deletions torchsparse/backend/others/sparsemapping_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <torch/torch.h>

#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <cstdio>
#include <vector>
Expand Down Expand Up @@ -261,6 +262,7 @@ std::vector<at::Tensor> build_kernel_map_subm_hashmap_int32(
at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max,
at::Tensor _kernel_sizes, at::Tensor _stride,
at::Tensor _padding, bool to_insert) {
c10::cuda::CUDAGuard guard(_in_coords.device());
int n_points = _in_coords.size(0);
int kernel_volume = (int)(torch::prod(_kernel_sizes).item<int>());
int *in_coords = _in_coords.data_ptr<int>();
Expand Down Expand Up @@ -305,6 +307,7 @@ std::vector<at::Tensor> build_kernel_map_subm_hashmap(
at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max,
at::Tensor _kernel_sizes, at::Tensor _stride,
at::Tensor _padding, bool to_insert) {
c10::cuda::CUDAGuard guard(_in_coords.device());
int n_points = _in_coords.size(0);
int kernel_volume = (int)(torch::prod(_kernel_sizes).item<int>());
int *in_coords = _in_coords.data_ptr<int>();
Expand Down Expand Up @@ -349,6 +352,7 @@ std::vector<at::Tensor> build_kernel_map_downsample_hashmap_int32(
at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max,
at::Tensor _kernel_sizes, at::Tensor _stride,
at::Tensor _padding, bool to_insert) {
c10::cuda::CUDAGuard guard(_in_coords.device());
int n_points = _in_coords.size(0);
int kernel_volume = (int)(torch::prod(_kernel_sizes).item<int>());
int *in_coords = _in_coords.data_ptr<int>();
Expand Down Expand Up @@ -431,6 +435,7 @@ std::vector<at::Tensor> build_kernel_map_downsample_hashmap(
at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max,
at::Tensor _kernel_sizes, at::Tensor _stride,
at::Tensor _padding, bool to_insert) {
c10::cuda::CUDAGuard guard(_in_coords.device());
int n_points = _in_coords.size(0);
int kernel_volume = (int)(torch::prod(_kernel_sizes).item<int>());
int *in_coords = _in_coords.data_ptr<int>();
Expand Down Expand Up @@ -510,6 +515,7 @@ std::vector<at::Tensor> build_kernel_map_downsample_hashmap(
std::vector<at::Tensor> build_mask_from_kmap(int n_points, int n_out_points,
at::Tensor _kmap,
at::Tensor _kmap_sizes) {
c10::cuda::CUDAGuard guard(_kmap.device());
int kernel_volume = _kmap_sizes.size(0);
auto options =
torch::TensorOptions().dtype(at::ScalarType::Int).device(_kmap.device());
Expand Down
5 changes: 5 additions & 0 deletions torchsparse/backend/voxelize/voxelize_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <stdlib.h>
#include <torch/torch.h>

#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <cmath>

Expand Down Expand Up @@ -78,6 +79,7 @@ __global__ void voxelize_backward_kernel(int N, int c, int s,
at::Tensor voxelize_forward_cuda(const at::Tensor inputs, const at::Tensor idx,
const at::Tensor counts)
{
c10::cuda::CUDAGuard guard(inputs.device());
int N = inputs.size(0);
int c = inputs.size(1);
int N1 = counts.size(0);
Expand All @@ -98,6 +100,7 @@ at::Tensor voxelize_backward_cuda(const at::Tensor top_grad,
const at::Tensor idx, const at::Tensor counts,
const int N)
{
c10::cuda::CUDAGuard guard(top_grad.device());
int c = top_grad.size(1);
int N1 = counts.size(0);

Expand All @@ -116,6 +119,7 @@ at::Tensor voxelize_backward_cuda(const at::Tensor top_grad,
void to_dense_forward_cuda(const at::Tensor inputs, const at::Tensor idx,
const at::Tensor range, at::Tensor outputs)
{
c10::cuda::CUDAGuard guard(inputs.device());
int N = inputs.size(0);
int c = inputs.size(1);

Expand All @@ -130,6 +134,7 @@ void to_dense_backward_cuda(const at::Tensor top_grad,
const at::Tensor idx, const at::Tensor range,
const at::Tensor bottom_grad)
{
c10::cuda::CUDAGuard guard(top_grad.device());
int N = bottom_grad.size(0);
int c = bottom_grad.size(1);

Expand Down