📚 Split Q + QK Fine-grained Tiling
What's Changed
📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM, Headdim -> 1024
)
Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192)
can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, 📚 Split Q + Fully Shared QKV SMEM can achieve 55 TFLOPS (D=64) that almost ~1.5x 🎉 faster than FA2. Moreover, on NVIDIA L20, 📚 Split Q + QK Fine-grained Tiling can achieve 81 TFLOPS (D=512) that almost ~1.4x 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~
- Example: B=1, H=8, N=8192,
D=64
(NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
torch(unfused): ['-0.00514603 ', '0.05783081 ', '-0.00026727 '], time:20.999861ms, TFLOPS:6.67 (+0.00%)
mma(split-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.120730ms, TFLOPS:27.36 (+310.10%)
mma(split-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.004287ms, TFLOPS:28.00 (+2.33%)
mma(split-q+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.462291ms, TFLOPS:40.47 (+44.54%)
mma(split-q+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.658915ms, TFLOPS:38.30
mma(split-q+share-qkv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.551699ms, TFLOPS:54.91 (+35.69%)
mma(split-q+share-qkv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.532172ms, TFLOPS:55.34 (+0.77%)
mma(split-q+share-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.776575ms, TFLOPS:50.46
mma(split-q+share-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.596927ms, TFLOPS:53.96
(flash): ['-0.00516129 ', '0.05783081 ', '-0.00027728 '], time:3.776550ms, TFLOPS:37.10
----------------------------------------------------------------------------------------------------------------------------------
- Example: B=1, H=48, N=8192,
D=512
(RTX 3080), FA2 not supported,QK Tiling
Faster than SDPA~🎉🎉
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
mma(split-q+tiling-qk+stage1): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:48.775554ms, TFLOPS:22.60 (+0.00%)
mma(split-q+tiling-qk+stage2): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:47.503424ms, TFLOPS:23.20 (+2.68%)
(sdpa): ['-0.00438309 ', '0.02174377 ', '-0.01551056 '], time:66.486573ms, TFLOPS:16.58
----------------------------------------------------------------------------------------------------------------------------------
- Example: B=1, H=48, N=8192,
D=512
(NVIDIA L20), FA2 not supported,QK Tiling
Faster than SDPA~🎉🎉
python3 flash_attn_mma.py --B 1 --H 48 --D 512 --N 16384 --show-all --check --iters 10 --sdpa
-----------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10------------------------------------------
mma(split-q+tiling-qk+stage1): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:387.384224ms, TFLOPS:68.28 (+0.00%)
mma(split-q+tiling-qk+stage2): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:325.593209ms, TFLOPS:81.24 (+18.98%)
(sdpa): ['0.00790405 ', '-0.02330017 ', '0.00875854 '], time:452.067018ms, TFLOPS:58.51
----------------------------------------------------------------------------------------------------------------------------------
- 📚 Split Q + Fully Shared QKV SMEM (1/4 SRAM vs FA2)
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
// and reduce Q SMEM IO-Access.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM,
Headdim -> 1024
)
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
Full Changelog: v2.6.10...v2.6.11