본문 바로가기
Computer Science/Deep Learning

[논문] Generative Adversarial Nets (1)

by 리코더@typing4life 2018. 4. 18.

Generative Adversarial Nets - Ian Goodfellow et al. 2014. 논문 리뷰

이 논문은 2014년에 Ian Goodfellow가 NIPS에서 발표한 paper로, 최근 몇 년 동안 인기몰이를 하고 있는 논문입니다. 가장 기본이 되는 이 논문에서는 노이즈로부터 이미지를 생성하는 아이디어를 제시했지만, 다양한 논문이 나오면서 더 발달된 이미지 생성(Image Generation), 자연어처리 (NLP, Natural Language Processing) 등 다양한 분야에서 엄청난 응용 가능성을 보여주고 있습니다.


이 논문에서 핵심 내용은 다음과 같습니다.

  1. GAN의 기본 아이디어 제시
  2. GAN 모델이 global minimum으로 최적화되어 있고, 주어진 문제는 unique한 답을 가지고 있을 경우 -log4로 수렴함을 증명.

1. GAN의 기본 개념

먼저 GAN의 이름에서부터 대략적인 개념을 파악할 수 있습니다.


'Generative'은 발생의, 생성의 라는 뜻의 단어입니다. 무엇인가를 만드는 역할을 하는 모델이란 것을 확인할 수 있습니다. 'Adversarial'은 대립 관계에 있는, 적대적인 이라는 뜻을 가지고 있습니다. 따라서 무엇인가를 생성한 다음 누군가와 대립하는 Network라는 것을 알 수 있습니다. 무엇인가 대립한다는 것은 상대방이 있다는 뜻이니 GAN은 다른 모델들과는 다르게 주인공이 두 녀석이라는 것을 직관적으로 알 수 있습니다.


이 논문에서는 Image를 만들어내는 모델(Generator)과, 이 모델에서 만들어진 이미지를 판별하고자 시도하는 모델(Discriminator)이라는 두 모델을 서로 대립(Adversarial)시켜서 서로의 성능을 올릴 수 있다는 개념을 제시했습니다. 즉, Generator는 진짜 같은 이미지를 생성해서 Discriminator를 속이는 것을 목표로 학습하고, Discriminator은 Generator한테 속지 않고 진짜 데이터를 판별하는 것을 목표로 학습하게 됩니다.


논문에서는 이를 지폐위조범과 경찰을 예로들어 설명하였습니다.


지폐위조 팀(Generator)은 진짜 지폐와 최대한 비슷하게 만들어서 될 수 있는한 경찰에게 걸리지 않으려고 한다. 반면에 경찰(Discriminator)은 가짜 지폐를 구별(Classify)하고자 노력한다.

이러한 경쟁이 계속되면 두 팀 모두 각자의 능력이 향상되고, 결국에는 경찰이 진짜와 가짜를 구별할 수 없을때까지 진행된다. (경찰이 찾을 확률 = 0.5)


원문

The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistinguishable from the genuine articles.

조금 더 이론적으로 풀어나가자면, Generative model G는 트레이닝 데이터 x의 분포(distribution)을 흉내내는 방향으로 training됩니다. 만약에 G가 흉내내고자 하는 x의 분포를 정확하게 모사할 경우, Discriminator model D는 G가 생성한 sample과 진짜 데이터 x를 구별할 수 없게 됩니다.

Discriminator model D는 현재 판별하고자 하는 sample이 training data로 부터 온 것인지(real_data), G로부터 만들어진 것인지(fake_data)를 판별하여 real_data일 확률을 추측(estimate)합니다.


그림 1. GAN model _ Generator과 Discriminator

*출처 : https://towardsdatascience.com/understanding-generative-adversarial-networks-4dafc963f2efsdf


위 그림 1을 보면 GAN의 학습이 어떻게 이루어지는지 알 수 있습니다.

training data에서 나온 sample x는 D(x) = 1이 되고, 임의의 noise distribution인 input z를 G에 넣고 만들어진 sample에 대해서는 D(G(z)) = 0이 되도록 학습됩니다. 다시 말해서 D는 잘못 판단할 확률을 최소화하도록(min), G는 D가 잘못 판단할 확률을 최대화하도록(max) 학습이 진행되는데, 이를 하나의 문제로 놓고 바라보면 'two-player minimax game'라고 할 수 있습니다.



2. Adversarial nets

이 논문은 기본적으로 G, D 모델을 multilayer perceptrons를 사용해서 구성했습니다. 

물론 추후에 DCGAN paper 이후에 본격적으로 Convolution 기법을 사용한 GAN모델들이 나오기 시작하지만, 이번 리뷰에서는 original GAN 모델에 대해서만 작성하겠습니다.


