A Triton-based implementation of Fused Triangle Self Attention kernels. TriFast provides an optimized version triangle self attention, using the ideas of flash attention.
- Memory Efficient: Achieves n² memory complexity (compared to n³ for pure PyTorch implementations)
- High Performance:
- ⚡ ~4x faster forward pass than the next fastest implementation (DS4S evoformer kernel)
- ⚡ ~2x faster backward pass than the next fastest implementation (DS4S evoformer kernel)
- Multiple Precision Support: Works with float32, bfloat16, and fp16 data types
- GPU Accelerated: Benchmarked on NVIDIA GPUs with excellent scaling properties
- Auto-tuning: Includes built-in kernel tuning capabilities to optimize for specific workloads
All benchmarks were performed on an NVIDIA GeForce RTX 3090 GPU using BFloat16 precision.
pip install trifast
Basic usage of the triangle attention function:
import torch
from trifast import triangle_attention
from trifast.utils import gen_tensors
# Generate tensors (query, key, value, bias, mask)
q, k, v, bias, mask = gen_tensors(n, d, h, use_mask=True, device=device, dtype=dtype, std=scale)
# Apply triangle self attention
out = triangle_attention(q, k, v, bias, mask)
TriFast modifies triton.autotune
to cache the best config to disk. This means the first time the kernel is run will generally be slower than subsequent times (for a given input shape).
The package provides a command-line script for auto-tuning:
# Basic usage
trifast-tune
# For a more extensive tuning process
TRIFAST_FORCE_TUNE=1 trifast-tune
# With custom parameters
trifast-tune --min-n 32 --max-n 2048 --dtype bfloat16 --h 4,8 --d 32,64
--min-n Minimum sequence length to tune (default: 16)
--max-n Maximum sequence length to tune (default: 1024)
--dtype PyTorch datatypes to use (comma-separated).
Options: float32, bfloat16, float16, all (default: bfloat16)
--h List of number of heads (comma-separated integers, e.g., "1,2,4") (default: 4)
--d List of dimensions (comma-separated integers, e.g., "16,32,64") (default: 32)
After tuning, the best configurations are cached to disk using platformdirs
(e.g. under ~/.config/trifast/<version>
on Linux).
TriFast implements Triangle Self Attention using Triton to create optimized GPU kernels. It's essentially a Flash Attention applied to triangle self attention, resulting in significant performance gains compared to naive PyTorch implementations.
This implementation draws inspiration from:
- Explore performance optimizations for dq/db/dkv transposed operations
- Implement pipelined writes to global memory in backward kernels
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
- The FlagOpen team for their work on FlagAttention
- The Triton team for providing excellent documentation and tutorials
- The DS4S team for their evoformer kernel implementation used in benchmarking