본문 바로가기
  • 책상 밖 세상을 경험할 수 있는 Playground를 제공하고, 수동적 학습에서 창조의 삶으로의 전환을 위한 새로운 라이프 스타일을 제시합니다.
Computer Vision

[2025-1] 김유현 - Improved Training of Wasserstein GANs

by rdg126 2025. 3. 22.

https://arxiv.org/abs/1704.00028

 

Improved Training of Wasserstein GANs

Generative Adversarial Networks (GANs) are powerful generative models, but suffer from training instability. The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only low-quality sampl

arxiv.org

 

0. Abstract

GAN은 강력한 생성 모델이지만 학습 불안정성이 문제이다. WGAN은 학습 안정성을 개선했지만, 여전히 샘플 품질이 낮거나 수렴 실패 문제가 발생할 수 있다. 이는 WGAN에서 Lipschitz 제약을 적용하기 위해 weight clipping 하는 방식이 원인이 될 수 있다. 이를 해결하기 위해 논문에서는 critic의 입력에 대한 gradient norm을 penalty로 부과하는 방식을 제안한다. 이 방법은 WGAN보다 안정적인 학습을 가능하게 하며, 101층 ResNet 및 연속적 생성자를 사용하는 언어 모델 등 다양한 GAN architecture에서 별다른 hyperparameter tuning 없이도 효과적으로 동작한다. 또한 CIFAR-10 및 LSUN bedroom 데이터셋에서 높은 품질의 샘플을 생성할 수 있다. 

 

1. Introduction

GAN은 생성자와 판별자가 경쟁하는 방식으로 작동하는 강력한 생성 모델이지만, 학습이 불안정한 문제가 있다. 이를 해결하기 위해 WGAN이 Wasserstein 거리를 활용하여 더 안정적인 가치 함수를 도입했지만, 1-Lipschitz 조건을 만족시키기 위해 weight clipping을 사용하면서 새로운 문제가 발생했다. 이를 해결하기 위해 논문에서는 WGAN-GP를 도입하며 WGAN-GP의 목적은 아래와 같다. 

  • 간단한 데이터셋에서 weight clipping이 문제를 일으킬 수 있음을 보인다.
  • Gradient penalty를 적용한 WGAN-GP를 제안하여 이러한 문제를 해결한다.
  • 다양한 GAN 구조에서 안정적인 학습을 가능하게 하고, 이미지 생성 및 문자 기반 GAN 모델에서도 성능을 향상시킨다.

 

2. Background

A. Generative adversarial networks 

GAN의 학습은 Generator와 Discriminator 사이의 경쟁적인 방식으로 정의된다. Generator는 노이즈를 입력 공간으로 매핑하여 가짜 데이터를 생성하고 Discriminator를 속이는 방향으로 학습시킨다. Discriminator는 입력 데이터가 실제 데이터인지 생성된 데이터인지 구별하는 역할을 한다. 이는 minimax 목적 함수로 표현된다. 

여기서 $P_r$은 실제 데이터 분포, $P_g$는 생성된 데이터 분포로 정의된다. 

이론적으로, Discriminator가 최적화된 상태에서 Generator가 위 식을 최소화하면 Jensen-Shannon 발산(JSD)을 최소화하는 것이 된다. 하지만, 이는 Discriminator가 너무 강할 경우 gradient vanishing 문제를 초래할 수 있다. 이를 해결하기 위해 기존 연구에서는 Generator가 가짜 데이터를 통해 Discriminator를 속이는 방향으로 최대화하도록 학습하는 방법을 제안했으나, 강력한 Discriminator가 존재하면 이 방식도 불안정해질 수 있다.

 

B. Wasserstein GANs

기존 GAN이 최적화하는 발산 함수들은 생성자의 파라미터에 대해 연속적이지 않을 수 있어 학습이 어려운 문제가 있다. 이를 해결하기 위해 WGAN에서는 Wasserstein-1 distance를 활용하며, 이는 한 분포를 다른 분포로 변환하는 데 필요한 최소 비용으로 정의된다. Wasserstein 거리는 연속적이고 거의 모든 곳에서 미분 가능하여 학습 안정성이 높다.

