This repo contains PyTorch implementation of Vanilla GAN architecture.
GAN stands for Generative Adversarial Networks, which is a type of deep learning model that consists of two networks: a generator and a discriminator. The generator network learns to generate realistic-looking fake data (e.g. images, audio, text) from random noise, while the discriminator network learns to distinguish the fake data from the real data. The two networks are trained simultaneously in an adversarial manner, where the generator tries to fool the discriminator by generating more realistic fake data, while the discriminator tries to correctly identify the real and fake data.
The original paper introducing GANs is titled Generative Adversarial Networks and was authored by Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. It was published in 2014 at the Conference on Neural Information Processing Systems (NIPS).
GANs have two components:
- Generator Network: The generator network samples from an Isotropic Gaussian distribution and applies a transformation so that the resulting distribution mimics the data distribution.
- Discriminator Network: The discriminator network is a classifier trained to discriminate between real and generated samples.
Vanilla GAN is my implementation of the original GAN paper with certain modifications mostly in the model architecture, like the usage of LeakyReLU and 1D batch normalization.
The GAN loss is a min-max optimization problem which is why it is also known as adversarial loss.
The optimization problem for the discriminator is:
The loss for the generator is simplified as:
The expectations in the above equations are computed using Monte Carlo approximations by taking sample averages.
GAN was trained on data from the MNIST dataset. Here is how the generated digits look like:
I've used the following repositories as reference for implementing my version:
- pytorch-GANs (PyTorch)
- research_implementations (PyTorch)