카테고리 없음

[26-1] 김효민 - GQA: Training Generalized Multi-Query Transformer Models fromMulti-Head Checkpoints

ohne-reue 2026. 2. 28. 11:59

[Paper]

GQA : https://arxiv.org/abs/2305.13245

 

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Multi-query attention (MQA), which only uses a single key-value head, drastically speeds up decoder inference. However, MQA can lead to quality degradation, and moreover it may not be desirable to train a separate model just for faster inference. We (1) pr

arxiv.org

 

MQA : https://arxiv.org/abs/1911.02150

 

Fast Transformer Decoding: One Write-Head is All You Need

Multi-head attention layers, as used in the Transformer neural sequence model, are a powerful alternative to RNNs for moving information across and between sequences. While training these layers is generally fast and simple, due to parallelizability across

arxiv.org

 

 

 

오늘 읽어볼 논문에서는 Muti-Head Attention(MHA)를 추론 속도 측면에서 개선한 Grouped-Query Attention(GQA) 구조를 제안하고 있다. 논문의 흐름을 파악하기 전에 Muti-Head Attention(MHA), Muti-Query Attention(MQA), Grouped-Query Attention(GQA) 구조를 각각 짚어보도록 하자.

 

왼쪽부터 MHA, GQA, MQA의 구조

 

 

MHA는 우리가 흔히 알고 있는 Head 별로 각각의 Q,K,V를 사용하는 방식이다. h개의 헤드가 있다면 h개의 Q와, h개의 K와, h개의 V가 존재한다. 

 

MQA는 전체 Head에 대해 같은 K와 V를 attention 계산에 사용하는 것이다. 여기서 같은 K와 V란 Head 마다 같은 K와 V의 projection matrix를 사용했음을 의미한다.

 

GQA는 MHA과 MQA의 중간지점이라고 볼 수 있다. Head를 그룹지어서 같은 그룹 내의 Head끼리만 같은 K와 V를 사용하는 것이 GQA 구조이다. GQA-G (eg. GQA-8)은 G개의 그룹을 갖는 Grouped-Query Attention 구조이다. 이때, GQA-1은 MQA와 동일하며, GQA-h (헤드 수=그룹 수)는 MHA와 동일하다고 볼 수 있다.

 

 

 

1. Introduction

Transformer 모델의 autoregressive decoder의 attention 연산 과정에서는 KV-cache 값을 메모리에서 읽어오는 데에서 그 속도가 느려 병목 현상이 발생한다.

 

KV-cache의 작동방식 (이미지 출처 : https://medium.com/@joaolages/kv-caching-explained-276520203249)

※ KV-cache란?
KV-cache란 토큰 생성 시 계산되는 Key&Value 텐서를 메모리에 저장해 두었다가, 다음 토큰 생성 시에 불러와서 재사용하는 방법으로 이전 토큰의 Key&Value 텐서가 매 step마다 다시 계산되지 않게 하여 연산량을 줄이는 방법이다. KV-cahce를 이용해서 전반적인 연산량을 줄인다고 하더라도, Key와 Value의 값을 메모리에서 계속 읽어와야 하기 때문에 지연이 발생한다. 

 

GQA보다 앞서 등장한 MQA 구조를 활용한 모델에서는 Head마다 같은 Key와 Value를 사용해서  KV-cache로 인한 지연을 줄였다. 그러나, MQA에서는 성능 저하 및 훈련 불안정성의 문제가 있었다. 이는 Head마다 서로다른 Key와 Value를 사용하는 MHA와 속도 및 성능 관점에서 trade-off 관계에 있음을 시사한다.

 

 

 

2. Method

본 논문에서는 MQA를 두 가지 방법으로 개선하여 MQA에 비해서는 성능 품질이 좋고, MHA에 비해서는 추론 속도가 빠른 구조 GQA를 제안하였다.

 

1) Uptraining

첫 번째는 "Uptraining" 방식이다. 기존 MQA 구조를 제안한 논문에서는 모든 파라미터를 initialize하고 학습을 시작했다. 본 논문에서는 MHA의 checkpoint를 활용하여 MQA 및 GQA 구조의 학습을 이어나갈 수 있는 방법을 제안했다.

 

 

먼저, 기존 MHA에서 사용했던 여러개의 K와 V projection matrix들을 각각 Mean pooling하여 하나의 K와 V의 projection matrix로 만든다. 이렇게 만들어진 K와 V를 모든 Head에 적용하여 기존 학습 방식대로 α 비율만큼의 추가 학습을 진행한다 (eg. α=0.05면 기존 학습의 5%만큼 추가로 학습을 진행한다). GQA에서는 mean pooling이 그룹별로 진행되어 각 그룹별로 K와 V가 하나씩 남는다고 생각하면 된다.

 

K와 V의 projection matrix를 하나로 만드는 과정에 따른 성능

 

이 실험은 MQA에서 사용할 하나의 K와 V를 만드는 방식에 대한 ablation 실험 결과이다. Mean pooling 방식을 사용하면 기존 MQA를 제안했던 논문에서 사용한 Random한 방식보다 성능이 더 잘 나오는 것을 확인할 수가 있다. 또한 MHA에서 학습된 여러 헤드 중에 첫 번째 헤드의 K와 V를 골라서 사용하는 것보다 전체 헤드의 K와 V를 평균내어 사용하는 것의 성능이 더 좋음을 확인할 수 있다. 이 실험 결과는 MHA가 학습을 통해 가지게 된 정보량을 잃는 정도가 Random > First > Mean 순으로 크고, MHA의 정보량을 보존할수록 높은 성능을 보임을 시사한다.

 

 

2) GQA