WGAN의 목적 함수는 Kantorovich-Rubinstein 이중성을 이용하여 아래와 같이 정의한다.

 

여기서 critic은 1-Lipschitz 함수를 만족해야 하며, 이상적인 Discriminator 하에서는 Generator의 최적화가 Wasserstein 거리를 최소화하는 것과 동일하게 된다.

WGAN 목적 함수는 기존 GAN보다 Discriminator의 gradient가 더 안정적으로 동작하도록 하여 생성자의 학습을 쉽게 만든다. 또한, WGAN의 목적 함수 값이 생성된 샘플 품질과 더 잘 연관됨이 실험에서 나타났다.

하지만, 1-Lipschitz 조건을 강제하기 위해 Discriminator의 gradient를 [$−c,c$] 범위 내로 clipping하는 방식이 사용되었으며, 이는 특정 $k$-Lipschitz 함수의 부분집합을 구성한다. 논문에서는 이 방식의 문제점을 보이고, 이를 개선하기 위한 새로운 접근 방식인 WGAN-GP을 제안한다. 

 

3) Properties of the optimal WGAN critic

WGAN의 critic에서 weight clipping이 문제가 되는 이유를 제시하고 WGAN의 구조에서 최적의 critic의 몇 가지 특성을 보여준다. 

 

C. Difficulties with weight constraints

WGAN에서 weight clipping이 최적화에 어려움을 보이고, 최적화가 성공하더라도 critic의 값이 비정상적인 분포를 가질 수 있다. WGAN 논문에서 사용한 가중치 절댓값을 특정 범위 [$-c, c$] 내로 강제하는 방식 뿐만 아니라, 아래와 같은 방법을 도입했다. 

  • L2 norm clipping
  • Weight Normalization
  • Weight decay, Soft Constraints 

Critic 네트워크에서 batch normalization을 사용하면 어느 정도 문제를 완화할 수 있었지만 깊은 WGAN critic network의 경우 여전히 수렴하지 않는 문제가 발생한다.

 

A. Capacity Underuse

Weight clipping을 이용한 $k$-Lipschitz constraint를 부과하는 것은 critic을 더 간단한 함수로 편향시키는 문제를 초래한다. 최적의 WGAN critic은 데이터 분포 $P_r$과 생성 분포 $P_g$에서 gradient norm이 거의 모든 곳에서 1이 되어야 하지만, weight clipping을 적용하면 critic이 최대 gradient norm $k$를 달성하려고 하면서 극단적으로 단순한 함수만 학습하는 경향을 가진다. 이것을 증명하기 위해서, 논문에서는 고정된 생성 분포인 $P_g$에서 여러 toy distribution을 사용하여 WGAN critic을 최적화한다. 

  • 실험 결과, weight clipping을 사용한 critic은 데이터 분포의 고차원 정보를 무시하고 단순한 함수만 학습하게 된다
  • Batch normalization을 제외한 상태에서 실험을 진행했다. 
  • Figure 1a에서 clipping을 적용한 critic은 최적 함수의 단순한 근사만 수행하며 논문에서는 WGAN-GP는 이런 문제를 겪지 않는다고 주장한다. 

 

 

B. Exploding and Vanishing Gradients

WGAN에서는 weight clipping과 cost function의 상호작용 때문에 최적화 과정이 어렵다. 

  • Clipping 임계값 $c$를 적절히 조정하지 않으면 gradient가 소실되거나 폭발할 위험이 있다. 
  • Swiss Roll 데이터셋을 이용해 WGAN을 학습하며, clipping 임계값 $c$를 [$10^{(-1)}, 10^{(-2)}, 10^{(-3)}$]로 변경하면서 gradient 변화를 분석한다. 
  • Critic과 Generator 모두 12-ReLU MLP 모델을 사용하며, batch normalization은 사용하지 않는다. 

