본문 바로가기
  • 책상 밖 세상을 경험할 수 있는 Playground를 제공하고, 수동적 학습에서 창조의 삶으로의 전환을 위한 새로운 라이프 스타일을 제시합니다.
Miscellaneous

[2025-1] 노하림 - Beyond Scalar Reward Model: Learning Generative Judge from Preference Data

by 리미61 2025. 3. 10.

https://arxiv.org/html/2410.03742v2

 

Beyond Scalar Reward Model: Learning Generative Judge from Preference Data

Beyond Scalar Reward Model: Learning Generative Judge from Preference Data Ziyi Ye1, Xiangsheng Li2, Qiuchi Li3, Qingyao Ai1, Yujia Zhou1, Wei Shen2, Dong Yan2, Yiqun Liu1 1Department of Computer Science and Technology, Tsinghua University 2Baichuan AI  

arxiv.org

Abstract

기존 방식에서는 preference feedback data로 학습하여 인코딩한다. 이때 LLM에 연결된 value head를 통해 단일 스칼라 점수를 출력하는 방식으로 평가된다. 하지만 스칼라 모델은 해석 가능성이 낮고 데이터셋의 편향에 취약하다는 문제가 있다. 본 논문에서는 LLM의 생성 능력을 활용해 위의 문제를 해결한다. 구체적으로, 사전 학습된 LLM에 프롬프트를 제공하여 긍정적 및 부정적 판단을 생성하도록 하고, 이를 각각 자연어 기반의 근거(rationale)로 뒷받침한다.

 

Introudction

일반적으로 스칼라 모델이 사용되며, LLM에 value head를 추가하여 단일 score로 응답을 평가한다. 하지만 (1) 해석 불가능 : 해당 답변이 더 나은지 설명 불가능 (2) 데이터 편향에 취약 : 학습 데이터의 편향을 그대로 반영함의 한계가 있다. 따라서 본 논문에서는 Con-J (Contrastive Judgment 기반 생성적 판별 모델)를 제안한다. LLM 자체가 직접 긍/부정적 판단을 생성하도록 유도하며, 이를 자연어 기반 근거와 제공한다. 

Con-J의 주요 단계는

  1. Judgment Sampling (판단 샘플링) : 사전 학습된 LLM에 질문 + 두 개의 답변을 주고 여러 개의 판단을 샘플링
  2. Judgment Filtering (판단 필터링) : 실제 선호 데이터를 사용하여 대조적 판단 쌍을 구성
  3. Con-J Training (Con-J 학습) : DPO 기법을 사용해 Con-J 학습. LLM이 선호도를 예측하면서 동시에 이유를 생성하도록 유도

으로 구성된다. 기존 방식은 외부 모델이나 고품질 데이터에 의존하지만 Con-J는 자체적으로 학습 데이털르 생성하여 추가적인 데이터나 보상 모델 없이 학습 가능하다. 

 

Improving Generative Judge by Training on Contrastive Judgments

SM(scalar model) 대신 LLM 자체가 선호 판단을 수행한다. 질문 \( q \) 와 두 개의 답변 \( a_1 \), \( a_2 \) 가 주어졌을 때 사전 안내문 preamble, 질문 $q$, 답변 $a_1$, $a_2$를 하나의 프롬프트 $p$로 구성한다. 이 사전 안내문(preamble) 은 LLM $\pi$ 가 판사 역할을 수행하도록 설명하는 지침을 포함한다.

LLM에게 instruction을 포함한 프롬프르 $p$를 제공한다. LLM이 자연어로 판단 \( j = \pi(p) \) 을 생성하며, 이 판단은 JSON 형식으로 구성된다.

 

  • "rationale" 키에는 답변에 대한 단계별 설명 및 검증 과정 이 포함
  • "better answer" 키에는 LLM이 결정한 선호 답변(이진 판단) 이 저장

 

Judgment Sampling

Repeated Sampling

동일한 입력 프롬프트를 여러 번 실행해서 다양한 판단을 얻는다. 단, 모델이 특정 답변만 지속적으로 선호하면 대조 판단을 만들 수 없다.

Hint-Driven Sampling

모델이 특장 답변을 선호하도록 힌트를 제공하여 다양한 판단을 생성한다.이를 통해 contrastive judgment pair를 생성 가능하다. 

 

Judgement Filtering

프롬프트 생성 및 판단 생성 (5~6행) 올바른 판단과 틀린 판단 필터링 (7행) Contrastive Judgment Pairs 저장 (8행) 힌트 기반 추가 샘플링 (9~11행)

반복 샘플링을 통해 긍/부정적 판단을 포함한 $M(p)$를 생성한다.

  • 긍정적 판단 ($j^+$): "better answer" 키의 값이 올바르게 판단된 경우
  • 부정적 판단 ($j^-$): 판단이 잘못되었거나 ($j^$), 모델이 명확한 선호도를 나타내지 않은 경우 ($j_n$)

대조적 판단 쌍 (contrastive judgment pairs) {(\( j^+, j^- \))}은 긍정적 판단 집합 \( M(p)^+ \) 와 부정적 판단 집합 \( M(p)^- \) 의 직접 곱(direct product)을 통해 구성된다. 반복 샘플링 횟수는 8회로 설정되며, 최적의 경우 4개의 긍정적 판단과 4개의 부정적 판단이 포함되어 최대 4쌍의 판단 쌍을 생성할 수 있다. 힌트 기반 샘플링 (hint-driven sampling) 에서는 LLM에게 하나는 올바른 힌트, 하나는 잘못된 힌트를 제공하여 하나의 대조적 판단 쌍을 추가로 생성한다.

Con-J Training

