Skip to content

Commit 1a3e3c7

Browse files
atalmanmalfet
andauthored
[CUDA] baddmm should fall back to addmm for batch=1 (#114992) (#116518)
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1 After the change, benchmarking torch built with CUDA-12 using [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100 are as follows: | Shape | bmm_time | mm_time | slow down (%) | | -------------- | --------- | --------- | ------------- | | 1x1x4096 | 14.18 | 14.31 | -0.89 | | 1x1x8192 | 14.37 | 14.37 | -0.05 | | 1x1x16384 | 14.03 | 14.12 | -0.68 | | 1x1x32768 | 14.19 | 14.24 | -0.35 | | 1x1x65536 | 14.85 | 14.52 | 2.30 | | 1x1x131072 | 14.03 | 14.07 | -0.33 | | 128x128x128 | 11.34 | 11.06 | 2.56 | | 256x256x256 | 14.85 | 14.40 | 3.15 | | 512x512x512 | 27.22 | 27.22 | -0.01 | | 1024x1024x1024 | 129.66 | 129.50 | 0.12 | | 2048x2048x2048 | 972.18 | 973.24 | -0.11 | | 129x127x129 | 11.21 | 11.25 | -0.39 | | 257x255x257 | 14.50 | 14.43 | 0.44 | | 513x511x513 | 29.01 | 29.01 | 0.01 | | 1025x1023x1025 | 137.65 | 137.64 | 0.01 | | 2049x2047x2049 | 982.58 | 982.65 | -0.01 | | 4097x3x4097 | 86.65 | 86.64 | 0.01 | | 8193x3x8193 | 384.02 | 383.96 | 0.02 | | 16385x3x16385 | 1106.73 | 1107.32 | -0.05 | | 32769x3x32769 | 4739.49 | 4739.48 | 0.00 | | 65537x3x65537 | 17377.78 | 17378.74 | -0.01 | | 4097x5x4097 | 87.09 | 87.12 | -0.03 | | 8193x5x8193 | 301.38 | 301.36 | 0.01 | | 16385x5x16385 | 1107.38 | 1108.04 | -0.06 | | 32769x5x32769 | 4743.73 | 4744.07 | -0.01 | | 65537x5x65537 | 17392.32 | 17395.42 | -0.02 | | 4097x7x4097 | 87.17 | 87.19 | -0.02 | | 8193x7x8193 | 301.94 | 302.00 | -0.02 | | 16385x7x16385 | 1107.17 | 1106.79 | 0.03 | | 32769x7x32769 | 4747.15 | 4747.13 | 0.00 | | 65537x7x65537 | 17403.85 | 17405.02 | -0.01 | Fixes perf problem reported in #114911 Pull Request resolved: #114992 Approved by: https://github.com/Skylion007, https://github.com/eqy Co-authored-by: Nikita Shulga <[email protected]>
1 parent ab7505f commit 1a3e3c7

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

aten/src/ATen/native/cuda/Blas.cpp

+26-14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <ATen/ops/_addmm_activation_native.h>
1717
#include <ATen/ops/_efficientzerotensor.h>
1818
#include <ATen/ops/_scaled_mm_native.h>
19+
#include <ATen/ops/_unsafe_view_native.h>
1920
#include <ATen/ops/addmm_native.h>
2021
#include <ATen/ops/addmv_native.h>
2122
#include <ATen/ops/baddbmm_native.h>
@@ -369,12 +370,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
369370
}
370371

371372
const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
372-
IntArrayRef batch1_sizes = batch1.sizes();
373-
374373
// handle pathological cases that blas may not like
375374
if (result.numel() == 0) {
376375
return result;
377-
} else if (batch1_sizes[2] == 0) {
376+
} else if (batch1.size(2) == 0) {
378377
if (beta.to<c10::complex<double>>() == 0.0) {
379378
return result.zero_();
380379
} else {
@@ -421,17 +420,30 @@ const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, co
421420
const scalar_t* batch1_ptr = batch1_->const_data_ptr<scalar_t>();
422421
const scalar_t* batch2_ptr = batch2_->const_data_ptr<scalar_t>();
423422
scalar_t* result_ptr = result_->mutable_data_ptr<scalar_t>();
424-
at::cuda::blas::bgemm<scalar_t>(
425-
transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n',
426-
transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n',
427-
m, n, k,
428-
alpha_val,
429-
batch1_ptr, lda, batch1_->strides()[0],
430-
batch2_ptr, ldb, batch2_->strides()[0],
431-
beta_val,
432-
result_ptr, ldc, result_->strides()[0],
433-
num_batches
434-
);
423+
const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n';
424+
const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n';
425+
// If batch is 1 call gemm rather than bgemm
426+
if (num_batches == 1) {
427+
at::cuda::blas::gemm<scalar_t>(
428+
transa, transb,
429+
m, n, k,
430+
alpha_val,
431+
batch1_ptr, lda,
432+
batch2_ptr, ldb,
433+
beta_val,
434+
result_ptr, ldc);
435+
} else {
436+
at::cuda::blas::bgemm<scalar_t>(
437+
transa, transb,
438+
m, n, k,
439+
alpha_val,
440+
batch1_ptr, lda, batch1_->strides()[0],
441+
batch2_ptr, ldb, batch2_->strides()[0],
442+
beta_val,
443+
result_ptr, ldc, result_->strides()[0],
444+
num_batches
445+
);
446+
}
435447
});
436448
if (!result.is_same(*result_)) {
437449
result.copy_(*result_);

torch/testing/_internal/common_methods_invocations.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
skipCPUIfNoMklSparse,
2727
toleranceOverride, tol)
2828
from torch.testing._internal.common_cuda import (
29-
PLATFORM_SUPPORTS_FLASH_ATTENTION, SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
29+
PLATFORM_SUPPORTS_FLASH_ATTENTION, SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
3030
_get_torch_cuda_version, _get_torch_rocm_version,
3131
)
3232
from torch.testing._internal.common_utils import (
@@ -15937,9 +15937,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1593715937
op=lambda tensors, equation: torch.einsum(equation, tensors),
1593815938
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
1593915939
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
15940-
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16]
15941-
if (SM60OrLater or
15942-
TEST_WITH_ROCM) else []),
15940+
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
1594315941
supports_out=False,
1594415942
supports_forward_ad=True,
1594515943
supports_fwgrad_bwgrad=True,

0 commit comments

Comments
 (0)