@@ -235,15 +235,15 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols,
235
235
CUDA_CHECK (cudaMemcpy (sorted_val_ptr, values_ptr, nnz * sizeof (scalar_t ),
236
236
cudaMemcpyDeviceToDevice));
237
237
238
- thrust::sort_by_key (thrust::device, //
239
- sorted_row_ptr, // key begin
240
- sorted_row_ptr + nnz, // key end
241
- thrust::make_zip_iterator ( // value begin
242
- thrust::make_tuple ( //
243
- sorted_col_ptr, //
244
- sorted_val_ptr //
245
- ) //
246
- ));
238
+ THRUST_CHECK ( thrust::sort_by_key (thrust::device, //
239
+ sorted_row_ptr, // key begin
240
+ sorted_row_ptr + nnz, // key end
241
+ thrust::make_zip_iterator ( // value begin
242
+ thrust::make_tuple ( //
243
+ sorted_col_ptr, //
244
+ sorted_val_ptr //
245
+ ) //
246
+ ) ));
247
247
LOG_DEBUG (" sorted row" , cudaDeviceSynchronize ());
248
248
} else {
249
249
sorted_row_ptr = row_indices_ptr;
@@ -481,10 +481,10 @@ coo_spmm_average(torch::Tensor const &rows, torch::Tensor const &cols,
481
481
CUDA_CHECK (cudaMemcpy (sorted_col_ptr, col_indices_ptr,
482
482
nnz * sizeof (th_int_type), cudaMemcpyDeviceToDevice));
483
483
484
- thrust::sort_by_key (thrust::device, //
485
- sorted_row_ptr, // key begin
486
- sorted_row_ptr + nnz, // key end
487
- sorted_col_ptr);
484
+ THRUST_CHECK ( thrust::sort_by_key (thrust::device, //
485
+ sorted_row_ptr, // key begin
486
+ sorted_row_ptr + nnz, // key end
487
+ sorted_col_ptr) );
488
488
489
489
// ///////////////////////////////////////////////////////////////////////
490
490
// Create vals
@@ -496,21 +496,20 @@ coo_spmm_average(torch::Tensor const &rows, torch::Tensor const &cols,
496
496
(scalar_t *)c10::cuda::CUDACachingAllocator::raw_alloc (
497
497
nnz * sizeof (scalar_t ));
498
498
torch::Tensor ones = at::ones ({nnz}, mat2.options ());
499
-
500
- // reduce by key
501
- auto end = thrust::reduce_by_key (
502
- thrust::device, // policy
503
- sorted_row_ptr, // key begin
504
- sorted_row_ptr + nnz, // key end
505
- reinterpret_cast <scalar_t *>(ones.data_ptr ()), // value begin
506
- unique_row_ptr, // key out begin
507
- reduced_val_ptr // value out begin
508
- );
509
-
510
- int num_unique_keys = end.first - unique_row_ptr;
511
- LOG_DEBUG (" Num unique keys:" , num_unique_keys);
512
-
513
- // Create values
499
+ int num_unique_keys;
500
+ try {
501
+ // reduce by key
502
+ auto end = thrust::reduce_by_key (
503
+ thrust::device, // policy
504
+ sorted_row_ptr, // key begin
505
+ sorted_row_ptr + nnz, // key end
506
+ reinterpret_cast <scalar_t *>(ones.data_ptr ()), // value begin
507
+ unique_row_ptr, // key out begin
508
+ reduced_val_ptr // value out begin
509
+ );
510
+ num_unique_keys = end.first - unique_row_ptr;
511
+ LOG_DEBUG (" Num unique keys:" , num_unique_keys);
512
+ } THRUST_CATCH;
514
513
515
514
// Copy the results to the correct output
516
515
inverse_val<th_int_type, scalar_t >
0 commit comments