https://arxiv.org/abs/1411.1784
0. Abstract
Generative Adversarial Nets의 조건을 추가한 conditional 버전이다. 기존 GAN은 우리가 원하는 조건의 데이터를 생성할 수 없지만 새로운 CGAN은 기존 GAN에 조건을 추가로 넣어 우리가 원하는 모습으로 제어할 수 있다.
위 논문에서는 CGAN이 class label에 맞는 MNIST 숫자 이미지를 생성하는 것을 보인다. 추가적으로 Multi-Modal model을 학습하는 방법에 대해 설명하며 descriptive tags를 생성하는 방법을 보여주는 image tagging에 대한 예시도 설명한다.
1. Introduction
기존 GAN의 장점으로는 Markov chain의 불필요성, backpropagation을 통한 gradient 확보, 다양한 factor와 model의 결합 가능성이 있다.
Unconditioned generative model에서는 Generator가 생성하는 데이터에 대한 제어할 방법이 없는데 condition을 추가함으로 데이터 생성 과정을 조절할 수 있다. 일부 데이터나 다른 modality를 가지는 데이터를 condition으로 사용할 수 있다. 이 방식으로 Conditional GAN을 구성하고 class label에 따라 조절된 MNIST dataset과 multi-modal learning을 위한 테스트를 진행한다.
2. Related Work
A. Multi-Modal learning for image labeling
- Supervised Neural Network 문제점
- 굉장히 많은 예측 타겟 출력 범주 수의 수용이 어려움
- 해결하기 위한 방법으로는 다른 modality에서 추가 정보를 불러와 활용한다.
- Ex) Geometric 관계가 의미 있는 label에 대한 vector 표현을 학습하기 위해 자연어 단어 사용한다. 이를 통해 예측에 실패하더라도 실제 값에 가깝고, training 동안 없었던 label에 대해 generalization할 수 있게 된다. Image feature space에서 word representation space로 단순한 선형 mapping을 해주는 것 만으로도 분류 성능을 크게 높여주었다.
- Input -> Output으로 one-to-one mapping(일대일 매핑)을 학습하기 때문에 one-to-many mapping(일대다 매핑)의 어려움
- 해결하기 위한 방법으로는 Conditional probabilistic generative model을 사용한다. 입력은 conditional 변수로 간주되고 일대다 mapping은 조건부 분포를 예측하는 것으로 이루어진다.
3. Conditional Adversarial Nets
A. Generative Adversarial Nets
GAN은 데이터 생성 모델을 훈련시키는 방법으로 실제와 유사한 데이터 분포를 구축한 Generative model G와 데이터가 G가 아닌 실제 훈련 데이터에서 나올 확률을 추정하는 Discriminative model D로 구성된다.
데이터 $x$에 대한 Generator 분포인 $Pg$를 학습하기 위해 Generator는 prior noise distribution $Pz(z)$에서 데이터 공간으로 $G(z;θg)$를 구축한다. Discriminator는 sample이 $Pg$가 아닌 실제 학습 데이터에서 나올 확률을 0~1 사이 값으로 출력한다.
기존 GAN은 random noise vector 값에 따라 무작위 출력물을 생성하기 때문에 생성되는 데이터의 속성(이미지 클래스나 스타일)를 제어 및 조정할 수 없다는 단점이 있다.
※ $log$ 함수의 사용 이유 ※
- Cross-Entropy Loss와의 관계: 기본적으로 GAN은 G와 D의 loss를 최적화하며 이를 Cross-Entropy Loss로 표현하기에 사용된다. 이를 통해 확률의 차이를 명확하게 구분하고 확률 분포의 적합성을 보여준다.
- 안정적인 학습: 학습 초기에 G가 잘못된 샘플을 생성하더라도 Loss가 과도하게 커지지 않아 Gradient Explosion을 방지하고 D가 진짜 데이터를 구별하지 못할 때, 이를 개선하는 방향으로 Gradient를 제공한다.
- 수학적 의미(Likelihood Maximization): Generator는 G가 진짜 데이터일 확률인 $D(G(z))$를 최대화하려고 하므로, 로그를 사용한다.
B. Conditional Adversarial Nets
Conditional GAN은 기존 GAN에 $y$라는 추가정보와 함께 conditional model로 확장시킨다. 최종적으로 두 네트워크의 학습이 균형을 이루면 G는 조건 $y$에 기반한 사실적인 데이터를 생성할 수 있게 된다.
- Generator
기존 GAN의 입력 $z$(노이즈 벡터) 외에 조건 $y$를 함께 입력으로 받아(joint hidden representation으로 결합) 조건부 데이터를 생성한다.
- Discriminator
입력 데이터가 실제 데이터 $x$인지, Generator가 생성한 데이터인지 판별할 때, 조건 $y$를 함께 고려한다.
Ex)
- Generator
Noise $z$와 함께 $y=3$와 입력되어 숫자 3을 생성하도록 이미지가 생성되고, $y=5$일 때는 숫자 5를 생성하도록 학습한다. Noise $z$는 잠재 공간에서 샘플링 된 값으로서 다양한 데이터를 생성하는 데 필요한 다양성을 제공한다. $y=3$의 경우 one-hot vector (0, 0, 1, 0, 0, 0, 0, 0, 0, 0), $y=5$의 경우 (0, 0, 0, 0, 1, 0, 0, 0, 0, 0)로 입력값에 들어간다.
- Discriminator
생성된 이미지와 함께 실제 이미지와 label을 one-hot vector로 입력 값으로 부여한다. 예를 들어 $y$=3, $x$=3일 때 $D(3, 3) = 0.6$로 출력되면, Discriminator는 데이터가 60%의 확률로 '진짜'라고 판단한다. 반대로, $y=8$, $x=3$이면 Discriminator는 '가짜'로 판단한다.
※ one-hot vector: 단어 집합의 크기를 벡터의 차원으로 하고, 표현하고 싶은 단어의 index에 1의 값을 부여하고, 다른 인덱스에는 0을 부여하는 단어의 벡터 표현 방식이다.
C. Cost Function
기존의 GAN과 모양이 같지만 Discriminator와 Generator에 조건 $y$가 추가된다.
- $Pdata(x)$: 실제 데이터 x의 분포
- $Pz(z)$: 잠재 공간 z의 분포(보통 정규분포 $N(0, 1)$ 또는 균등분포 $U(-1, 1)$로 설정)
- $z$: 잠재 공간(Latent space)에서 샘플링된 random noise vector. $z$ ~ $Pz(z)$
- $y$: Generator와 Discriminator에 주어진 추가 조건(e.g., 클래스 레이블, 추가적인 정보)
- $G(z|y)$: 조건 $y$에 기반해 생성된 데이터 샘플
- $x$: 실제 데이터 샘플
- $D(x|y)$: 입력 $x$가 조건 $y$에 기반한 실제 데이터일 확률
- $logD(x|y)$: 실제 데이터 $x$와 조건 $y$에 대해 Discriminator가 '진짜'라고 판단한 확률의 로그
- $log(1-D(G(z|y)))$: Generator가 생성한 샘플 $G(z|y)$와 조건 $y$에 대해 Discriminator가 '가짜'라고 판단한 확률의 로그
4. Experimental Results
A. Unimodal
One-hot vector로 구성된 class lablel의 MNIST로 CGAN 학습을 진행한다. CGAN은 4개의 linear layer로 구성된다.
- Generator
- Uniform Distribution $z$ size = 100
- 변수 $z$ size = 200 & 변수 $y$ size = 1000의 hidden layer(ReLU)로 mapping 되어 hidden ReLU layer로 합쳐진다.
- 784($1*28*28$)차원의 MNIST sample을 생성하기 위해 sigmoid layer를 거친다.
- Discriminator
- $x$는 240 unit, 5 piece maxout layer, $y$는 50 unit, 5 piece maxout layer로 mapping된다.
- 240 unit, 5 piece maxout layer로 합쳐진 후 sigmoid layer를 통해 최종 출력값을 출력한다.
Table 1은 MNIST dataset의 Gaussian Parzen Window-based log-likelihood estimate를 보여준다. 각 10개 class에서 1000개의 sample이 추출되었고, Gaussian Parzen window가 sample들에 적용되었다. 이후, Parzen window distribution 기반으로 density function을 추정하여 test set log-likelihood를 추정한다. Hyeperparamter 최적화와 아키텍처 구조에 대한 추가 연구를 통해 성능 향상이 가능하다.
※ Gaussian Parzen Window
데이터 샘플의 분포를 기반으로 새로운 데이터 포인트 $x$가 나올 확률을 추정하고 가우시안 분포를 각 데이터 포인트 $x$에 생성하고, 이 분포의 합으로 전체 밀도를 계산하는 방법이다.
※ Log-Likelihood
데이터 $x$가 주어진 모델에서 관찰될 가능성을 측정하는 지표로서 추정된 분포가 실제 데이터 분포를 얼마나 잘 설명하는지 측정한다. 높은 값일수록 모델이 테스트 데이터를 잘 설명한다는 것을 의미한다.
B. Multimodal
입력 정보는 image domain이지만, 출력은 word token이라는 점이 Multimodal Learning으로 볼 수 있다.
MIR Flickr 25,000 dataset(한 이미지 당 평균 9개 tag 가지고 있는 데이터 셋)를 활용한다. CGAN과 language model을 사용해서 tag 이미지와 user generated metadata(사람이 이미지의 객체를 자연어로 설명하는 방법에 근접하고 각 사용자마다 표현하는 어휘가 다르기에 유용한 실험)를 포함한다.
이미지의 feature를 추출하기 위해 2,100 label의 ImageNet에 pre-trained 된 CNN 모델을 사용해 마지막 fc layer의 출력값을 사용한다. 텍스트 표현을 위해 YFCC100M 2 dataset 메타데이터에서 사용자 생성 메타 데이터(태그, 제목, 설명)를 포함하여 text를 수집한다.
텍스트를 정제한 후 단어 벡터 크기가 200인 skip-gram model(주어진 단어로 주변 단어 예측)을 훈련시켰다. 그리고 단어에서 200번 미만의 단어를 생략하고 247,465개 단어의 dictionary를 만들었다. Conceptual word embeddings는 비슷한 개념의 단어들이 비슷한 vector로 표현되게 만든다.
- Generator
- 100 dim noise는 500 dim ReLU layer에, 4096 dim 이미지 feature는 2000 dim ReLU layer에 매핑한다.
- 두 ReLU layer는 linear layer에 연결되어 200 dim의 word vector를 생성한다.
- Discriminator
- $(x, y)$ pair가 $(G(z), y)$ pair와 matching이 되는지를 판단하게 된다.
- Word vector와 image를 input으로 받아서 하나의 sigmoid를 출력해 확률을 계산한다.
평가를 위해 각 이미지에 대해 100개의 sample을 생성하고 각 sample에 대한 단어 vector 표현의 cosine similarity를 사용하여 가장 가까운 상위 20개의 단어를 찾는다. 그리고 100개의 sample 가운데 가장 일반적인 상위 10개의 단어를 선택한다.
5. Future Work
Conditional GAN의 잠재력을 보여주며 다양한 응용의 가능성을 보여준다.
- 여러 태그 동시에 사용시 더 좋은 결과 도출 기대
- 언어 모델을 학습시키기 위한 계획 구축
- 데이터 증강, 콘텐츠 생성, 예술 작품 제작 등 다양한 분야로의 확장 기대
Reference
https://velog.io/@wilko97/Conditional-GAN
https://ddongwon.tistory.com/126
https://blog.naver.com/winddori2002/222222304740
https://nowolver.tistory.com/144