실험 결과는 아래와 같다. 

  • Figure 1b에서 확인한 결과, clipping 임계값에 따라 네트워크 깊이가 깊어질수록 gradient가 급격하게 증가하거나 감소한다. 
  • WGAN-GP 방식은 gradient가 폭발하거나 소실하지 않고 안정적으로 유지된다. 

따라서 WGAN-GP 방식은 gradient 안정성을 유지하여 보다 복잡한 네트워크 학습이 가능한 것을 볼 수 있다. 

 

4. Gradient Penalty

A. Algorithm 1

<Gradient Penalty>

1) 입력 값

  • Gradient Penalty 계수 λ
  • Critic 학습 반복 횟수 $n$
  • 배치 크기 $m$
  • Adam 최적화 하이퍼파라미터 α, β1​, β2
  • Critic 초기 파라미터 $w_0$, Generator 초기 파라미터 $θ_0$

 

2) 학습 과정

  • Critic 업데이트
    • 실제 데이터 $x$ ~ $P_r$ 및 잠재 변수 $z$ ~ $p(z)$ 샘플링
    • 생성된 샘플 $G_θ(z)$ 계산
    • 실제 샘플과 생성 샘플을 보간해 새로운 랜덤 샘플 생성
    • Critic의 손실 함수 $L$ 계산
    • Adam 최적화로 critic 업데이트
  • Generator 업데이트
    • 잠재 변수 $z$ ~ $p(z)$ 샘플링
    • 생성자의 손실 함수 $-D_w(G_θ(z))$를 Adam으로 최적화

 

 

B. Gradient Penalty

논문에서는 Lipschitz 조건을 강제하는 새로운 방식을 제안한다. 

 

1) 기존 WGAN의 문제점

  • 기존 WGAN에서는 weight clipping을 사용하여 Lipschitz 제약을 적용했으나, 최적의 critic 학습을 방해하고 gradient가 소실 또는 폭발하는 문제를 초래한다. 

 

2) 새로운 Lipschitz 제약 방식

  • 1-Lipschitz 함수는 gradient norm이 1 이하일 때 성립
  • Critic의 출력에 대한 입력 gradient 크기를 직접 제한하는 방식을 제안
  • 임의의 샘플에 대해 gradient의 크기가 1에서 벗어나는 정도를 패널티로 추가

 

3) 세부 구현 방법

  • Sampling distribution
    • 기존에는 단순히 실제 데이터와 가짜 데이터를 critic에 입력했지만, WGAN-GP는 실제 데이터 분포 $P_r$와 생성된 데이터 분포 $P_g$ 사이의 직선 경로에서 랜덤하게 샘플링
    • 최적 critic의 gradient가 1이 되는 경향이 있기 때문에, 해당 경로에서 gradient를 제약하는 것만으로도 Lipschitz 조건을 강제하는데 효과적
  • Penalty coefficient =  λ
    • 모든 실험에서 λ = 10을 사용
    • 다양한 구조 및 데이터셋에서 일관된 성능 향상을 보인다.
  • No critic batch normalization
    • 기존 GAN에서는 Batch Normalization을 사용하여 학습을 안정화했으나, critic의 출력이 개별 샘플이 아닌 batch 전체에 의해 영향을 받는 문제가 발생
    • WGAN-GP는 개별 샘플에 대해 gradient penalty를 적용하기 때문에 batch normalization이 적절하지 않다
    • Layer Normalization을 대체 방안으로 추천
  • Two-sided penalty
    • 기존 방식에서는 gradient norm이 1을 초과히지 않도록 제한
    • WGAN-GP에서는 gradient가 1에 수렴하도록 강제
    • Gradient가 너무 크거나 너무 작아지는 것을 모두 방지하여 Critic의 성능을 지나치게 제한하지 않아 더 나은 성능 도출
더보기

Layer Normalization

  • 하나의 샘플 내에서 각 레이어별로 정규화 진행
  • 독립적인 진행이므로 GAN 학습에도 사용 가능

 

Batch Normalization

  • 배치 전체에서 평균과 분산을 계산
  • 작은 배치와 시퀀스 길이가 다른 경우 불안정하고 동작하기 어렵다

