논문 링크: 2305.13245
1. Attention 개요
GQA를 이해하기 위해 Transformer 모델에서 사용되는 주요 Attention 기법을 정리한다.
Multi-Head Attention (MHA)
- Transformer 모델의 핵심 구조로, Attention Is All You Need (2017) 논문에서 제안됨.
- Query(Q), Key(K), Value(V)를 여러 개의 Head로 나누어 병렬 연산 수행.
- 장점: 다양한 의미 표현을 학습할 수 있어 모델 성능 향상.
- 단점: 메모리 사용량이 많고, 연산량이 크며, 병목 현상이 발생할 가능성이 있음.
KV Cached Attention
- AutoRegressive Inference에서 이전 token에 대한 Key-Value(KV) 연산을 저장하는 방식.
- Without Cache:
- 매번 과거 Query-Key 연산을 반복 수행해야 하므로 비효율적.
- With Cache:
- 이전에 계산된 Key를 저장해두고, 새로운 Query만 연산하여 속도를 최적화.
- 단점: Sequence 길이가 길어질수록 KV Cache를 저장해야 하는 메모리 부담 증가.
Multi-Query Attention (MQA) - Fast Transformer Decoding (2019, Google)
- Query는 Multi-Head로 유지하되, Key와 Value는 하나의 Head만 사용하는 방식.
- 장점:
- Inference 속도 대폭 향상.
- 기존 MHA 대비 11배 빠른 처리량, 30% 더 낮은 Latency.
- 단점:
- MHA 대비 성능 저하 가능성.
- 학습이 불안정할 위험이 있음.
2. GQA
GQA (Grouped-Query Attention)
- MHA와 MQA의 중간 단계로 볼 수 있는 Attention 구조.
- Query Head를 여러 개의 Group으로 나누고, 각 Group이 하나의 Key-Value Head를 공유.
- GQA-G: Group 개수를 G로 설정한 GQA.
- GQA-1 = MQA (Key-Value가 하나만 있음)
- GQA-H = MHA (Head 개수와 동일한 Group 수)
- 중간 개수의 Group을 설정하면, MQA 수준의 속도를 유지하면서도 MHA에 가까운 품질 달성.
3. Method
Uptraining: Multi-Head 모델을 Multi-Query 모델로 변환
- Checkpoint 변환
- 기존 MHA 체크포인트를 MQA 구조로 변환.
- 기존 MHA 모델의 Key-Value Projection Matrix를 Mean Pooling하여 하나의 Key-Value Head로 변환.
- 기존 Head 중 하나를 선택하거나, 새로 초기화하는 방식보다 Mean Pooling 방식이 더 높은 성능을 보임.
- 사전 학습 (Pre-training with α parameter)
- 변환된 체크포인트를 원래 학습 방식과 동일하게 추가 학습.
- 전체 학습 단계의 일부인 α(5%)만큼만 추가 학습하여 MQA 구조에 적응.
Grouped-Query Attention (GQA) 정리
- Query Head를 G개의 그룹으로 나누고, 각 그룹이 하나의 Key-Value Head를 공유.
- GQA의 그룹 수에 따라 성능과 속도 trade-off 가능.
- GQA-1 = MQA, GQA-H = MHA.
- GQA 체크포인트 변환:
- 기존 MHA 체크포인트에서 각 그룹 내 Head를 Mean Pooling하여 새로운 Key-Value Head 생성.
- 성능 및 속도 trade-off 조절 가능:
- MQA보다 높은 품질, MHA보다 빠른 속도 제공.
- Memory bandwidth와 key-value cache 사용량이 모델 크기에 비례하여 감소.
4. Experiments
- 데이터셋
- 요약: CNN/Daily Mail, arXiv/PubMed, MediaSum, Multi-News.
- 번역: WMT 2014 English-to-German.
- 질의응답: TriviaQA.
- 모델 구조
- T5.1.1 아키텍처 기반 (JAX, Flax, Flaxformer 구현).
- T5-Large, T5-XXL 모델 비교.
- Uptraining 설정
- 기존 T5 체크포인트를 초기화 후 α = 0.05 (5% 추가 학습).
Main Results
- Inference 속도 & 성능 비교
- GQA는 MQA 수준의 속도를 유지하면서도, MHA에 가까운 성능을 유지.
- MQA는 가장 빠르지만, 성능 저하 발생.
- MHA는 가장 높은 성능을 유지하지만, inference 속도가 가장 느림.
- Checkpoint 변환 방식 비교
- Mean Pooling 방식이 기존 Head 중 하나를 선택하는 방법보다 성능이 우수.
- Uptraining step 수 비교
- 전체 학습 단계의 5%만 재학습해도 좋은 성능 유지.
- GQA 그룹 수 비교
- GQA의 그룹 수를 조정하면 속도와 성능을 균형 있게 조절 가능.
- G가 작을수록(MQA에 가까울수록) 속도 향상, G가 클수록(MHA에 가까울수록) 성능 향상.
5. Conclusion
- GQA는 기존 MHA 모델을 Uptraining하여 MQA와 유사한 속도 향상과 MHA에 가까운 성능을 유지하는 방법을 제안.
- MQA 대비 성능 저하를 완화하면서도, memory bandwidth와 key-value cache 크기를 줄이는 효과 제공.
- 특히, 대형 모델에서는 inference 성능을 유지하면서도 메모리 절약이 가능하여 효율적인 trade-off를 제공.
- 현재 LLaMA 2 등 최신 대형 언어 모델에서 표준적으로 적용되고 있음.
6. Key Takeaways
- GQA는 MHA 대비 훨씬 빠른 속도를 제공하면서도, MQA 대비 성능 저하를 완화할 수 있는 효율적인 Attention 기법.
- Inference 성능과 속도 간 최적의 Trade-off 제공.
- 현재 대규모 Transformer 모델에서 필수적인 기술로 자리 잡음.