Data Selection for Language Models via Importance Resampling
Selecting a suitable pretraining dataset is crucial for both general-domain (e.g., GPT-3) and domain-specific (e.g., Codex) language models (LMs). We formalize this problem as selecting a subset of a large raw unlabeled dataset to match a desired target di
arxiv.org
1. Method
DSIR Framework
Large raw dataset에서 target data의 distribution과 일치하는 데이터 추출
Step 1. Feature Extraction
- Raw data와 target data를 같은 feature space에 mapping
- m(논문에서는 10,000)차원의 feature vector로 변환
Step 2. Feature Distribution
- Raw data와 target data의 distribution 추정
Step 3. Computing Importance Weight
- Raw 데이터 x에 대해, 특징 z=h(x) (h는 feature extractor)
- 가중치 계산
Step 4. Resampling
- 가중치를 정규화하여 distribution을 구함
- Gumbel top‑k를 사용하여 k개의 데이터를 추출
- 각 문서의 중요도 가중치에 Gumbel 노이즈를 더한 후, 상위 k개의 값을 한번에 선택
2. Experiment
Selecting Data for Domain-Specific Continued Pretraining
- Raw Data: The Pile
- Target Data: 4개의 도메인에서 8개의 데이터셋 사용
- Computer Science papers (ACL-ARC, Sci-ERC)
- Biomedicine (ChemProt, RCT)
- News (AGNews, HyperPartisan)
- Reviews (Helpfulness, IMDB)
- Model: RoBERTa-base
- 비교:
- Random selection: 단순 무작위 샘플링
- Heuristic classification: GPT‑3나 PaLM에서 생성한 데이터로 학습한 fasttext classifer를 이용해 target distribution에 가까운 문서를 선택
- Manual curation (DAPT): 전문가 큐레이션
- Top‑k heuristic classification: Heuristic classification에서 상위 k개 문서를 선택하는 방식
- DSIR
- 결과:
- 자동화된 알고리즘을 통해 동일한 품질의 데이터를 빠르고 효율적으로 선택할 수 있음
- Discriminative(fasttext 기반 classifier를 사용하여 target data일 확률로 가중치 추정)이 generative(n‑gram 기반으로 한 확률 모델)을 사용할 때보다 성능이 낮음
- Hashed n‑gram(uni-gram + bi-gram)을 사용할 때, uni-gram만 사용할 때보다 더 나은 downstream task 성능(F1 score)
- Pretraining data의 target distribution가 downstream task와 잘 맞을수록 성능이 높아짐
- 잘못된 target distribution을 선택하면 F1 score 하락
- 같은 도메인 내의 여러 task에 대해 pretraining data를 선택했을 때, 다른 도메인에서 선택한 데이터를 사용해서 pretrain 했을 때보다 downstream 성능이 더 좋음
KL Reduction on Hashed N-grams Predicts Downstream Performance
- 무작위로 선택한 데이터셋과 타깃 데이터셋 간의 KL divergence와, DSIR로 선택된 데이터셋과 타깃 데이터셋 간의 KL divergence의 차이(KL reduction)를 계산
- 실험 결과 Hashed n‑gram feature space에서 계산된 KL reduction 값이 downstream task의 평균 성능과 높은 상관관계(KL reduction 값이 높을수록, 해당 데이터를 활용해서 pretrain한 model이 downstream task에서 더 좋은 성능을 냄)를 가짐
Selecting Data for Domain-Specific Continued Pretraining
- Raw Data: The Pile
- Target Data: 위키피디아, 책
- Model: BERT-base
- 비교:
- Random selection
- Heuristic classification
- Top‑k heuristic classification
- DSIR
- Top-k DSIR
- 결과:
- DSIR이 선택한 텍스트는 다른 방법들에 비해서 Wikipedia나 책과 유사한 형식
- DSIR은 GLUE에서 다른 방법 대비해서 높은 성능 향상
- Top-k DSIR에 비해 DSIR이 더 높음
3. Limitation
- n-gram 모델을 말고도 다양한 임베딩 방법 시도의 부재
- 다양한 평가지표의 부재
- 최근 와서는 다양한 open source 모델이 공개되었기 때문에, RoBERTa 말고도 다른 모델에서도 좋은 성능을 내는지도 궁금함