5. Experiments

1) Training random architectures within a set

A. Table 1

DCGAN 구조를 기반으로 다양한 하이퍼파라미터를 조합하여 새로운 아키텍처를 생성한다. Table 1은 실험에 사용된 GAN 구조의 하이퍼파라미터 범위를 보여준다.

  • Nonlinearity G & D
    • Generator와 Discriminator에 적용된 활성 함수
    • ReLU, LeakyReLU, SoftPlus, Tanh
  • Depth G & D
    • Generator와 Discriminator의 네트워크 깊이: 4, 8, 12, 20
    • 네트워크의 총 개수를 다양하게 설정하여 성능 비교
  • Batch norm G & D
    • Generator에서는 일반적인 Batch Normalization 사용 여부 테스트
    • Discriminator에서는 Layer Normalization을 적용
  • Base filter count G & D
    • Generator와 Discriminator의 필터 개수: 32, 64, 128
    • 작은 필터에서부터 큰 필터까지 다양한 설정으로 테스트하여 성능 비교

 

 

B. Table 2

 

WGAN-GP의 안정성 실험을 위해 200개의 sample architecture를 생성하여 32x32 ImageNet 데이터셋에서 학습을 진행한다. 표준 GAN과 WGAN-GP를 비교하며 Inception Score를 기준으로 학습 성공 여부를 판단한다. 일부 표준 GAN architecture도 성공하지만, WGAN-GP는 더 많은 architecture에서 학습을 성공적으로 수행한다. 또한 점수 기준이 높아질수록 성공하는 네트워크 개수가 급감됨을 알 수 있다. 

예를 들어, Min.score=5.0일 때 Only WGAN-GP = 147의 경우 147개의 네트워크가 기존 GAN은 실패했지만 WGAN-GP에서는 성공했다는 의미이다. 

 

2) Training varied architectures on LSUN bedrooms

A. 평가 Architecture

디양한 GAN architecture에서 WGAN-GP의 학습 성능을 평가하는데 평가한 architecture는 아래와 같다. 

  • Batch normalization 없이 일정한 필터 개수를 가진 Generator
  • 4-layer 512 차원 ReLU MLP Generator
  • Generator & Discriminator에 normalization을 적용하지 않은 구조
  • Gated Multiplicative Nonlinearities 적용 - 특정한 비선형 변환 방식
  • Tanh 활성화 함수 적용
  • 101-layer ResNet 기반 Generator & Discriminator
더보기

Gated Multiplicative Nonlinearities

게이트 구조를 사용하여 뉴런의 활성화 정도를 조절하는 비선형 변환 방식

  • 뉴런의 출력을 조절하는 곱셈 연산을 사용
  • 일반적인 뉴런은 선형 변환 후 활성화 함수를 거치지만 Gated Multiplicative Nonlinearities는 게이트를 추가로 곱해서 동작을 조절
  • 모델이 더 풍부한 표현력을 가질 수 있도록 도움

 

B. GAN 학습 방법

WGAN-GP와 기존 방법을 비교하는데 비교한 GAN 학습 방법은 아래와 같다. 각 방법에서는 해당 논문의 기본 최적화 하이퍼파라미터를 사용해 20만 번의 학습을 반복해서 진행한다. 

  • WGAN-GP - Gradient Penalty 
  • WGAN - Weight clipping
  • DCGAN
  • Least-squares GAN

 

Figure 2에서는 WGAN-GP를 제외한 다른 방법들은 특정 architecture에서 학습이 불안정하거나 실패함을 보인다. 또한 WGAN-GP는 다양한 architecture에서 공통된 하이퍼파라미터로 일관된 성능을 보이며, 안정적인 학습이 가능함을 알 수 있다. 논문에서는 101-layer ResNet을 사용한 GAN이 성공적으로 학습된 것이 처음이라고 보고하고 있다. 

 

3) Improved performance over weight clipping