두 번째는 "Grouped-Query Attention" 구조를 도입해서 MQA를 개선시켰다. 기존 MQA 논문에서는 전체 헤드가 모두 같은 K와 V를 사용했지만, GQA에서는 헤드를 그룹지어서 각 그룹별로 같은 K와 V를 사용하도록 했다. 

 

왼쪽부터 MHA, GQA, MQA의 구조

 

 

앞서 Fig.4에서 우리는 MHA의 정보를 어느 정도 보존하는 것이 유리함을 알 수 있었다. 이러한 맥락에서 MHA를 MQA 방식으로 Uptraining하는 것보다는 GQA 방식으로 Uptraining하는 것이 낫다고 볼 수 있다.

 

예를 들어 MHA → MQA로 전환할 때에는 H → 1 로 표현력이 감소한다.

그러나 MHA → GQA-G로 전환할 때에는 H → H/G 로 표현력이 감소한다.

 

즉, GQA 방식을 사용할 때 기존 MHA의 정보를 더 잘 유지할 수 있다고 해석할 수 있다. 추가로, 모델이 커질 때에 보통 Head의 수를 증가시키는 경우가 많은데, 그럴수록 H → 1가 더 극단적인 상황이 되기 때문에 모델이 클수록 MQA 보다는 GQA를 적용하는 것이 더 낫다.

 

다시말해 GQA는 MHA보다는 추론속도가 빠르고 MQA보다는 성능이 좋은, 성능과 속도의 trade-off에서 적절한 위치를 차지하는 모델이다.

                                                           [성능]   MHA   ↔   GQA   ↔   MQA   [속도]

 

 

 

3. Experiments

실험에서는 multi-head attention을 사용하는 T5 Large 와 T5 XXL와, T5 XXL을 바탕으로 Uptraining된 MQA, GQA 버전을 비교한다. 학습 조건은 기존 T5 학습 조건과 동일하며, MQA 및 GQA는 decoder의 MSA와 CA에만 적용하고 autoregressive하지 않는 encoder의 SA에는 적용하지 않는다.


CNN, arXiv, PubMed, MediaSum, Multi-News 요약 데이터셋과 WMT 2014 영어-독일어 번역 데이터셋, TriviaQA 질의응답 데이터셋 등이 fine-tuning 및 성능 평가에 활용되었다.

 

Fig.3의 실험 결과를 살펴보면, Time per sample(=요약 task로 따지자면 하나의 input을 요약해서 output을 만들어내는 과정. 추론 속도 정도로 해석할 수 있다.)의 시간이 MHA-XXL 모델 대비 GQA-8-XXL과 MQA-XXL에서 크게 감소한 것을 확인할 수 있다. 또한 GQA-8-XXL은 빠른 추론 속도에도 MHA-XXL 모델에 필적할만한 성능을 유지하고 있음을 확인할 수 있다. 같은 추론 속도를 가지는 MHA-Large 모델까지 함께 비교해보자면, MHA-Large에서 모델의 규모를 키워 MHA-XXL과 같이 더 성능 좋은 모델을 만들었을 때에는 추론 속도를 어느 정도 내려놓아야 했었는데, GQA는 성능과 추론 속도를 함께 잡았다고 볼 수 있다. MQA-XXL에 비해서도 GQA-8-XXL에서 더 성능이 좋음을 확인할 수 있다.

 

실험 결과 그래프

 

실험 결과 표

 

위 표는 그래프의 수치를 담고 있으며 데이터셋 별 성능도 함께 보여주고 있다. 보면 GQA-8-XXL이 MHA-XXL과 정말 근소한 성능 차이를 보이고 있으며 특정 데이터셋에서는 더 나은 성능을 보일 때도 있음을 확인할 수 있다.

 

+) Ablation Study

논문에서는 Uptraining 비율과 Group 수에 대한 실험도 추가로 진행했다.

Uptraining 비율에 대한 실험

 

Fig.5의 실험 결과를 보면 α=0일 때 즉 구조만 변환했을 때는 GQA가 MQA보다 더 품질을 잘 유지함을 확인할 수 있었다. 또한 α=0.05(구조 변환 후 5% 추가 학습 진행)일 때에는 GQA와 MQA에서 모두 성능이 크게 상승했으며 특히 GQA는 거의 MHA와 비슷한 수준의 성능을 보임을 확인할 수 있었다. 논문에서는 α=0.1일 때에는 GQA와 MQA에서 모두 성능이 큰 변화가 없으므로 5% 정도의 α로 Uptraining을 진행하는 것이 유용할 것이라고 말한다.

 

group 수에 대한 실험

 

Fig.6는 GQA-G에서 G 즉 group의 개수를 어느 정도로 설정하는 것이 좋을지에 대한 실험이다. 보면 G=8정도까지 추론속도가 유의미하게 감소하므로, 논문에서는 G=8 정도가 적절하다고 보았다. (앞선 비교 실험들도 GQA-8을 기준으로 진행되었다.)

 

 

4. 의의 및 한계

MHA보다는 빠른 추론 속도를 가지며 MQA보다는 성능이 좋은 구조를 제안했다는 데에 의의가 있다. 다만 논문에서는 평가 지표로 ROUGE를 사용했다는 점에서 성능 측면의 trade-off가 제대로 측정되었는지에 대해 확신하기 어렵다고 말하고 있다. 또한, MQA 방식에서는 Fig.4에서처럼 projection parameter를 초기화한 버전과 mean pooling한 버전을 비교했지만 컴퓨팅 자원의 한계로 GQA에서는 진행하지 못했음을 밝히고 있다. 또한 최근의 디코더 기반 모델이 아닌 ENC-DEC 기반 모델인 T5 모델에 대해서만 실험을 진행한 한계가 있음을 언급하고 있지만, 이에 대해서는 오히려 DEC 기반이기 때문에 GQA가 가지는 효과가 더 클 것이라고 말하고 있다.