Computer Vision

[2025-2] 전연주 - DiT: Scalable Diffusion Models with Transformers

YeonJuJeon 2025. 8. 16. 09:12

논문 링크: 2212.09748

깃허브 링크: facebookresearch/DiT: Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"

 

GitHub - facebookresearch/DiT: Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"

Official PyTorch Implementation of "Scalable Diffusion Models with Transformers" - facebookresearch/DiT

github.com


Abstract

  • 아이디어: 기존 Diffusion 모델은 U-Net을 backbone으로 사용했으나, 본 논문은 이를 Transformer로 교체한 Diffusion Transformer (DiT) 를 제안.
  • 방법: Latent Diffusion 환경(LDM)에서 patch 단위 latent를 Transformer로 처리.
  • 분석: 모델 복잡도(연산량)를 Gflops로 측정하여 scaling law 확인.
  • 결과:
    • Gflops ↑ → FID ↓ (scaling law 성립)
    • DiT-XL/2 모델: ImageNet 256×256에서 FID 2.27 (SOTA)
    • 512×512에서도 SOTA 달성

DiT로 생성한 샘플 이미지 (ImageNet 512, 256)


1. Introduction

  • 지난 5년간 Transformer는 NLP, Vision 등 모든 영역에서 주류 모델로 자리 잡음.
  • 하지만 이미지 생성(특히 Diffusion)에서는 여전히 U-Net이 표준 backbone으로 사용됨.
  • 문제 제기: U-Net이 반드시 필요할까? Transformer로도 충분히 대체 가능할까?
  • 기여:
    • U-Net inductive bias가 없어도 Diffusion은 잘 동작함을 입증.
    • DiT는 Transformer의 scaling 특성을 그대로 계승.
    • Compute-efficiency에서도 경쟁 우위 확보.

FID vs Gflops 관계

  • Bubble 크기: 모델 FLOPs
  • 왼쪽: 모델 FLOPs ↑ → FID ↓
  • 오른쪽: DiT-XL/2가 기존 ADM, LDM보다 효율적임

2. Related Work

 

  • DDPMs: 최근 GAN보다 더 좋은 성능, 안정적 학습 가능.
  • 개선 방향: classifier-free guidance, noise prediction, cascaded pipeline.
  • U-Net: Ho et al.의 초기 구조에서 거의 변화 없이 사용됨.
  • 문제: parameter count는 모델 복잡도를 제대로 설명하지 못함.
  • 해결책: FLOPs 단위로 복잡도를 정의 → Transformer scaling 분석.

 


3. Method

3.1 Latent Diffusion Models

  • 픽셀 공간 직접 학습은 너무 비용이 큼.
  • LDM은 이미지를 VAE encoder로 latent z (예: 32×32×4) 로 압축 후 Diffusion 학습.
  • Sampling 시: latent에서 샘플 생성 후 VAE decoder로 이미지 복원.
  • 본 논문은 Stable Diffusion의 pre-trained VAE 사용.
    • 입력 256×256×3 → latent 32×32×4
    • Downsampling factor = 8

Diffusion 세팅

  • 1000 steps
  • Linear noise schedule
  • ε-prediction (noise 예측)
  • Classifier-free guidance 사용

3.2 Diffusion Transformer (DiT)

입력 처리 (Patchify)

z → patchify → tokens → DiT blocks

  • latent z ∈ ℝ^{I×I×C}
  • p×p 크기로 잘라 T = (I/p)^2 개 토큰 생성
  • 각 패치를 Linear projection → hidden dim d
  • Positional embedding: 2D sine-cosine

Patch 크기 trade-off

  • p ↓ → T ↑ (토큰 수 증가) → 성능↑ but 연산량 급증
  • FLOPs는 최소 4배 증가 가능

