본문 바로가기
  • 책상 밖 세상을 경험할 수 있는 Playground를 제공하고, 수동적 학습에서 창조의 삶으로의 전환을 위한 새로운 라이프 스타일을 제시합니다.
Computer Vision

[2025-1] 전연주 - GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

by YeonJuJeon 2025. 1. 31.

논문 링크: 2305.13245


1. Attention 개요

GQA를 이해하기 위해 Transformer 모델에서 사용되는 주요 Attention 기법을 정리한다.

Multi-Head Attention (MHA)

  • Transformer 모델의 핵심 구조로, Attention Is All You Need (2017) 논문에서 제안됨.
  • Query(Q), Key(K), Value(V)를 여러 개의 Head로 나누어 병렬 연산 수행.
  • 장점: 다양한 의미 표현을 학습할 수 있어 모델 성능 향상.
  • 단점: 메모리 사용량이 많고, 연산량이 크며, 병목 현상이 발생할 가능성이 있음.

KV Cached Attention

https://medium.com/@joaolages/kv-caching-explained-276520203249

  • AutoRegressive Inference에서 이전 token에 대한 Key-Value(KV) 연산을 저장하는 방식.
  • Without Cache:
    • 매번 과거 Query-Key 연산을 반복 수행해야 하므로 비효율적.
  • With Cache:
    • 이전에 계산된 Key를 저장해두고, 새로운 Query만 연산하여 속도를 최적화.
  • 단점: Sequence 길이가 길어질수록 KV Cache를 저장해야 하는 메모리 부담 증가.

Multi-Query Attention (MQA) - Fast Transformer Decoding (2019, Google)

  • Query는 Multi-Head로 유지하되, Key와 Value는 하나의 Head만 사용하는 방식.
  • 장점:
    • Inference 속도 대폭 향상.
    • 기존 MHA 대비 11배 빠른 처리량, 30% 더 낮은 Latency.
  • 단점:
    • MHA 대비 성능 저하 가능성.
    • 학습이 불안정할 위험이 있음.

2. GQA

GQA (Grouped-Query Attention)

  • MHA와 MQA의 중간 단계로 볼 수 있는 Attention 구조.
  • Query Head를 여러 개의 Group으로 나누고, 각 Group이 하나의 Key-Value Head를 공유.
  • GQA-G: Group 개수를 G로 설정한 GQA.
    • GQA-1 = MQA (Key-Value가 하나만 있음)
    • GQA-H = MHA (Head 개수와 동일한 Group 수)
  • 중간 개수의 Group을 설정하면, MQA 수준의 속도를 유지하면서도 MHA에 가까운 품질 달성.
  1.  

3. Method

Uptraining: Multi-Head 모델을 Multi-Query 모델로 변환

  1. Checkpoint 변환
    • 기존 MHA 체크포인트를 MQA 구조로 변환.
    • 기존 MHA 모델의 Key-Value Projection Matrix를 Mean Pooling하여 하나의 Key-Value Head로 변환.
    • 기존 Head 중 하나를 선택하거나, 새로 초기화하는 방식보다 Mean Pooling 방식이 더 높은 성능을 보임.
  2. 사전 학습 (Pre-training with α parameter)
    • 변환된 체크포인트를 원래 학습 방식과 동일하게 추가 학습.
    • 전체 학습 단계의 일부인 α(5%)만큼만 추가 학습하여 MQA 구조에 적응.

Grouped-Query Attention (GQA) 정리

  1. Query Head를 G개의 그룹으로 나누고, 각 그룹이 하나의 Key-Value Head를 공유.
  2. GQA의 그룹 수에 따라 성능과 속도 trade-off 가능.
  3. GQA-1 = MQA, GQA-H = MHA.
  4. GQA 체크포인트 변환:
    • 기존 MHA 체크포인트에서 각 그룹 내 Head를 Mean Pooling하여 새로운 Key-Value Head 생성.
  5. 성능 및 속도 trade-off 조절 가능:
    • MQA보다 높은 품질, MHA보다 빠른 속도 제공.
    • Memory bandwidth와 key-value cache 사용량이 모델 크기에 비례하여 감소.

4. Experiments

  • 데이터셋
    • 요약: CNN/Daily Mail, arXiv/PubMed, MediaSum, Multi-News.
    • 번역: WMT 2014 English-to-German.
    • 질의응답: TriviaQA.
  • 모델 구조
    • T5.1.1 아키텍처 기반 (JAX, Flax, Flaxformer 구현).
    • T5-Large, T5-XXL 모델 비교.
  • Uptraining 설정
    • 기존 T5 체크포인트를 초기화 후 α = 0.05 (5% 추가 학습).

Main Results

  1. Inference 속도 & 성능 비교
    • GQA는 MQA 수준의 속도를 유지하면서도, MHA에 가까운 성능을 유지.
    • MQA는 가장 빠르지만, 성능 저하 발생.
    • MHA는 가장 높은 성능을 유지하지만, inference 속도가 가장 느림.
  2. Checkpoint 변환 방식 비교
    • Mean Pooling 방식이 기존 Head 중 하나를 선택하는 방법보다 성능이 우수.
  3. Uptraining step 수 비교
    • 전체 학습 단계의 5%만 재학습해도 좋은 성능 유지.
  4. GQA 그룹 수 비교
    • GQA의 그룹 수를 조정하면 속도와 성능을 균형 있게 조절 가능.
    • G가 작을수록(MQA에 가까울수록) 속도 향상, G가 클수록(MHA에 가까울수록) 성능 향상.

5. Conclusion

  • GQA는 기존 MHA 모델을 Uptraining하여 MQA와 유사한 속도 향상과 MHA에 가까운 성능을 유지하는 방법을 제안.
  • MQA 대비 성능 저하를 완화하면서도, memory bandwidth와 key-value cache 크기를 줄이는 효과 제공.
  • 특히, 대형 모델에서는 inference 성능을 유지하면서도 메모리 절약이 가능하여 효율적인 trade-off를 제공.
  • 현재 LLaMA 2 등 최신 대형 언어 모델에서 표준적으로 적용되고 있음.

6. Key Takeaways

  • GQA는 MHA 대비 훨씬 빠른 속도를 제공하면서도, MQA 대비 성능 저하를 완화할 수 있는 효율적인 Attention 기법.
  • Inference 성능과 속도 간 최적의 Trade-off 제공.
  • 현재 대규모 Transformer 모델에서 필수적인 기술로 자리 잡음.