|
16 | 16 | #include <ATen/ops/_addmm_activation_native.h>
|
17 | 17 | #include <ATen/ops/_efficientzerotensor.h>
|
18 | 18 | #include <ATen/ops/_scaled_mm_native.h>
|
| 19 | +#include <ATen/ops/_unsafe_view_native.h> |
19 | 20 | #include <ATen/ops/addmm_native.h>
|
20 | 21 | #include <ATen/ops/addmv_native.h>
|
21 | 22 | #include <ATen/ops/baddbmm_native.h>
|
@@ -369,12 +370,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
369 | 370 | }
|
370 | 371 |
|
371 | 372 | const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
|
372 |
| - IntArrayRef batch1_sizes = batch1.sizes(); |
373 |
| - |
374 | 373 | // handle pathological cases that blas may not like
|
375 | 374 | if (result.numel() == 0) {
|
376 | 375 | return result;
|
377 |
| - } else if (batch1_sizes[2] == 0) { |
| 376 | + } else if (batch1.size(2) == 0) { |
378 | 377 | if (beta.to<c10::complex<double>>() == 0.0) {
|
379 | 378 | return result.zero_();
|
380 | 379 | } else {
|
@@ -421,17 +420,30 @@ const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, co
|
421 | 420 | const scalar_t* batch1_ptr = batch1_->const_data_ptr<scalar_t>();
|
422 | 421 | const scalar_t* batch2_ptr = batch2_->const_data_ptr<scalar_t>();
|
423 | 422 | scalar_t* result_ptr = result_->mutable_data_ptr<scalar_t>();
|
424 |
| - at::cuda::blas::bgemm<scalar_t>( |
425 |
| - transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n', |
426 |
| - transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n', |
427 |
| - m, n, k, |
428 |
| - alpha_val, |
429 |
| - batch1_ptr, lda, batch1_->strides()[0], |
430 |
| - batch2_ptr, ldb, batch2_->strides()[0], |
431 |
| - beta_val, |
432 |
| - result_ptr, ldc, result_->strides()[0], |
433 |
| - num_batches |
434 |
| - ); |
| 423 | + const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n'; |
| 424 | + const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n'; |
| 425 | + // If batch is 1 call gemm rather than bgemm |
| 426 | + if (num_batches == 1) { |
| 427 | + at::cuda::blas::gemm<scalar_t>( |
| 428 | + transa, transb, |
| 429 | + m, n, k, |
| 430 | + alpha_val, |
| 431 | + batch1_ptr, lda, |
| 432 | + batch2_ptr, ldb, |
| 433 | + beta_val, |
| 434 | + result_ptr, ldc); |
| 435 | + } else { |
| 436 | + at::cuda::blas::bgemm<scalar_t>( |
| 437 | + transa, transb, |
| 438 | + m, n, k, |
| 439 | + alpha_val, |
| 440 | + batch1_ptr, lda, batch1_->strides()[0], |
| 441 | + batch2_ptr, ldb, batch2_->strides()[0], |
| 442 | + beta_val, |
| 443 | + result_ptr, ldc, result_->strides()[0], |
| 444 | + num_batches |
| 445 | + ); |
| 446 | + } |
435 | 447 | });
|
436 | 448 | if (!result.is_same(*result_)) {
|
437 | 449 | result.copy_(*result_);
|
|
0 commit comments