본문 바로가기
  • 책상 밖 세상을 경험할 수 있는 Playground를 제공하고, 수동적 학습에서 창조의 삶으로의 전환을 위한 새로운 라이프 스타일을 제시합니다.
카테고리 없음

[2025-1] 박서형 - Gradient Episodic Memory for Continual Learning

by ejrwlfgksms skffkddl 2025. 3. 8.

[1706.08840] Gradient Episodic Memory for Continual Learning

 

Gradient Episodic Memory for Continual Learning

One major obstacle towards AI is the poor ability of models to solve new problems quicker, and without forgetting previously acquired knowledge. To better understand this issue, we study the problem of continual learning, where the model observes, once and

arxiv.org

 

 

0. Abstract

 

AI는 새로운 과제를 수행하는 데 어려움을 보이는데 이걸 해결하기 위해 이전에 학습했던 지식을 잊지 않고 유지하도록 continual learning을 이용할 수 있다. 본 논문에서는 continual learning에서 사용할 수 있는 새로운 metrics와 새로운 모델인 Gradient Episodic Memory (GEM)을 제시한다. 

 

 

 

 

 

1. Introduction

 

1)

Supervised learning은 iid ( independently and identically distributed ) 데이터셋에서 feature vector x가 주어졌을 때 target vector y를 예측하는 함수 f를 학습한다. 이때 학습은 Empirical Risk Minimization (ERM)을 이용하여 손실함수 l을 최소화하는 방식으로 진행된다. 그리고 ERM은 데이터셋을 여러 번 반복하여 학습하는 것을 전제로 한다. 

 

 

그러나 이는 인간의 학습 방식과 중대한 차이가 있는데 우선 인간은 ordered sequence로 데이터를 관찰하며 같은 데이터를 두 번 이상 반복해서 보지 않고 한 번에 볼 수 있는 데이터 양도 제한적이다. 이러한 차이들로 인해 iid와 ERM을 동시에 적용하면 continual learning에서는 catastrophic forgetting 문제가 발생하게 된다. 이는 새로운 task를 학습할 때 이전에 학습한 내용을 잊어버리는 현상이다. 

 

 

 

2) continual learning

GEM 모델은 데이터를 하나씩 순차적으로 관찰하며 이를 continuum of data로 표현할 수 있다. 

이때 ti는 어떤 task t에 속하는지 표시하는 식별자로 task별로 데이터가 하나씩 들어오기 때문에 랜덤하게 샘플링되는 iid 가정을 따르지 않게 된다. 

 

 

 

 

 

 

2. A Framework for Continual Learning

 

1)

Continula learning에서는 전체 데이터셋으로 보면 iid를 따르지 않지만 task별로 보면 iid를 따르게 한다. 그리고 데이터를 학습할 때 현재의 task 뿐만 아니라 과거와 미래의 task에 대해서도 잘 수행하도록 학습하게 된다. 이때 task descriptor t가 풍부한 정보를 담고 있을수록 새로운 task를 zero shot으로 학습할 가능성이 높아지지만 본 논문에서는 제로샷보다는 Catastrophic Forgetting 문제 해결에 더 초점을 맞춘다. 

 

논문에서 제시하는 모델은 기존 방식과 다음과 같은 차이점을 가진다.

 

 

  • task 갯수가 작다. -> task 갯수가 크다.
  • 각 task 당 데이터 수가 많다. -> 각 task에서 주어지는 데이터 수가 작다.
  • 모델이 같은 데이터를 여러 번 반복 학습한다. -> 각 데이터를 한번만 학습한다.
  • 평가 지표로 전체 과업에서의 평균 정확도(accuracy)만을 사용한다. -> forgetting과 transfer learning 정도도 함께 평가한다. 

 

2) metrics

각 task에 대한 성능 평가 행렬을 R이라고 정의했을 때 다음 3개의 metrics를 정의할 수 있다.

  • Average Accuracy, ACC : 평균 정확도 
  • Backward Transfer, BWT : 이후에 학습한 task가 과거에 학습한 task에 미치는 영향 -> 부정적일수록 과거의 task에 대한 능력이 감소하는데 이걸 Catastrophic Forgetting이라고 한다.
  • Forward Transfer, FWT : 이전에 학습한 task가 미래의 task 학습에 도움을 주는 정도 -> 긍정적일수록 Zero-shot Learning이 가능하다고 본다. 

 

 

 

3. Gradient of Episodic Memory (GEM)

 

GEM에서는 continual learning에서 발생하는  Catastrophic Forgetting 문제를 해결하기 위해 Episodic Memory를 도입하여 과거의 task의 데이터 중 일부만 저장하여 모델을 업데이트 한다. 이때 총 메모리 용량 M이 주어졌고 Task 개수 를 알고 있다면 각 Task에 m=M/T 개의 샘플을 저장한다. 하지만 Task 개수를 모른다면, 새로운 Task가 등장할 때마다 각 Task의 샘플 개수를 점점 줄이는 방식을 사용하는데 이때 각 Task에서 가장 마지막에 등장한 개의 샘플을 저장하게 된다. 

 

 

GEM은 이전 task의 loss가 증가하지 않도록 하면서 현재 task를 학습하는 방식으로 작동한다. 이때 task k에서의 loss는 다음과 같이 정의된다. 

 

이때 이전 task의 loss들도 증가하지 않아야 하므로 다음을 만족해야 한다. 

 

이를 보장하기 위해 loss의 gradient 관계를 다음과 같이 변형할 수 있다.

 

위와 같은 조건들을 만족하면서 현재 손실을 최소화하기 위해 Quadratic Programming (QP) 문제를 정의할 수 있다. 

풀어야 할 문제
Quadratic programming
GEM QP

 

위의 식을 풀어서 이전 task의 성능을 유지하면서 현재의 task를 학습하는 최적의 gradient 업데이트 방향을 찾을 수 있다. 

위와 같은 방식의 gradient 업데이트 방식과 위에서 정의했던 평가방식들을 이용하여 최종적으로 GEM을 이용하여 학습하는 전체 알고리즘 구조는 다음과 같다. 

 

 

 

 

4. Conclusion

 

novelty  :

  • 연속 학습(Continual Learning) 시나리오를 공식화
  • 모델의 평가 기준으로 정확도(ACC)뿐만 아니라, 순전이(FWT)와 역전이(BWT) 개념을 정의하고 적용.
  • GEM (Gradient Episodic Memory)을 제안하여, 에피소드 메모리를 활용해 망각을 방지하고 긍정적인 역전이(Positive BWT)를 유도.
  • 실험 결과, GEM이 기존 최첨단(SoTA) 방법들과 비교해 경쟁력 있는 성능을 보임.

 

limitation:

 

Zero-Shot Learning, 고급 메모리 관리, 연산 최적화와 같은 부분에서 추가적인 연구가 필요함