논문 링크: 2212.09748
GitHub - facebookresearch/DiT: Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
Official PyTorch Implementation of "Scalable Diffusion Models with Transformers" - facebookresearch/DiT
github.com
Abstract
- 아이디어: 기존 Diffusion 모델은 U-Net을 backbone으로 사용했으나, 본 논문은 이를 Transformer로 교체한 Diffusion Transformer (DiT) 를 제안.
- 방법: Latent Diffusion 환경(LDM)에서 patch 단위 latent를 Transformer로 처리.
- 분석: 모델 복잡도(연산량)를 Gflops로 측정하여 scaling law 확인.
- 결과:
- Gflops ↑ → FID ↓ (scaling law 성립)
- DiT-XL/2 모델: ImageNet 256×256에서 FID 2.27 (SOTA)
- 512×512에서도 SOTA 달성
1. Introduction
- 지난 5년간 Transformer는 NLP, Vision 등 모든 영역에서 주류 모델로 자리 잡음.
- 하지만 이미지 생성(특히 Diffusion)에서는 여전히 U-Net이 표준 backbone으로 사용됨.
- 문제 제기: U-Net이 반드시 필요할까? Transformer로도 충분히 대체 가능할까?
- 기여:
- U-Net inductive bias가 없어도 Diffusion은 잘 동작함을 입증.
- DiT는 Transformer의 scaling 특성을 그대로 계승.
- Compute-efficiency에서도 경쟁 우위 확보.
- Bubble 크기: 모델 FLOPs
- 왼쪽: 모델 FLOPs ↑ → FID ↓
- 오른쪽: DiT-XL/2가 기존 ADM, LDM보다 효율적임
2. Related Work
- DDPMs: 최근 GAN보다 더 좋은 성능, 안정적 학습 가능.
- 개선 방향: classifier-free guidance, noise prediction, cascaded pipeline.
- U-Net: Ho et al.의 초기 구조에서 거의 변화 없이 사용됨.
- 문제: parameter count는 모델 복잡도를 제대로 설명하지 못함.
- 해결책: FLOPs 단위로 복잡도를 정의 → Transformer scaling 분석.
3. Method
3.1 Latent Diffusion Models
- 픽셀 공간 직접 학습은 너무 비용이 큼.
- LDM은 이미지를 VAE encoder로 latent z (예: 32×32×4) 로 압축 후 Diffusion 학습.
- Sampling 시: latent에서 샘플 생성 후 VAE decoder로 이미지 복원.
- 본 논문은 Stable Diffusion의 pre-trained VAE 사용.
- 입력 256×256×3 → latent 32×32×4
- Downsampling factor = 8
Diffusion 세팅
- 1000 steps
- Linear noise schedule
- ε-prediction (noise 예측)
- Classifier-free guidance 사용
3.2 Diffusion Transformer (DiT)
입력 처리 (Patchify)
- latent z ∈ ℝ^{I×I×C}
- p×p 크기로 잘라 T = (I/p)^2 개 토큰 생성
- 각 패치를 Linear projection → hidden dim d
- Positional embedding: 2D sine-cosine
Patch 크기 trade-off
- p ↓ → T ↑ (토큰 수 증가) → 성능↑ but 연산량 급증
- FLOPs는 최소 4배 증가 가능
Conditioning Mechanisms
- In-context conditioning
- timestep t, class label c를 embedding 후 토큰으로 추가
- 구현 간단, 연산 오버헤드 거의 없음
- 성능은 약간 떨어짐
- Cross-Attention block
- t, c를 별도 시퀀스로 두고 cross-attn 수행
- 성능은 괜찮지만 연산 오버헤드 큼 (약 15%)
- adaLN
- LayerNorm을 adaptive LayerNorm으로 교체
- t, c embedding에서 γ, β를 회귀해 토큰에 동일하게 적용
- compute-efficient
- adaLN-Zero
- adaLN에 residual scale α를 추가
- α는 0으로 초기화하여 블록 전체가 identity로 시작
- 안정적 학습 가능, 성능/효율 모두 최고
Decoder Head
- 마지막 DiT block 후 LayerNorm
- Linear projection으로 각 토큰 → p×p×2C
- (C는 latent channel 수, 여기서는 4)
- 2C는 noise와 variance를 동시에 예측
- Reshape → 공간 구조 복원
3.3 Model Variants
4. Experimental Setup
- Dataset: ImageNet 256×256, 512×512
- Optimizer: AdamW (lr=1e-4, wd=0)
- Batch size: 256
- Augmentation: horizontal flip only
- EMA decay: 0.9999
- Evaluation: FID-50K (50K samples, 250 DDPM steps)
- Compute: JAX 구현, TPU v3 pods에서 학습
5. Experiments
5.1 Conditioning Comparison
- 동일한 DiT-XL/2 모델에서 4가지 conditioning 비교
- adaLN-Zero가 압도적
- 학습 안정성↑
- FID↓
- 연산 효율↑
5.2 Scaling Experiments
- 모델 크기↑ (S → B → L → XL) → 성능 지속 개선
- Patch 크기↓ (토큰 수 증가) → 성능 개선, FLOPs 증가
- Scaling law 확인: FLOPs와 FID 간 강한 상관관계
- 작은 모델을 오래 돌리기보다 큰 모델을 적당히 학습하는 게 효율적
5.3 State-of-the-Art Results
- ImageNet 256×256
- DiT-XL/2 + classifier-free guidance
- FID 2.27 → 기존 LDM보다 큰 폭 개선
- ImageNet 512×512
- DiT-XL/2
- FID 3.04 → 기존 ADM보다 개선
6. Classifier-Free Guidance
$$\hat\varepsilon_\theta(x_t, c) = \varepsilon_\theta(x_t, ∅) + s \cdot (\varepsilon_\theta(x_t, c) - \varepsilon_\theta(x_t, ∅)), \quad s > 1$$
- ε(x_t, ∅): label dropout으로 학습한 null embedding
- s: guidance scale (크면 품질↑, 다양성↓ trade-off)
7. Implementation Details
- VAE: Stable Diffusion의 pre-trained VAE
- Diffusion 세팅: ADM과 동일 (1000 steps, linear schedule)
- Decoder head: p×p×2C 텐서로 noise와 variance 복원
- Hyperparameters: 전 모델 동일하게 적용
- Compute: TPU v3 pods, 대규모 학습
8. Additional Analysis
- 모든 메트릭(FID, IS, Precision/Recall 등)에서 scaling이 compute-efficiency를 보임
- Training loss도 모델 크기↑ → 더 빠른 감소
9. Conclusion
- DiT는 U-Net을 Transformer로 대체해도 latent diffusion에서 안정적이고 강력함을 입증
- adaLN-Zero가 핵심: 안정적 학습 + 높은 성능
- Scaling law 성립: 모델 FLOPs와 FID 간 강한 상관, 큰 모델이 compute 효율에서도 우위
→ DiT는 Diffusion의 미래를 Transformer 쪽으로 크게 열어준 연구