Con-J 모델은 대조 학습(contrastive learning)과 지도 학습(SFT)을 결합하여 학습된다. DPO 손실을 적용하고 SFT 손실과 결합하여 과적합을 방지한다. 

 

DPO 손실 함수

\[\ell_{DPO} = - \sum_{(p,j+,j-)} \log \sigma \left( \eta \log \frac{\pi(j+ | p)}{\pi_0(j+ | p)} - \eta \log \frac{\pi(j- | p)}{\pi_0(j- | p)} \right)\]

  • 여기서 $π0$ 는 사전 학습된 LLM (변경되지 않음)
  • π(j+ | p)는 Con-J가 정답(j+)을 선택할 확률
  • DPO는 정답과 오답을 명확히 구분하도록 학습

SFT 손실 함수

\[\ell_{SFT} = - \sum_{(p,j+)} \log \pi(j+ | p)\]

  • 올바른 판단(j+)을 모방하도록 지도 학습 적용

최종 손실 함수

\[\ell_{final} = \ell_{DPO} + \alpha \cdot \ell_{SFT}\]

  • DPO 중심으로 학습하되, SFT를 추가하여 안정성 유지

DPO training promotes distinguishing between answers

기존 판별 모델은 SFT 방식으로 학습된다. 하지만 SFT만으로는 충분하지 않아 DPO로 모델이 정확한 선호 판단을 학습하도록 유도한다. LLM은 올바른 판단을 단순히 모방하는 것이 아니라 판단에서 중요한 요소를 파악해야 한다. 예를 들어, 다음 두 문장은 같은 패턴을 따르지만 서로 반대되는 의미를 가진다.

  1. "답변 1에는 논리적 오류가 있으므로, 더 나은 답변은 2이다."
  2. "답변 2에는 논리적 오류가 있으므로, 더 나은 답변은 1이다."

하지만 SFT 기반 학습에서는 이러한 패턴을 그대로 학습할 가능성이 크다. 

Rationales bring robustness against bias

Con-J 모델은 논리적 근거도 함께 생성할 수 있다. 판단 $j$를 선호 결과 ($j_y$)와 논리적 근거 ($j_r$)로 분리하여 데이터 편향을 감소한다. 모델이 \( j_y \) 뿐만 아니라 \( j_r \) 도 예측하도록 학습하여 편향을 줄이면서 신뢰할 수 있는 판단을 수행한다. 

 

 

기존 모델(SM)의 한계

1. 판단 근거가 없음 → 단순 점수($r$) 비교 방식

2. 데이터 편향에 취약 → 특정 데이터 패턴을 쉽게 학습하여 왜곡 가능

3. 판단 정확도가 낮음 → 선호 판단이 애매할 경우 신뢰도 문제 발생

 

Con-J의 장점

1. 자연어 기반 판단 생성 → 단순 점수 대신 이유(rationale)를 포함한 판단 제공

2. 편향 완화 효과 → 판단($j$)을 판단 근거($j_r$) + 선호 결과($j_y$) 로 분리하여 데이터 편향 감소

3. 기존 LLM 구조 활용 → 새로운 분류 헤드를 추가하지 않아 더 안정적이고 강력한 정규화 효과

 

Generative judge resists bias with a better prior

Con-J는 LLM의 사전 학습된 지식을 활용하여 편향 저항 효과를 가진다. 기존 SM은 새롭게 추가된 분류 헤드가 편향을 쉽게 학습하는 반면 Con-J는 기존 LLM의 사전 훈련된 가중치(θ)를 유지한다. 

 

Experiments

Setup

데이터셋 (Datasets) 훈련 데이터: Creation (120K), Math (50K), Code (50K)
추가 데이터: Skywork-Reward-Preference-80K
테스트 데이터: Infinity-Preference, UltraFeedback, PKU-SafeRLHF, Reward-Bench
모델 (Models) 기본 모델: Qwen2-7B-Instruct
훈련 모델:   • Scalar Model (SM) – Pairwise & Pointwise   • Generative Judge (Con-J)
비교 모델: GPT-4o, Auto-J, Prometheus 2, Llama 3.1-8B, Llama 3.1-70B, Qwen2-7B, Qwen2.5-72B
훈련 환경 (Hyperparameters & Training) 학습률: 9e-6 (Cosine Scheduler, 3% Warmup)
배치 크기: SM – 128, Con-J – 24 
손실 함수: SFT Loss + DPO Loss (α = 1e-6)

 

Main Results

  • Con-J가 모든 태스크에서 SM보다 높은 성능을 보임
  • GPT-4o보다도 우수한 성능을 발휘
  • Text Creation에서 가장 큰 성능 차이를 보이며, 판단 능력 향상 효과 확인

Con-J 모델 변형 실험 (Ablation Study)

  • DPO 학습이 성능 향상에 중요한 역할을 함
  • Hint-driven 샘플링이 없을 경우 정확도가 큰 폭으로 하락
  • DPO + Hint 사용 시 최상의 성능을 보임

 

Contrastive Judgment 학습량에 따른 성능 변화

 

  • 5개 체크포인트 (2k, 4k, 8k, 16k, 50k)에서 실험 진행
  • Judgment Accuracy가 증가할수록 판단 근거(rationale)도 향상됨
  • 그러나 정확성이 높아질수록 판단 근거의 일관성은 감소 (Binary Preference 예측 능력은 증가하지만, rationale과의 균형이 맞지 않음)

 

데이터 편향 실험

 

  • Format(응답 형식) & Verbosity(응답 길이) 편향을 다르게 적용하여 실험
  • Adversarial (도전적 문제) 및 General (일반 문제) 테스트 셋에서 성능 비교
  • Con-J가 Scalar Model 및 판단 근거 없이 훈련된 Con-J보다 우수한 성능

 

 

 

 

  •