/ AI

Multi-head Attention, GQA 개념 정리

NPU를 연구하게 되면서 등한시했던 AI 공부를 하는 수밖에 없게 되었습니다.

ㅠ _ ㅠ

특히 LLM의 핵심인 Attention 및 Transformer는 학부 수업때도 어렵고 헷갈리는 개념이었죠… 이번 글에서는 세부 메커니즘인 Multi-head Attention과 GQA (Grouped Query Attention)을 정리해보고자 합니다.

Multi-head Attention

Attention에서 head가 여러 개 있는 걸 쉽게 비유하면, 주어진 문장 하나를 각각 다른 관점에서 해석하는 머리(head)가 여러 개 있는 것과 같습니다.

문법을 따지는 머리, 감정을 따지는 머리, 주어-동사 관계를 따지는 머리가 따로 있는 것이죠.

예를 들어 모델의 model dimension $d_{model}$이 4096이라고 합시다. 그러면 인풋 토큰(의 임베딩 벡터) $X$의 길이는 4096이 됩니다. 그리고 $W_Q$의 차원은 $4096 \times 4096$이 됩니다. (GQA의 출력 복원과 Transformer 모델의 Residual Connection을 고려하여 일반적으로 $W_Q$는 Square Matrix입니다.) 이걸 통째로 써서 $X$와 가중치 $W_Q$를 곱해서 $Q$가 계산이 됩니다. 그 결과, $Q$는 $X$와 길이가 같습니다. Multi-head Attention은 이 $1 \times 4096$ 크기의 벡터를 예를 들면 128차원씩 32조각으로 쪼개서 이걸 32개의 머리가 따로 처리하도록 하는 것입니다.

Decode 과정에서 $X$와 가중치 $W_Q, W_K, W_V$를 곱해서 실시간으로 생성되는 $Q$, $K$, $V$ 텐서의 차원을 $d_q$, $d_k$, $d_v$라고 할 때, 토큰 하나당 생성되는 $Q, K, V$는 각각 $1 \times d_q$, $1 \times d_k$, $1 \times d_v$ 크기의 벡터입니다. 여기서 1은 방금 들어온 토큰 1개를 의미합니다.

그리고 여기서 헤드는 각각 row vector에 해당하는 $Q, K, V$ 텐서를 헤드 개수 $h$만큼 등분(column-wise로 split)을 해서 만들어진 것입니다. Attention score 계산 과정에서 $Q$와 $K$ 헤드 벡터의 내적이 필요하므로, $Q$와 $K$ 헤드의 차원은 반드시 동일해야 합니다.

한편 여기서 토큰 하나당 생성되는 $K, V$ 벡터를 과거 생성한 토큰 $n$개에 대해 모두 이어 붙여서 만든 $K, V$ 텐서의 차원은 $n \times d_k$, $n \times d_v$가 됩니다. 이것이 바로 Attention 계산에 사용되는 $K, V$ 텐서이며, 이 텐서를 저장한 것이 바로 KV Cache 입니다. 참고로 $d_k$와 $d_v$는 일반적으로 같습니다. 수학적으로는 달라도 되나, 메모리 파편화 등을 방지하기 위해서입니다.

여기서 Attention 계산 식을 다시 살펴보고 갑시다.

\[softmax (\frac{QK^T} {\sqrt{d_k}}) V\]

여기서 Attention score는 $QK^T$ 의 값입니다.

$V$ 헤드의 차원은 수학적으로 달라도 되지만, 현실에서는 하드웨어 최적화 등을 통해 무조건 동일하게 맞춥니다. 따라서 통일된 헤드의 차원을 $d_h$라고 하면, Multi-head가 된 후에는 $Q, K, V$ 텐서가 각각 $1 \times h_q \times d_h$, $1 \times h_{kv} \times d_h$ 크기의 텐서로 구성됩니다. 여기서 $h_q$는 $Q$헤드의 개수, $h_{kv}$는 $K, V$ 헤드의 개수입니다.

또한 헤드 하나에 대해 Attention을 계산하면 텐서의 차원은 $1 \times d_h$ 가 됩니다.

