Skip to content

PyTorch implementation of Vanilla GAN architecture

License

Notifications You must be signed in to change notification settings

Rish-01/PyTorch-GANs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch GANs 💻 🎨

This repo contains PyTorch implementation of Vanilla GAN architecture.

Table of Contents

Understanding GANs

GAN stands for Generative Adversarial Networks, which is a type of deep learning model that consists of two networks: a generator and a discriminator. The generator network learns to generate realistic-looking fake data (e.g. images, audio, text) from random noise, while the discriminator network learns to distinguish the fake data from the real data. The two networks are trained simultaneously in an adversarial manner, where the generator tries to fool the discriminator by generating more realistic fake data, while the discriminator tries to correctly identify the real and fake data.

The original paper introducing GANs is titled Generative Adversarial Networks and was authored by Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. It was published in 2014 at the Conference on Neural Information Processing Systems (NIPS).

GANs have two components:

  1. Generator Network: The generator network samples from an Isotropic Gaussian distribution and applies a transformation so that the resulting distribution mimics the data distribution.
  2. Discriminator Network: The discriminator network is a classifier trained to discriminate between real and generated samples.

Vanilla GAN

Vanilla GAN is my implementation of the original GAN paper with certain modifications mostly in the model architecture, like the usage of LeakyReLU and 1D batch normalization.

GAN Loss Function

Total GAN Loss

The GAN loss is a min-max optimization problem which is why it is also known as adversarial loss. $p_{\text{data}}(x)$ is the data distribution and $p_z(z)$ is the model distribution. Like any other generative model, the goal is to minimize some kind of divergence metric between these two distributions. GAN loss can be seen as a minimization of a general class of divergence metrics called f-divergences. The final loss is given as:

$$\mathcal{L}_{\text{GAN}} = \min_{\phi} \max_{\theta} \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D_{\theta}(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D_{\theta}(G_{\phi}(z)))]$$

Discriminator Loss

The optimization problem for the discriminator is:

$$\mathcal{L}_D = \max_{\theta} \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D_{\theta}(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D_{\theta}(G_{\phi}(z)))]$$ $$\mathcal{L}_D = \min_{\theta} - \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D_{\theta}(x)] - \mathbb{E}_{z \sim p_z(z)}[\log(1 - D_{\theta}(G_{\phi}(z)))]$$

Generator Loss

The loss for the generator is simplified as:

$$\mathcal{L}_G = \min_{\phi} \mathbb{E}_{z \sim p_z(z)}[\log(1 - D_{\theta}(G_{\phi}(z)))]$$ $$\mathcal{L}_G = \max_{\phi} \mathbb{E}_{z \sim p_z(z)}[\log D_{\theta}(G_{\phi}(z))]$$ $$\mathcal{L}_G = \min_{\phi} - \mathbb{E}_{z \sim p_z(z)}[\log D_{\theta}(G_{\phi}(z))]$$

The expectations in the above equations are computed using Monte Carlo approximations by taking sample averages.

Sample Outputs

GAN was trained on data from the MNIST dataset. Here is how the generated digits look like:


Acknowledgements

I've used the following repositories as reference for implementing my version: