논문 정리: FlashAttention
GPU가 워크로드를 처리하는 방법에 있어 I/O Aware한 최적화 알고리즘을 제안하여 처리 속도와 GPU 메모리 사용량 모두를 크게 개선한 성과를 낸 논문 FlashAttention에 대해 정리해 보고자 합니다. 굉장히 유명한 논문이며 현재는 AI워크로드에 필수적으로 널리 쓰이고 있는 기술이라 AI 관련 시스템 연구자는 꼭 읽어봐야 할 논문입니다.
Introduction
Transformer 구조는 시퀀스 길이가 길어지면 느리고 메모리 소모가 큽니다. Self-attention 레이어의 메모리 복잡도가 시퀀스 길이의 제곱에 비례해 증가하기 때문이죠. 기존의 Approximation attention method들은 모델 복잡도를 줄여서 compute complexity를 감소시키려는 시도를 했지만, 실질적인 wall-clock speedup을 달성해내진 못했습니다. 현대의 GPU들은 이미 computation 속도가 메모리 속도를 넘어섰고, transformer의 대부분 연산들은 memory access가 bottleneck이 되죠.
이 연구에서는 어텐션 알고리즘을 IO-aware하게 디자인해서, tiling 기법을 통해 GPU와 HBM간 메모리 읽기/쓰기 횟수를 줄이는 FlashAttention을 소개합니다. 주된 목표는 attention matrix의 HBM 왕복을 최대한 회피하는 것입니다. 이를 위해서는
-
Softmax 계산을 전체 input에 대한 접근 없이 계산하는 것
-
Backward pass에서 거대한 intermediate attention matrix를 (HBM에) 저장하지 않는 것
이 요구됩니다. 이 두 가지 조건을 충족하기 위해
-
Attention 연산의 구조를 재설계해서 입력을 여러 개의 block들로 나누고, 이 입력 블록들에 대해 여러 번의 pass를 거치게 해서 softmax reduction을 incremental하게 수행하는 것입니다. tiling 기법으로도 알려져 있습니다.
-
Forward pass에서 얻어진 softmax normalization factor를 저장해서, backward pass에서 attention을 on-chip에서 빠르게 재계산하도록 합니다. 이는 HBM으로부터 intermediate attention matrix를 읽어오는 표준적인 방법보다 더 빠릅니다.
이 연구는 또한 FlashAttention이 approximate attention algorithm을 설계할 때 참고하기에 매우 적합한 베이스 알고리즘이 될 수 있다고 합니다. 메모리 접근 오버헤드를 극복하는 특성 덕분이죠. 그 증명으로 이 FlashAttention을 block-sparse attention으로 확장하여 기존 approximate attention method보다 빠른 approximate attention 알고리즘을 제시합니다.
이 연구에서 FlashAttention을 평가한 결과 모델 트레이닝 속도를 향상시키고, 더 긴 context를 모델링함으로써 모델의 퀄리티도 향상시킨다고 합니다.
Background
중요한 것만 간단하게 정리했습니다. 자세한 내용은 논문을 참조하세요~
Hardware Performance
Performance Characteristics
연산과 메모리 접근의 비중에 따라 연산은 compute-bound 또는 memory-bound로 분류될 수 있습니다.
-
Compute-bound: 연산 시간이 산술 연산의 양에 따라 결정되고, 반면 HBM 접근 시간의 비중이 작은 연산들입니다. 전형적인 예는 큰 inner dimension을 가진 matrix multiplication, 다수의 채널을 가진 convolution이 있습니다.
-
Memory-bound: 메모리 접근량에 따라 시간이 결정되는 연산입니다. 다른 대부분의 연산들이 이에 속합니다: elementwise (예: activation, dropout), reduction (예: sum, softmax, batch norm, layer norm)
Kernel Fusion
Memory-bound 연산을 가속시키는 가장 전형적인 방법은 kernel fusion입니다. HBM으로부터 input을 한 번만 로드하고, 여러 개의 연산(여러 커널을 하나로 fusion)을 적용시키는 것이죠. 하지만 model training에서는 backward pass를 위해 중간값들이 HBM에 저장되어야 하고 단순한 kernel fusion의 효과를 경감시킵니다.
Standard Attention Implementation
주어진 input sequence들 $Q,K,V \in \Bbb{R} ^ {N \times d}$ ($N$은 시퀀스 길이, $d$는 head dimension)에 대해, attention output은 다음과 같이 계산을 합니다.
$S=QK^T \in \Bbb{R}^{N \times N}$, $P=softmax(S) \in \Bbb{R}^{N \times N}$, $O=PV \in \Bbb{R}^{N \times d}$
여기서 softmax는 row-wise로 적용됩니다.
표준 구현은 매트릭스 $S$와 $P$를 HBM에 materialize 하며, 이는 $O(N^2)$ 만큼 메모리를 먹습니다. 일반적으로 $N » d$ 입니다. 대부분의 연산이 memory-bound하기 때문에 많은 수의 메모리 접근이 실행 시간의 증가로 이어집니다. $S$에 적용된 masking이나 $P$에 적용된 dropout 등 다른 elementwise 연산들이 이 문제를 더 가중시킵니다. 표준 알고리즘은 다음과 같이 구현됩니다.
표준 Attention 구현
요구사항: $Q, K, V \in \Bbb{R} ^ {N \times d}$ 가 HBM에 저장되어 있음.
- HBM으로부터 $Q,K$ 블록을 로드하여 $S=QK^T$ 를 계산하고, $S$를 HBM에 씁니다.
- $S$를 다시 HBM에서 읽고, $P=softmax(S)$ 를 계산 후 $P$를 HBM에 씁니다.
- $P$와 $V$ 를 블록 단위로 HBM에서 읽고, $O=PV$ 계산 후, $O$를 다시 HBM에 씁니다.
- $O$를 리턴합니다.
여기서 표준 구현에서도 매트릭스를 블록 단위로 읽는다는 것이 FlashAttention의 고유 아이디어와 혼동될 수 있습니다만, SRAM의 크기가 작기 때문에 원래도 매트릭스를 SRAM에 대응시키기 위헤 block 단위로 쪼개는 것이 일반적입니다. 중대한 차이는, 표준 구현에서는 한 번의 연산이 끝날 때마다 그 결과를 HBM에 쓰는 것을 반복하지만 FlashAttention은 HBM에 쓰지 않고, 그대로 끝까지 SRAM 안에서 블록 단위의 연산을 수행한다는 것이죠. 다음 섹션에서 자세히 살펴봅시다.
FlashAttention: 알고리즘, 분석, 확장(Extensions)
Algorithm
앞서 언급했듯, tiling과 recomputation이라는 2개의 기법이 핵심이며, 메인 아이디어는 $Q,K,V$ 를 블록들로 나누고, 해당 블록들에 대해 attention output을 계산하는 것입니다. 각 블록의 output을 더하기 전 적절한 normalization factor로 스케일링함으로써 정확한 결과를 얻을 수 있습니다.
Tiling
수식은 생략하겠습니다.
원래 softmax는 $K$의 열들을 묶기 때문에 softmax를 decompose할 필요가 있습니다. 이를 위해서는 추가적인 통계값 2개 $(m(x), l(x))$ (softmax normalization statistics)를 관리하면 됩니다. 그러면 나누어진 block에 대한 softmax 값을 계산할 수 있으며, 마지막에 결과를 합칠 수 있습니다.
Recomputation
Backward pass에서는 $S, P \in \Bbb{R}^{N \times N}$ 매트릭스가 필요합니다. 그래야 $Q, K, V$에 대한 gradient를 계산할 수 있기 때문이죠. 이 매트릭스들을 직접 저장하는 대신, $O$와 softmax normalization statistics $m, l$을 저장함으로써 SRAM에 저장되어 있는 $Q, K, V$ 블록들로부터 $S$와 $P$를 쉽게 재계산해낼 수 있습니다.
Implementation Details: Kernel Fusion
Tiling 덕분에 단일 CUDA 커널에서 HBM 에서 입력 데이터 로드 -> 모든 연산 과정 수행(행렬곱, softmax, 선택적인 masking 및 dropout, 행렬곱) -> 결과를 HBM에 쓰기 라는 알고리즘의 구현이 가능합니다. 이는 HBM에 대한 반복적인 읽기/쓰기를 방지하죠.
알고리즘은 생략하겠습니다.
몇가지 짚고 넘어가자면, Block size를 $M/4d$로 하는 이유는 SRAM에 4개의 블록 $K_j$, $V_j$, $Q_i$, $O_i$ 가 동시에 존재해야 하고, 계산의 편의를 위해 $Q$ 블록 크기와 $K$ 블록 크기를 거의 같다가 가정했기 때문입니다.
Theorem 1. 이 FlashAttention 알고리즘은 $O = softmax(QK^T)V$를 $O(N^2d)$ FLOPs 만에 계산해내고 오직 $O(N)$ (아까 봤던 $m, l$)만의 추가 메모리를 필요로 합니다.
Analysis: IO Complexity of FlashAttention
이 섹션에서는 표준 attention과 비교해서, HBM 접근 수의 상당한 감소를 보여주고, lower bound를 제시함으로써 모든 SRAM 크기에 대해 어떠한 exact attention algorithm도 HBM 접근 횟수를 asymptotically하게 개선할 수 없음을 보입니다. 증명은 논문의 Appendix에 있습니다.
Theorem 2. $N$이 시퀀스 길이, $d$ 가 head dimension, $M$이 $d \le M \le Nd$ 를 만족하는 SRAM 크기라고 하자. 표준 attention 알고리즘은 $\Theta(Nd+N^2)$ 번의 HBM 접근을 요구하지만, FlashAttention은 오직 $\Theta(N^2d^2M^{-1})$ 번의 HBM 접근만 요구한다.
$d^2$의 값은 일반적으로 $M$보다 훨씬 작기 때문에, FlashAttention이 표준 구현보다 훨씬 HBM 접근 수를 적게 요구합니다. 위 식의 도출 과정은 생략하겠습니다.
SRAM의 범위가 저렇게 가정된 이유는 우선 최소 크기는 한 column(row) vector보다는 커야 SRAM이 최소한의 의미를 갖기 때문이고, 상한선은 입력 행렬 $Q,K,V \in \Bbb{R}^{N \times d}$ 전체의 크기에 해당하기 때문입니다.
다음은 Lower bound에 대한 설명입니다.
Proposition 3. (Lower bound) $N$이 시퀀스 길이, $d$ 가 head dimension, $M$이 $d \le M \le Nd$ 를 만족하는 SRAM 크기라고 하자. $O(N^2d^2M^{-1})$ 번의 HBM 접근으로 exact attention을 M의 모든 범위 $[d, Nd]$에 대해 계산해내는 알고리즘은 없다.
이 증명은 $M=\Theta(Nd)$에 대해 어떤 알고리즘이든 반드시 $\Omega(N^2d^2M^{-1})=\Omega(Nd)$ 번의 HBM 접근을 수행해야 한다는 사실에 의존합니다.
| Attention | Standard | FlashAttention |
|---|---|---|
| GFLOPs | 66.6 | 75.2 |
| HBM R/W (GB) | 35.3 | 4.4 |
| Runtime (ms) | 35.1 | 11.7 |
이 표를 보시면 FlashAttention이 더 많은 FLOP수를 갖고 있음에도 런타임은 훨씬 짧은 것을 확인할 수 있습니다. 그만큼 HBM R/W 수가 성능에 크게 영향을 미친다는 뜻입니다.
블록 크기를 키우면 그만큼 HBM을 왕복하는 데이터양(GB)도 줄어들지만($O, m, l$의 이동량이 줄어드니까), 그 대신 연산 등 다른 요소가 bottleneck이 되어 일정 수준 이상 runtime 감소가 이뤄지지 않습니다. 더욱이, SRAM의 크기 제한때문에 무한정 블록 크기를 늘릴 수도 없습니다.
Extension: Block-Sparse FlashAttention
이 섹션에서는 FlashAttention을 approximate attention 알고리즘인 block-sparse FlashAttention으로 확장합니다. 이 알고리즘은 FlashAttention보다 sparsity에 비례하여 IO 복잡도를 작게 가져갑니다.
이 알고리즘은 주어진 입력 $Q,K,V \in \Bbb{R}^{N \times d}$와 mask matrix $\tilde{M} \in {0,1}^{N \times N}$에 대해 다음을 계산하는 것입니다.
$S=QK^T \in \Bbb{R}^{N \times N}$, $P=softmax(S \odot \Bbb{1}_{\tilde{M}}) \in \Bbb{R}^{N \times N}$, $O=PV \in \Bbb{R}^{N \times d}$
여기서 $\tilde{M}{kl} = 1$일 때 $(S \odot \Bbb{1}{\tilde{M}}){kl}=S{kl}$ 이며 $\tilde{M}_{kl} = 0$ 일 때에는 $-\infty$입니다.
원래의 attention 계산과 비교해 볼까요?
$S=QK^T \in \Bbb{R}^{N \times N}$, $P=softmax(S) \in \Bbb{R}^{N \times N}$, $O=PV \in \Bbb{R}^{N \times d}$
softmax 계산식만 바뀌었다는 것을 확인할 수 있습니다.
FlashAttention을 적용하기 위해선, 이 식에서 $\tilde{M}$이 block 형태를 가져야 할 필요가 있습니다. 어떤 블록 크기 $B_r, B_c$, 모든 $k,l$에 대해, $\tilde{M}{k, l}={M}{ij}$ 이며 이때 어떤 $M \in {0,1}^{N/B_r \times N/B_c}$ 에 대해 $i=\lfloor k/B_r \rfloor, j=\lfloor l/B_c \rfloor$ 입니다.
미리 정의된 block sparsity mask $M \in {0, 1}^{N/B_r \times N/B_c}$가 주어지면 FlashAttention 알고리즘을 쉽게 적용하여 attention matrix의 nonzero block들만 계산할 수 있습니다. 즉 block-sparse FlashAttention은 zero block들을 스킵한다는 점만 빼면 FlashAttention의 알고리즘과 동일합니다.
IO complexity 분석은 생략하겠습니다.
Experiments
Transformer 모델을 트레이닝할 때 FlashAttention의 영향을 평가하였습니다.
실험 결과는 논문에서 확인하실 수 있습니다.
Limitations and Future Directions
Compiling to CUDA
이 논문에서 소개된 attention 알고리즘 구현을 위해서는 각각의 새로운 attention 구현에 대해 새로운 CUDA 커널 작성이 필요합니다. PyTorch보다 훨씬 더 low-level에서 알고리즘을 작성해야 하기 때문에 engineering effort가 상당히 많이 요구되죠. 다른 GPU 아키텍처로 옮기는 것도 불가능 할 수 있습니다. PyTorch같은 high-level에서 이러한 attention 알고리즘을 작성할 수 있고 CUDA로 컴파일할 수 있도록 지원할 방법이 필요합니다.
IO-Aware Deep Learning
Attention 레이어가 Transformer에서 가장 memory-intensive한 계산이지만 사실 모든 레이어가 HBM을 경유하기 때문에 다른 모듈들에도 IO-aware한 방법론이 적용되면 좋을 것입니다.
Multi-GPU IO-Aware Methods
여기서 제시된 FlashAttention은 단일 GPU에 최적화되어 있지만 attention 계산은 여러 GPU에 parallelizable하기 때문에 이에 대해 GPU 간 데이터 전송을 고려한 IO 분석도 따로 필요합니다.