Skip to content

Commit 2a41a66

Browse files
authored
replace at::cuda::getCurrentCUDASparseHandle with custom func (fix NVIDIA#308) (NVIDIA#315)
* force initialize cusparse handle * replace all at::cuda::getCurrentCUDASparseHandle with custom func * change log
1 parent 8910910 commit 2a41a66

8 files changed

+24
-7
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
- spmm average cuda function
66
- SparseTensor list operators (cat, mean, sum, var)
77
- MinkowskiStack containers
8+
- Replace all at::cuda::getCurrentCUDASparseHandle with custom getCurrentCUDASparseHandle (issue #308)
89

910
## [0.5.1]
1011

src/broadcast_gpu.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ at::Tensor BroadcastForwardGPU(
8888
torch::empty({in_feat.size(0), in_feat.size(1)}, in_feat.options());
8989

9090
auto stream = at::cuda::getCurrentCUDAStream();
91-
cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
91+
cusparseHandle_t handle = getCurrentCUDASparseHandle();
9292
cusparseSetStream(handle, stream);
9393

9494
AT_DISPATCH_FLOATING_TYPES(
@@ -158,7 +158,7 @@ std::pair<at::Tensor, at::Tensor> BroadcastBackwardGPU(
158158
const auto &in_outs = p_map_manager->origin_map(p_in_map_key);
159159

160160
auto stream = at::cuda::getCurrentCUDAStream();
161-
cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
161+
cusparseHandle_t handle = getCurrentCUDASparseHandle();
162162
cusparseSetStream(handle, stream);
163163

164164
AT_DISPATCH_FLOATING_TYPES(

src/global_pooling_gpu.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ std::tuple<at::Tensor, at::Tensor> GlobalPoolingForwardGPU(
135135
case PoolingMode::GLOBAL_AVG_POOLING_KERNEL: {
136136
const auto &in_outs = p_map_manager->origin_map(p_in_map_key);
137137
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
138-
cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
138+
cusparseHandle_t handle = getCurrentCUDASparseHandle();
139139
cusparseSetStream(handle, stream);
140140

141141
TemplatedAllocator<char> byte_allocator;

src/gpu.cu

+6
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ const char *cusparseGetErrorString(cusparseStatus_t error) {
115115
return "<unknown>";
116116
}
117117

118+
cusparseHandle_t getCurrentCUDASparseHandle() {
119+
cusparseHandle_t handle;
120+
CUSPARSE_CHECK(cusparseCreate(&handle));
121+
return handle;
122+
}
123+
118124
static std::string format_size(uint64_t size) {
119125
std::ostringstream os;
120126
os.precision(2);

src/gpu.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ const char *cublasGetErrorString(cublasStatus_t error);
163163
// CUSparse error reporting.
164164
const char *cusparseGetErrorString(cusparseStatus_t error);
165165

166+
cusparseHandle_t getCurrentCUDASparseHandle();
167+
166168
constexpr uint32_t CUDA_NUM_THREADS = 128;
167169

168170
constexpr uint32_t SHARED_BLOCK_SIZE = 32;

src/local_pooling_gpu.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ std::pair<at::Tensor, at::Tensor> LocalPoolingForwardGPU(
117117
num_nonzero.resize_({out_nrows});
118118
num_nonzero.zero_();
119119
}
120-
cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
120+
cusparseHandle_t handle = getCurrentCUDASparseHandle();
121121
cusparseSetStream(handle, stream);
122122

123123
AT_DISPATCH_FLOATING_TYPES(

src/local_pooling_transpose_gpu.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ std::pair<at::Tensor, at::Tensor> LocalPoolingTransposeForwardGPU(
111111
at::Tensor num_nonzero =
112112
torch::empty({0}, in_feat.options().requires_grad(false));
113113

114-
cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
114+
cusparseHandle_t handle = getCurrentCUDASparseHandle();
115115
cusparseSetStream(handle, stream);
116116

117117
AT_DISPATCH_FLOATING_TYPES(

src/spmm.cu

+10-2
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,11 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols,
186186

187187
// Dense matrices have to be contiguous for cusparseSpMM to work
188188
torch::Tensor const mat2_contig = mat2.contiguous();
189-
auto cusparse_handle = at::cuda::getCurrentCUDASparseHandle();
189+
// Issue 308
190+
// auto cusparse_handle = at::cuda::getCurrentCUDASparseHandle();
191+
auto stream = at::cuda::getCurrentCUDAStream();
192+
cusparseHandle_t cusparse_handle = getCurrentCUDASparseHandle();
193+
cusparseSetStream(cusparse_handle, stream);
190194

191195
torch::Scalar beta = 0;
192196
torch::Scalar alpha = 1;
@@ -442,7 +446,11 @@ coo_spmm_average(torch::Tensor const &rows, torch::Tensor const &cols,
442446

443447
// Dense matrices have to be contiguous for cusparseSpMM to work
444448
torch::Tensor const mat2_contig = mat2.contiguous();
445-
auto cusparse_handle = at::cuda::getCurrentCUDASparseHandle();
449+
// Issue 308
450+
// auto cusparse_handle = at::cuda::getCurrentCUDASparseHandle();
451+
auto stream = at::cuda::getCurrentCUDAStream();
452+
cusparseHandle_t cusparse_handle = getCurrentCUDASparseHandle();
453+
cusparseSetStream(cusparse_handle, stream);
446454

447455
torch::Scalar beta = 0;
448456
torch::Scalar alpha = 1;

0 commit comments

Comments
 (0)