Conditioning Mechanisms

  1. In-context conditioning
    • timestep t, class label c를 embedding 후 토큰으로 추가
    • 구현 간단, 연산 오버헤드 거의 없음
    • 성능은 약간 떨어짐
  2. Cross-Attention block
    • t, c를 별도 시퀀스로 두고 cross-attn 수행
    • 성능은 괜찮지만 연산 오버헤드 큼 (약 15%)
  3. adaLN
    • LayerNorm을 adaptive LayerNorm으로 교체
    • t, c embedding에서 γ, β를 회귀해 토큰에 동일하게 적용
    • compute-efficient
  4. adaLN-Zero
    • adaLN에 residual scale α를 추가
    • α는 0으로 초기화하여 블록 전체가 identity로 시작
    • 안정적 학습 가능, 성능/효율 모두 최고

Decoder Head

  • 마지막 DiT block 후 LayerNorm
  • Linear projection으로 각 토큰 → p×p×2C
    • (C는 latent channel 수, 여기서는 4)
    • 2C는 noise와 variance를 동시에 예측
  • Reshape → 공간 구조 복원

3.3 Model Variants


4. Experimental Setup

  • Dataset: ImageNet 256×256, 512×512
  • Optimizer: AdamW (lr=1e-4, wd=0)
  • Batch size: 256
  • Augmentation: horizontal flip only
  • EMA decay: 0.9999
  • Evaluation: FID-50K (50K samples, 250 DDPM steps)
  • Compute: JAX 구현, TPU v3 pods에서 학습

5. Experiments

5.1 Conditioning Comparison

  • 동일한 DiT-XL/2 모델에서 4가지 conditioning 비교
  • adaLN-Zero가 압도적
    • 학습 안정성↑
    • FID↓
    • 연산 효율↑

FID-50K 곡선 비교: adaLN-Zero > adaLN > Cross-Attn > In-context

5.2 Scaling Experiments

모델 크기/patch 크기에 따른 학습 곡선
Training compute vs FID 곡선: 큰 모델이 compute 효율 우위

  • 모델 크기↑ (S → B → L → XL) → 성능 지속 개선
  • Patch 크기↓ (토큰 수 증가) → 성능 개선, FLOPs 증가
  • Scaling law 확인: FLOPs와 FID 간 강한 상관관계
  • 작은 모델을 오래 돌리기보다 큰 모델을 적당히 학습하는 게 효율적

5.3 State-of-the-Art Results

  • ImageNet 256×256
    • DiT-XL/2 + classifier-free guidance
    • FID 2.27 → 기존 LDM보다 큰 폭 개선
  • ImageNet 512×512
    • DiT-XL/2
    • FID 3.04 → 기존 ADM보다 개선

6. Classifier-Free Guidance

$$\hat\varepsilon_\theta(x_t, c) = \varepsilon_\theta(x_t, ∅) + s \cdot (\varepsilon_\theta(x_t, c) - \varepsilon_\theta(x_t, ∅)), \quad s > 1$$

  • ε(x_t, ∅): label dropout으로 학습한 null embedding
  • s: guidance scale (크면 품질↑, 다양성↓ trade-off)

7. Implementation Details

  • VAE: Stable Diffusion의 pre-trained VAE
  • Diffusion 세팅: ADM과 동일 (1000 steps, linear schedule)
  • Decoder head: p×p×2C 텐서로 noise와 variance 복원
  • Hyperparameters: 전 모델 동일하게 적용
  • Compute: TPU v3 pods, 대규모 학습

8. Additional Analysis

  • 모든 메트릭(FID, IS, Precision/Recall 등)에서 scaling이 compute-efficiency를 보임
  • Training loss도 모델 크기↑ → 더 빠른 감소

9. Conclusion

  • DiT는 U-Net을 Transformer로 대체해도 latent diffusion에서 안정적이고 강력함을 입증
  • adaLN-Zero가 핵심: 안정적 학습 + 높은 성능
  • Scaling law 성립: 모델 FLOPs와 FID 간 강한 상관, 큰 모델이 compute 효율에서도 우위

 DiT는 Diffusion의 미래를 Transformer 쪽으로 크게 열어준 연구