이때 하드웨어가 각각의 헤드끼리 병렬로 Matmul을 수행하려면 헤드 축이 맨 앞으로 와야 합니다. 그래서 내부적으로 transpose를 수행하여 $K,V$는 다음과 같은 모양이 됩니다.

\[h_{kv} \times 1 \times d_h\]

자 이제 이런 모양을 가진 새롭게 생성된 K,V row vector을 기존의 KV Cache에 갖다 붙이면 KV Cache가 업데이트 되는 것입니다.

원래의 KV Cache는, 헤드 개수까지 고려했을 때, 다음과 같은 모양을 가지고 있습니다.

$1 \times h_{kv} \times n \times d_h$

처음에 붙는 1은 batch size 입니다.

여기에다가 방금 생성된 $K, V$를 갖다가 붙이면?

$1 \times h_{kv} \times (n+1) \times d_h$

이런 모양이 되겠죠.

사실 여기까지 다룬 ‘KV Cache’는 K Cache 또는 V Cache입니다. K Cache와 V Cache 따로따로 존재하는데 생긴게 똑같으니 KV Cache라고 퉁쳐서 부르며, 하드웨어적으로는 실제로 두개를 붙여서 다룹니다.

따라서 K Cache와 V Cache를 묶은 진짜 KV Cache는 다음과 같은 모양이 됩니다.

$2 \times 1 \times h_{kv} \times (n+1) \times d_h$

Original Transformer vs GQA (Grouped Query Attention)

Attention head에서는 Q헤드와 K, V헤드가 있습니다. 원래의 transformer 구현은 K, V헤드를 공유한다는 개념이 없고, 각각의 Q헤드가 서로 다른 K, V헤드랑 일대일대응이 됩니다.

하지만 GQA는 여러 개의 Q헤드가 하나의 K, V헤드를 공유하는 방식입니다. Q헤드와 각각 K,V 헤드가 n:1 대응이 되며, attention score 계산 시 각각의 K, V헤드는 Q헤드의 개수만큼 재사용이 되겠죠.

예를 들어 Q헤드 개수가 32인 모델에서 8개의 Q헤드가 1개의 K, V헤드를 공유한다고 하면 K, V헤드 개수는 4개만 필요하게 되어 필요한 KV Cache 크기가 8배로 줄어듭니다.

그렇다면 K, V 헤드 개수를 어떻게 줄일 수 있을까요? 답은 $K, V$를 만들 때 사용되는 weight인 $W_K$, $W_V$의 차원을 줄이는 것입니다. $W_Q$는 $4096 \times 4096$ 크기로 만들더라도, $W_K, W_V$는 $4096 \times 512$로 설계하는 것이죠. 이러면 결과로 나오는 $K, V$ row vector의 크기가 줄어드니, 똑같은 head dimension으로 쪼개더라도 $Q$ 헤드보다 더 적은 개수의 $K,V$ 헤드를 얻게 되죠.

flowchart TD
    X["Input X (1 × 4096)"]
    
    subgraph Query Process
        WQ["W_Q (4096 × 4096)"]
        Q["Q Vector (1 × 4096)"]
        QHeads["32 Q Heads<br>(32 × 128)"]
    end
    
    subgraph Key / Value Process
        WKV["W_K, W_V (4096 × 512)"]
        KV["K, V Vector (1 × 512)"]
        KVHeads["4 K, V Heads<br>(4 × 128)"]
    end

    X -->|"×"| WQ -->|"="| Q
    X -->|"×"| WKV -->|"="| KV
    
    Q -->|"Split (÷ 128)"| QHeads
    KV -->|"Split (÷ 128)"| KVHeads
    
    QHeads -.->|"Grouped Query Attention<br>(8 Q Heads : 1 K,V Head)"| KVHeads
    
    style X fill:#f9f9f9,stroke:#333,stroke-width:2px
    style QHeads fill:#e1f5fe,stroke:#01579b,stroke-width:2px
    style KVHeads fill:#e8f5e9,stroke:#1b5e20,stroke-width:2px

즉 현재의 정보 ($Q$) 는 크게 하고, 비슷한 문맥을 다루는 과거의 정보($K, V$)는 작게 압축하여 메모리 효율성을 높이는 기술이 GQA라고 할 수 있겠습니다.