Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023)
인용수: 2256 (25.02.23 기준)
논문 링크 : https://arxiv.org/pdf/2312.00752
https://blog.outta.ai/169
[2025-1] 김지원 - Efficiently Modeling Long Sequences with Structured State Spaces
논문 링크 Efficiently Modeling Long Sequences with Structured State Spaces특징 : ICRL 2022 Outstanding Paper, 인용 수 1578회 (2025-01-25 기준)코드: https://github.com/state-spaces/s4 GitHub - state-spaces/s4: Structured state space se
blog.outta.ai
Mamba 계열(Linear RNN 계열 or SSM 계열)의 모델의 기초가 되는 논문은 위 링크를 통해 간략히 이해할 수 있다.
1. 초록
SSM 계열의 모델은 content-based reasoning과 같은 태스크에서 약한 모습을 보인다.
이때 저자는 입력 값에 맞게 선택적으로 파라미터를 바꾸는 selective mechanism을 SSM 모델에 도입한다.
이어서 계산 효율적인 방식을 적용하기 위해 recurrent mode 내에서 병렬 연산 구조를 만들었다. (Scanning Method)
이러한 알고리즘(S6)을 어텐션이나 MLP block 없이 End-to-End로 만들었고 이를 Mamba 구조라고 한다.
결과적으로 Mamba는 Transformer와 동일한 크기에서 높은 계산 효율성과 성능을 보였다.
(참고로 기존 S4 모델에 selection mechanism과 scanning computation을 추가하였기 때문에 S6 알고리즘으로 불린다.)
(S4- Simplified Structured State Space Model)
2. 소개
기존 시퀀스 모델들은 단 하나의 백본 모델을 사용한다. 바로 Transformer이다.
하지만 Transformer 모델은 Context window의 길이가 제한적이며 Quadratic(제곱 형태)의 계산 복잡성을 가진다는 단점이 있다.
최근 Structured State Space Sequence Models (SSMs)가 새롭게 유망한 아키텍쳐로 등장했다.
(이때 Structured의 의미는 LTI 수식에서 A행렬의 HiPPO 행렬로의 초기화를 의미함)
이 유형의 모델들은 recurrence나 convolution 모드에서 매우 효율적으로 계산되며 장기 의존성 문제를 잘 다룬다.
하지만 이전 SSMs에 대한 연구들은 텍스트와 같은 이산적이고 정보 압축적인 데이터를 모델링할 때 덜 효율적이다.
따라서 저자는 다음 요소들은 기존 SSMs에서 추가하였다.
1) Selection Mechanism
- 입력 의존적인 방식으로 효율적으로 데이터를 선택하는 능력
- 관련 없는 데이터를 무시하고 관련 있는 데이터에 집중하는 것
- 트랜스포머의 경우 이러한 패턴 인식을 아주 잘함 (예시 - Induction Heads)
2) Hardware-aware Algorithm
- SSMs은 계산적 효율성을 위해 시간 불변성과 입력 불변성을 가질 수 밖에 없다.
- 이를 해결하기 위해 저자는 컨볼루션 대신 scan으로 hardware-aware algorithm를 구현하였고 따라서 입력과 시간에 의존적이면서 효율적으로 계산할 수 있게 되었다.
3) Architecture
- 저자는 이전 SSM 구조들과 Transformers의 MLP block을 혼합함으로써 모델 구조를 단순화시켰고 이는 simple & homogenous 구조인 Mamba로 이어진다.
즉 정리하면 이 모델은 3가지 장점이 있는데 selectivity로 인한 (1) High Quality 그리고 hardware-aware 알고리즘으로 인한 (2) Fast training and inference, 마지막으로 (3) Long context 에서의 성능 향상이다.
3. SSMs
SSM에 대해 다시 한번 복습해보자면 상태 공간 방정식은 일반적으로 수식 (1a)와 (1b)의 형태로 존재한다(연속적인 형태).
이를 여러가지 이산화 방법 (테일러, ZOH, Bilinear.. 등)을 통해 이산화하면 작은 시간 스텝 delta t에 대해 수식 (2a)와 (2b)로 이산화된다.
다양한 이산화 방법 중 ZOH로 한 결과는 수식 (4)와 같다.
이때 주목할 점은 이산화된 수식은 실제로 RNN의 게이트 메커니즘과 유사하다. $h_t=(1-\sigma(Linear(x_t))h_{t-1}+\sigma(Linear(x_t))x_t$
이제, 계산의 과정을 보면 첫 번째로, $(\Delta, A, B, C) \to (\bar A, \bar B, C)$로 변환하는 계산이 있고 이후에는 Linear Recurrence와 Global convolution이라는 두 가지 계산 방법이 있다.
일반적으로는 Global Convolution이 학습 시 병렬 계산을 할 수 있기 때문에 학습 시에 많이 사용되고 추론 시에는 Linear Recurrence가 효율적이기 때문에 이 방법을 사용한다.
(Global Convolution 수식은 3a, 3b 수식을 참고하면 된다.)
생각해보면 이 수식은 입력과 시간 순서에 따라 파라미터가 바뀌는 것이 아닌 계속 같은 가중치를 계산하고 있다.
이는 시간과 입력에 따라 파라미터가 달라지게 되면 효율적인 계산을 하기 어렵기 때문이다.
따라서 저자는 입력에 의해 Delta와 파라미터 B, C가 결정되도록 알고리즘을 수정했다.
(A는 왜 바꾸지 않을까? - 1) HiPPO 행렬로 초기화해야 하기 때문에 다른 행렬처럼 Sampling하면 안됨 2) Delta에 의해서 결국 이산화 과정에서 기존 S4 알고리즘과 달리 입력에 따른 변화가 생김)
이를 통해 저자의 첫 번째 Contribution인 selective mechanism을 기존 SSMs에 추가하였다.
저자의 두 번째 Contribution이었던 scan method의 경우, 기존 SSMs와 달리 selective mechanism으로 인해 항상 같은 행렬이 보장되지 않으므로 컨볼루션 연산을 적용할 수 없다.
이 경우에 Scan 방법을 사용하면 되지만 위 S6 알고리즘을 보면 이산화된 행렬 A, B가 (B, L, D, N) 차원으로 굉장히 메모리 소요가 크다는 것을 알 수 있다.
따라서 저자는 GPU의 HBM이 아닌 SRAM에서 이산화를 진행하도록 하여 메모리 소요를 FlashAttention을 실행한 Transformer와 동일하도록 낮출 수 있었다.
세 번째 Contribution인 Mamba Architecture에 대해, 저자는 H3라는 SSM 계열 모델의 가장 많이 쓰이는 아키텍쳐를 비교 대상으로 가져왔다.
H3는 선형 어텐션(SSM 계열 모델)과 MLP을 쌓아올린 형태이다.
저자는 GAU(Gated Attention Unit)에서 영감을 받아 이 둘을 쌓아 올리기 보단 하나의 모듈로 만들고자 했다.
이때 Mamba의 대부분의 연산은 Linear Projection에서 나오는 데 Linear Projection은 하이퍼파라미터인 expansion factor E에 의해 차원이 확장되고 축소된다.
즉 정리하면 Mamba는 H3와 달리 Gating Mechanism을 도입하여 모델의 동작을 단순화시켰으며 곱셈 연산에 비해 더 효율적인 계산이 진행된다.
4. Experiments
실험 결과 S6 알고리즘은 기존 알고리즘에 비해 월등한 성능을 보이며 같은 S6 알고리즘에 대해 Mamba라는 아키텍쳐가 H3보다 더 단순하게 만들었음에도 비슷한 성능을 보이므로 Mamba 아키텍쳐의 우수성을 보인다(Table 1).
놀라운 점은 입력에서 패턴을 찾는 Induction Heads 태스크에서 기존 학습 시퀀스보다 훨씬 긴 시퀀스에서도 일정한 성능을 보인다는 점이며(Table 2) 모델 크기에 따른 Perplexity 성능이 Transformer(GPT-3)를 능가하였고 Transformer++(Llama 등)과 비슷한 성능을 보였다(Table3).
즉 Mamba는 Attention을 활용하지 않은 아키텍쳐 중 Transformer++와 유사한 성능을 내는 첫 번째 아키텍쳐이다.
5. Downstream Tasks(Language, DNA, Audio)
자연어 처리 관련 Downstream task에서 동일 크기 Transformer 모델들보다 뛰어난 성능을 보였으며 자신보다 2배 더 큰 Transformer 모델에 대해서 비등한 성능을 보였다.
DNA 태스크에서는 Mamba가 Transformer++보다 뛰어난 성능을 보였으며 3배 더 많은 파라미터를 가진 Transformer++와 비슷한 성능을 보였다. 또한 매우 긴 파라미터에서도 일관적인 성능을 보였다.
시퀀스 분류 태스크에서는 시퀀스의 길이가 길어질 수록 압도적인 성능을 보였다.
오디오 분야에서도 작은 Mamba 모델이 훨씬 큰 SOTA 모델인 GAN 또는 Diffusion 기반 모델보다 더 높은 성능을 보였다(Table 4).
6. Ablation
모델의 성능은 S6라는 알고리즘에서 주로 오고 Mamba 아키텍쳐에서도 미세한 성능 향상을 보인다(Table 6).
또 S6 알고리즘 내에서 Delta에 대한 샘플링 방법이 가장 큰 성능 향상을 가져오는데 이는 Delta의 샘플링이 결과적으로 A, B, C 행렬에 모두 영향을 미칠 수 있기 때문일 것이다.
또한 HiPPO 행렬의 S4D 초기화 방법 중 허수 및 실수 방법들이 있는데 허수보다는 실수 초기화 방법이 대체적으로 좋은 모습을 보였다.
7. Personal Review
우선 언어 모델링을 비교하기에는 아직 그 파라미터의 규모가 너무 작고 Scanning 방법의 한계로 크게 활용하기 힘든 것 같다.
또, 시계열 예측은 잘하지만 분류에서는 약한 모습을 보이는 데 이에 대한 추가적인 연구가 필요해 보인다.