WGAN-GP는 기존 weight clipping 방법보다 학습 속도 및 샘플 품질이 향상됨을 확인할 수 있다. 이를 입증하기 위해 CIFAR-10 데이터셋에서 WGAN과 WGAN-GP를 비교해 Inception Score 변화를 분석한다. 

 

A. 실험조건

  • 동일한 RMSProp optimizer 및 learning rate로 WGAN-GP와 WGAN-Clip을 비교
  • Adam Optimizer와 더 높은 learning rate를 적용한 WGAN-GP 추가 실험

 

B. 결과

  • 동일한 optimizer 사용 시에도 WGAN-GP가 더 빠르게 수렴하고 더 높은 Inception Score를 기록
  • Adam optimizer 적용 시 성능이 더욱 향상
  • DCGAN과 비교 시, WGAN-GP는 수렴속도가 더 느리지만 최종적으로 더 안정적인 성능 도출

 

4) Sample quality on CIFAR-10 and LSUN bedrooms

동일한 네트워크 architecture를 사용할 경우, WGAN-GP는 기존 GAN과 비슷한 수준의 샘플 품질을 달성한다. 하지만 학습 안정성이 더 뛰어나므로 더 다양한 architecture를 실험할 수 있으며, 이를 통해 샘플 품질을 더욱 개선할 수 있다. 

 

A. 결과

  • CIFAR-10
    • 비지도 학습: 새로운 네트워크 architecture를 적용하여 최고 수준의 Inception Score를 기록
    • 지도 학습, label 사용: 동일한 네트워크에 label 정보를 추가하면 기존 모델보다 뛰어나며, SGAN을 제외한 모든 모델보다 우수한 성능을 기록

 

  • LSUN Bedrooms
    • 128x128 해상도 이미지 생성 실
    • 깊은 ResNet을 학습하여 생성된 샘플을 제시
    • 기존 연구와 비교해도 경쟁력 있는 결과를 도출
더보기

Inception Score

: GAN이 생성한 이미지의 품질을 측정하는 지표

  • Inception Score 증가
    • 생성된 이미지의 품질이 좋을 때
    • 생성된 이미지의 다양성이 증가할 때
    • Inception 네트워크가 각 이미지를 특정 클래스로 확신할 때
  • Inception Score 감소
    • 이미지가 흐려지거나 노이즈가 심할 때
    • 모델이 특정패턴에만 과적합되어 다양성이 줄어들 때
    • Inception 모델이 확신을 갖고 분류하지 못할 때 

 

5) Modeling discrete data with a continuous generator

A. 실험 

WGAN-GP가 언어 데이터 같은 이산 데이터 분포를 학습할 수 있는지 검증하기 위해 문자 단위의 GAN 언어 모델을 학습한다. 데이터셋으로는 Google Billion Word 데이터셋을 사용하여 실험한다. 

  • Generator
    • 1D CNN 구조 사용
    • Latent vector를 32개의 one-hot 문자 벡터로 변환
    • 출력층에서 Softmax 함수 적용 = 샘플링 없이 그대로 사용
  • Critic
    • 1D CNN 구조 사용
    • 생성된 문자를 직접 입력 받아 평가
  • 샘플링 과정
    • Softmax 벡터에서 Argmax를 확인해 최종 문자를 선택

 

B. 결과

  • 모델의 철자 오류는 있지만, 언어 통계를 학습하는 데는 성공
  • 기존 GAN으로는 유사한 결과를 얻지 못하지만 WGAN-GP는 문장 구조와 통계를 어느 정도 학습했음을 확인 가능

 

기존 GAN은 JS Divergence가 포화되는 문제인 Discriminator가 가짜 샘플을 쉽게 구별하고 Generator에 의미 있는 gradient를 제공하지 못하는 현상이 발생한다. 이와 반대로 WGAN-GP의 Wasserstein 거리는 연속적이며 거의 모든 곳에서 미분 가능하므로 Generator에 의미 있는 gradient를 제공한다. 또한 Lipschitz 제약을 사용한 critic은 모든 공간에서 일관된 gradient를 제공하여 one-hot vector 공간으로 수렴하도록 유도한다. 

