카테고리 없음

[2025-1] 박서형 - Distilling the Knowledge in a Neural Network

ejrwlfgksms skffkddl 2025. 2. 1. 12:47

https://arxiv.org/abs/1503.02531

 

Distilling the Knowledge in a Neural Network

A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Unfortunately, making predictions using a whole ensemble of models is cumbersome

arxiv.org

 

1. Introduction

machine learning 알고리즘의 성능을 향상시키는 일반적인 방법 중 하나는 동일한 데이터에 대해 여러 모델을 훈련시키고, 그 예측을 평균화하는 것이다. 그러나 이러한 ensemble 방법은 예측 시 많은 계산 자원을 소모한다는 문제가 있다. 그래서 본 논문에서는 Knowledge Distillation이라는 기법을 통해 복잡한 모델이나 앙상블 모델의 지식을 단일하고 간결한 모델로 압축하여, 성능을 유지하면서도 계산 효율성 높이고자 한다. 

 

 

2. Distillation 

Knowledge Distillatio은 크고 복잡한 신경망( Teacher Model )이 학습한 지식을 더 작은 신경망( Student Model ) 에게 전달하는 방법론이다. 이때 단순히 교사 모델의 예측값( hard target )만을 이용하는 것이 아니라 출력 확률 분포( soft target )을 활용하여 학생 모델이 더 풍부한 정보를 학습할 수 있도록 한다. 

 

1) teacher model T(x)를 학습시킨다.  -> 이때 cross entropy loss 사용

2) teacher model의 출력을 softmax 함수를 이용해 확률 분포로 변환

3) student model S(x)를 학습시킨다. 이때 두 가지 loss를 결합해서 사용한다.

- soft target loss : temperature T를 적용한 teacher model의 soft target과 student model의 taget 값 사이의 cross entropy loss

-hard target loss : 실제 label 기반 cross entropy loss

4) 두 손실을 결합하여 student model 학습

 

 

 

3. Temperature T

 

  • 이면 확률 분포가 부드러워짐(Soft Targets) → 클래스 간의 차이를 줄여 약한 클래스도 고려할 수 있음
  • T=1이면 일반적인 소프트맥스와 동일
  • T<1이면 원-핫 레이블과 가까워짐

Knowledge Distillation에서는 주로 T>1을 사용하여 teacher model이 다양한 class 관계를 학습하도록 한다.

 

 

 

 

4. logit matching

logit이 softmax를 통과해서 나온 값을 soft target이라 한다. 일반적인 knowledge distillation은 soft target과 hard target을 이용해 loss function을 정의한다면 knowledge ditillation의 특별한 케이스인 logit matching은 teacher model의 logit과 student model의 logit 사의 MSE를 loss function으로 정의한다. 

위의 식을 편미분하면 다음과 같다. 

이때 T가 충분히 크고  student model과 teacher model의 logit을 zero-mean 시킨다고 가정하면 다음의 식이 성립한다. 

 

위 식에 의하면 logit matching 방식은 student model과 teacher model의 logit 자체를 가까워지게 하는 방식이므로 일반적인 방식보다 좀 더 직접적으로 정보 손실 없이 학습할 수 있게 된다.