학습 초반에는 그림 1과 같이 G가 생성한 이미지는 누가 봐도 fake일 만큼 형편이 없습니다. 따라서 학습 초반에는 D(G(z))의 결과는 0에 가깝습니다. 하지만 학습이 진행 될 수록 D(G(z))의 값이 1이 되게끔 fake_image의 결과물이 발전하게 됩니다. 이 내용을 수식으로 나타내면 다음과 같습니다.




참고

  • Ex~pdata(x)[log D(x)] : training data(real data) x를 Discriminator에 넣었을 때 나오는 결과를 log를 취했을 때 얻는 기댓값을 의미합니다.
  • Ez~pz(z)[log (1-D(G(z)))] : noise distribution(fake data) z를 Generator에 넣었을 때 나오는 결과(fake image)를 Discriminator에 넣습니다. 그 결과를 log (1 - 결과)를 했을 때 얻는 기댓값을 의미합니다.
  • θg, θd : G와 D모델은 위에서 말씀드렸듯이 multilayer perceptrons을 기반으로 작동합니다. 세타는 이 모델에서 쓰이는 parameter들을 의미합니다.
  • Px~Pdata(x) : G와 D에 들어가는 input이 무엇을 바탕으로 나왔는지 알려주는 표기(notation)입니다. 다시 말해서, 이 표현은 x가 p_data(x), 즉 x는 training data(real_data)에서 나온 분포라는 것을 의미합니다. 두 번째 항에 있는 Pz~Pz(z)는 noise distribution(fake data)에서 나온 분포라는 의미입니다.

결론부터 말씀드리자면, D의 입장에서 value function V(D,G)의 이상적인 결과(최댓값)은 '0' 입니다. 또한 G의 입장에서 V(D,G)의 이상적인 결과(최솟값)는 '-∞' 입니다. 값을 넣어보면 식을 직관적으로 이해할 수 있습니다.


만약에 D가 매우 뛰어난 성능을 보이고 있다고 가정합니다. D에 들어온 sample x가 실제로 real_data에서 온 sample일 경우 D(x) = 1이 되므로 첫 번째 항에서 'log 1 = 0'이 되어 사라집니다. 그리고 G(z)가 생성한 fake_image를 잘 판별할 수 있어서 D(G(z)) = 0이라는 결과를 나타내어 두 번째 항에서 'log (1 - 0) = log 1 = 0'이 되어 식 전체 값은 0이 됩니다. 따라서 V(D,G) = 0이 D의 입장에서 얻을 수 있는 '최댓값'이라는 것을 알 수 있습니다.


반대로 이번에는 G가 아무도 못알아보게끔 image를 잘 생성한다고 가정합니다. 일단 G는 첫 번째 항에는 관여할 수 없습니다. 오로지 D가 real_data로 부터 나온 sample x를 얼마나 잘 판별하는지에 대한 항이기 때문입니다. G의 입장에서 확인해야할 항은 두 번째 항입니다. G가 매우 뛰어난 성능을 보여 D를 무조건 속인다고 가정했기 때문에 D(G(z))에서 D는 real이라고 판단해서 1의 결론을 내놓습니다. 따라서 'log (1 - 1) = -' 라는 결과가 나오게 됩니다. 즉, - G의 입장에서 얻을 수 있는 V(D,G)의 '최솟값'이라는 것을 알 수 있습니다.


그림 2. examples of generative models distribution and discriminative models distribution

*출처 : Generative Adversarial Nets(NIPS 2014)


GAN은 discriminative distribution(파란색 점선)을 동시에 업데이트하면서 학습됩니다. 따라서 D는 data generating distribution(검은색 점선)으로 비롯된 sample을 generative distribution(녹색 실선)으로 부터 나온 sample으로부터 판별하도록 학습됩니다. 그림 2에서 아래 수평선 x, z는 각각의 domain을 의미합니다. z 수평선에서 x 수평선으로 향하는 화살표의 의미는 x = G(z)의 mapping입니다. Generator에 z를 넣었을 때의 분포 변화를 그림으로 표현했다고 해석할 수 있습니다.


이 그림이 말하고자 하는 것은 (a)처럼 real과 fake의 분포가 전혀 다르게 생긴 것을 볼 수 있고, 현재 generator를 대상으로 discriminator을 학습시킨 결과가 (b)입니다. (a)처럼 들쑥날쑥하게 확률을 판단하는 것이 아니라, 흔들리지 않고 나름 분명하게 확률을 결과로 내놓는 것을 알 수 있습니다. (a)보다 성능이 올라갔다고 표현할 수도 있습니다. 어느정도 D가 학습이 이루어지면, G는 D가 구별하기 어려운 방향으로 학습을 하게되어 (c)와 같이 변합니다. 이 과정을 반복하게 되면 real과 fake가 점점 비슷해지고, 결국에는 (d)와 같이 구분할 수 없게 되어 D가 확률을 0.5로 계산하게 됩니다.


다음 포스팅에서 계속 작성하겠습니다.


반응형

댓글