기존 연구들은 주로 gradient 추정 기법을 사용하여 GAN 기반 언어 모델을 학습하는데, 이와 달리 WGAN-GP 방식은 Softmax 출력을 그대로 critic에 전달하는 방식을 사용해 구현이 더 간단하다. 하지만 논문에서는 대형 언어 모델에도 적용될 수 있을지에 대한 불확실성도 제시하고 있다. 

 

6) Meaningful loss curves and detecting overfitting

 

A. WGAN의 손실 곡선

WGAN의 loss는 샘플 품질과 상관관계를 가지며 최소값으로 수렴하는 특징이 있는데 WGAN-GP 또한 이 특성을 가지는지 검증하기 위해, LSUN Bedrooms 데이터 셋에서 WGAN-GP를 학습하고 critic의 손실을 음수로 변환하여 Figure 5a를 통해 시각화했다. 결과적으로 손실이 수렴하며 Generator가 Wasserstein 거리를 최소화 하는 방향으로 학습됨을 확인할 수 있다.

 

B. Overfitting 분석 실험

Overfitting 발생 시 Discriminator의 Wasserstein 거리 추정이 부정확해지고, 샘플 품질과의 상관관계가 깨질 수 있다. 이를 검증하기 위해 MNIST 데이터셋에서 1000장의 이미지만 사용하여 WGAN을 학습하고 Figure 5b를 통해 학습 데이터와 검증 데이터의 손실 곡선을 비교한다.

 

C. 결과

  • WGAN과 WGAN-GP 모두 critic의 training loss과 validation loss의 차이가 커지는 Discriminator의 overfitting 현상이 발생한다. 
  • 하지만 WGAN-GP에서는 training loss는 점차적으로 증가하면서 validation loss는 감소하는 경향을 보이는데, 이는 WGAN-GP가 critic의 손실 곡선을 통해 overfitting을 측정할 수 있음을 의미한다. 

 

6. Conclusion

WGAN-GP는 weight clipping 문제를 해결하면서도 Wasserstein 손실의 유용성을 유지하는 강력한 GAN 학습 기법을 제안한다. 다양한 네트워크에서 안정적인 학습 성능을 보였으며, overfitting 탐지 기능도 제공한다. 향후 연구 방향으로는 더 큰 규모의 데이터셋과 GAN 구조에 적용 가능성을 탐색할 필요가 있다고 제시한다. 

 

Reference


https://ysbsb.github.io/gan/2022/02/18/WGAN-GP.html

 

WGAN-GP 논문 리뷰 - Improved Training of Wassertein GANs (NIPS2017) | mocha's machine learning

안녕하세요. 모카의 머신러닝 입니다. 이번 포스팅은 “Improved Training of Wassertein GANs, NIPS 2017”에 대해 리뷰합니다. 영어로 된 논문을 한글로 같이 해석하며 논문에서 의미하는 것이 무엇인지 같

ysbsb.github.io

https://velog.io/@pabiya/Improved-Training-of-Wasserstein-GANs

 

Improved Training of Wasserstein GANs

오늘 리뷰할 논문은 WGAN을 보완한 WGAN-GP 논문이다.GANs는 강력한 generative model이지만 training instability에 시달린다. WGAN은 stable training에 성취를 이루었지만 여전히 poor samples만을 생성하거나

velog.io

https://jeongwooyeol0106.tistory.com/33

 

[논문] Improved Training of Wasserstein GANs(WGAN-gp)

https://arxiv.org/pdf/1704.00028.pdf 해당 논문을 보고 작성했습니다. Abstract GAN(Generative Adversarial Network)는 powerful 한 모델이지만 학습 불안정성을 보입니다. 최근 제시된 Wasserstein GAN(WGAN)은 안정된 GAN 모

jeongwooyeol0106.tistory.com

https://cl2020.tistory.com/32

 

Improved Training of Wasserstein GANs,2017

Improved Training of Wasserstein GANs Abstract Generative Adversarial Networks (GANs) are powerful generative models, but suffer from training instability. The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but some

cl2020.tistory.com