As prompt lengths continue to increase, the computational and memory bandwidth demands of Large Language Models (LLMs) grow significantly, making efficient processing more challenging. However, by fully leveraging the inherent sparsity in attention patterns, we can optimize the model’s performance, effectively reducing inference costs in computation. This approach not only enhances the efficiency of LLMs but also enables them to handle longer and more complex prompts without a proportional increase in resource consumption. To this end, we introduce Block Sparse Attention, a library of sparse attention kernels that supports various sparse patterns, including streaming attention with token granularity, streaming attention with block granularity, and block-sparse attention. By incorporating these patterns, Block Sparse Attention can significantly reduce the computational costs of LLMs, thereby enhancing their efficiency and scalability.
We release the implementation of Block Sparse Attention, which is modified base on FlashAttention 2.4.2.
- [2024/10] We release both fwd pass and bwd pass of Block Sparse Attention.
We have four patterns supported in Block Sparse Attention:
-
dense attention
Calculate the full attention matrix.
-
streaming atteniton with token granularity
Calculate the attention with a fixed number of sink tokens and local tokens. You can refer to StreamingLLM for more details.
-
streaming attention with block granularity, block_size = 128
Calculate the attention with a fixed number of sink blocks and local blocks.
-
blocksparse attention, block_size = 128
Take in a block mask and calculate the attention with the block mask.
Importantly, we support assigning different patterns for different heads.
You can use head_mask_type
to specify the pattern for each head. This is a list of quiry head number of integers.
For one head, mask_type = 0
means dense attention, mask_type = -1
means streaming attention (either block streaming or exact streaming), and mask_type = 1
means blocksparse attention, the head will use basemask[mask_type - 1]
as its attention mask.
For example, if you have 8 heads and
head_mask_type = [1, 1, 0, 0, 0, -1, 0, -1]
This means head0, head1 use blocksparse mask, head2 to head4 and head 6 use dense mask, and head 5 and head 7 use streaming mask.
The interface is:
from block_sparse_attn import block_sparse_attn_func
block_sparse_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False,
return_attn_probs=False,
)
from block_sparse_attn import block_streaming_attn_func
block_streaming_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q, max_seqlen_k,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=True,
return_attn_probs=False,
)
from block_sparse_attn import token_streaming_attn_func
# bwd pass is not yet supported
token_streaming_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q, max_seqlen_k,
deterministic=False,
softmax_scale=None,
return_attn_probs=False,
)
The figures above illustrate the speedup gained by using Block Sparse Attention in comparison to dense FlashAttention2 2.4.2. This speedup was measured on an A100 GPU, with configurations including a head dimension of 128 and 32 attention heads.
Duo Attention introduces a hybrid mask scenario, where half of the attention heads utilize a dense mask and the other half employ a streaming mask. This pattern is also proved to be an accurate approach for LLMs inference.
The graph above demonstrates the performance of our kernel for this specified workload. For token-level streaming masks, we allocate 64 sink tokens and 256 local tokens. For block-level streaming masks, we allocate 1 sink block and 3 local blocks, with each block consisting of 128 tokens. Speedup results were measured on an A100 GPU, using dense FlashAttention2 as the baseline, with a head dimension of 128, 32 attention heads, and a batch size of 1.
Requirements:
- CUDA 11.6 and above.
- PyTorch 1.12 and above.
- Linux.
pip install packaging
pip install ninja
python setup.py install
Block Sparse Interface: block_sparse_attn/block_sparse_attn_interface.py
Block Sparse Attention currently supports:
- Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
- Head dimension 32, 64, 128.
To run the correctness tests:
pip install pytest
-
For fwd only
cd ./block_sparse_tests/fwd/test_correctness pytest full_test.py
-
For fwd and bwd
cd ./block_sparse_tests/fwd_bwd/test_correctness pytest full_test.py
To run the performance tests:
-
For fwd only
cd ./block_sparse_tests/fwd/test_performance/ python token_streaming.py python blocksparse.py
-
For fwd and bwd
cd ./block_sparse_tests/fwd_bwd/test_performance/ python block_streaming.py python blocksparse.py
Junxian Guo: SJTU, MIT | Haotian Tang: MIT |
Shang Yang: MIT | Zhekai Zhang: MIT |
Zhijian Liu: Nvidia, MIT | Song Han: Nvidia, MIT |
- FlashAttention: the codebase we built upon. Thanks for their wonderful work. The design of block sparse attention in FlashAttention v1.0 is very inspiring.
- FlashAttention, FlashAttention-2, Big Bird, ETC: get the idea of block sparse attention and how it can be implemented.
- StreamingLLM: get the idea of streaming attention.
- Duo Attention, MInference 1.0: get the idea of hybrid masks.
@misc{guo2024blocksparse,
author = {Guo, Junxian and Tang, Haotian and Yang, Shang and Zhang, Zhekai and Liu, Zhijian and Han, Song},
title = {{Block Sparse Attention}},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/mit-han-lab/Block-Sparse-Attention}}
}