Generative Adversarial Networks are a special type of Neural Network that can learn the probability distribution of a dataset. Wait, what?!
Imagine you want to create a “procedural” marble texture generator with a Neural Network. You have a set of images you want to use as examples, those are the outputs of your network; but what would be your inputs? Using a GAN, you can correlate those outputs to any set of random variables and, thus, generate an infinite number of marble textures that look similar to that original dataset. This is one of GANs’ many applications.
You will need these resources to follow this tutorial
- Link to the dataset
- Code in Jupyter notebook format
- Trained model
Subscribe to receive them via e-mail.
How does GAN work?
A GAN is composed of two networks, one is called the Generator, and the other is the Discriminator (as named in the original paper).
The generator creates images from random inputs (or any other arbitrary set of inputs). The discriminator receives as inputs real images from the dataset and fake images created by the generator; it is then trained to label real images as real and fake images as fake.
So far so good. The trick is, how to transfer that knowledge from the discriminator to the generator?
For that the discriminator network is stacked atop of the generator network, the weights of the discriminator are frozen, and the network is trained using random values as inputs and a label of real image as output. The result: this stacked network updates the weights of the generator to create real looking outputs from those random inputs!
As we don’t know the “correct answer” from the outset the training of GANs can be time-consuming; it is also prone to instability and a phenomenon called mode collapse. Because of this instability in this tutorial we use a variation of the original GAN formulation known as WGAN.
In a nutshell, WGANs treat the outputs of the discriminator as a distance instead of discrete labels. Hence some changes are made to the transfer function of the output of the discriminator as well as to the training’s loss functions. For a more detailed description of these differences, I refer you to this summary created by the WGAN paper’s authors.
You can download the original code in the main author’s GitHub page. The code I provide in the resources and all through this post is based on that source. I tried to remove most bells and whistles, so the is easier to read.
Before we start, PyTorch
Up until now all tutorials in this blog have in some way or another used Keras as the main API/Framework to train Deep Learning models. Keras is great! It is easy to use, and it acts as an abstraction layer to other frameworks.
BUT, I have one minor issue with it: Keras makes it a bit tricky to implement loss functions that are a bit out of the ordinary. This is the case with GANs and with Reinforcement Learning as well. So, I am using this as an excuse to start using PyTorch more and more in the blog. I explain the pros and cons of PyTorch, how to install it, and how to use it against Maya in another post.
Creating WGAN Texture Generator
To create our texture generator, we need a good texture dataset. The Pixar dataset I used in the Normal Generator tutorial is fine, but, the number of images per individual class is a bit on the low side. So, for this tutorial, I’m using the VGG Describable Texture Dataset, specifically the Marbled class of this dataset. Here is a small sample:
Loading the data
So, the first thing is loading that dataset. You can do it the old way, using scikit-image or some other library to load images to a Numpy array and convert that array into a Torch Tensor PyTorch’s own array format). But fortunately, PyTorch comes with utils for that task:
import torchvision.datasets as dset # (1) from torch.utils.data import DataLoader import torchvision.transforms as transforms # (2) dataset = dset.ImageFolder(root='../data/pixar128/wood/', # (1a) transform=transforms.Compose([ # (2a) transforms.Scale(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # (2b) ])) dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # (3)
(1) The dset module has tools to load data and wrap it in a PyTorch dataset type object. (1a) The ImageFolder tool loads folders from images using a naming scheme, the root folder should have child folders which will be used as class names for the images. (2) Transforms are tools to edit (crop, rescale, grade, and so on) images. (2a) Transforms can be chained using transforms.Compose. (2b) The normalize operation is used here to rescale images to a -1, 1 range; the first tuple represents the mean and the second tuple the deviation. (3) The DataLoader creates an iterator according to the chosen batch_size; this makes it very easy to load the data from within the training loop, as we’ll see later.
Creating the models
The GAN model is composed of two sub-models, the generator, and the discriminator. The generator is a network that starts with a small arbitrary number of inputs that grows at every new layer, the output layer must be of the same size of the images in the training dataset. The discriminator is a network that starts with the exact same number of inputs as the outputs of the generator, and decreases the number of features for every layer, outputting one single scalar value that is a binary classification in the case of GANs or a distance in the case of WGANs. I’ll be using Convolutional layers since we are dealing with images.
Here is the definition for the Generator:
# Define generator model generator = nn.Sequential(nn.ConvTranspose2d(n_z, 512, 4, stride=1, bias=False), # (1) nn.BatchNorm2d(512, 0.1), nn.ReLU(), nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False), # (2) nn.BatchNorm2d(256, 0.1), nn.ReLU(), nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128, 0.1), nn.ReLU(), nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False), nn.BatchNorm2d(64, 0.1), nn.ReLU(), nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1, bias=False), # (3) nn.Tanh(), # (4) ).to(device)
We use the simplest mode to declare the topology of a network in PyTorch, the sequential model, which works much like Keras’ sequential model. Each added layer is connected to the next one.
Note that I’m using a layer called ConvTranspose2d, also known as a deconvolution layer. This is a convolution where the input filter is smaller than the output filter, missing values are substituted for zeros; it acts effectively as an up-sampling operation. Click here to get a visual intuition.
(1) In the first layer, we have an arbitrary number of single pixel inputs (n_z) and we up-scale them to 4[x4] filters using a stride of one. We then normalize outputs and apply a ReLU transfer function. (2) From then on, we apply other ConvTranspose operations, keeping the filter size at 4 but upping the stride to 2 and applying a padding of 1, effectively doubling the size of the output. (3) We end the network with a Tanh transfer function to keep values between -1 and 1 as per the recommendation of authors of the WGAN paper.
Now for the discriminator:
class Discriminator(nn.Module): def __init__(self): # (1) super(Discriminator, self).__init__() main = nn.Sequential() main.add_module('in', nn.Conv2d(n_channels, 64, 4, stride=2, padding=1, bias=False)) main.add_module('in_tf', nn.LeakyReLU(0.2, inplace=True)) main.add_module('h0', nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False)) main.add_module('h0_bn', nn.BatchNorm2d(128)) main.add_module('h0_tf', nn.LeakyReLU(0.2, inplace=True)) main.add_module('h1', nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False)) main.add_module('h1_bn', nn.BatchNorm2d(256)) main.add_module('h1_tf', nn.LeakyReLU(0.2, inplace=True)) main.add_module('h2', nn.Conv2d(256, 512, 4, stride=2, padding=1, bias=False)) main.add_module('h2_bn', nn.BatchNorm2d(512)) main.add_module('h2_tf', nn.LeakyReLU(0.2, inplace=True)) main.add_module('out', nn.Conv2d(512, 1, 4, stride=1, bias=False)) self.main = main def forward(self, x): # (2) output = self.main(x) # (3) return output.mean(0).view(1) # (3a)
Note that we create the discriminator differently, using then nn.Module, which analogous to Kera’s Model API. With it, you can (1) declare any number of layers in the model definition and then (2) use the forward definition to declare how these layers are connected. In our case there are not many customizations, the only thing we do is to output the mean of all batch samples. So, if we feed the network 1 or 100 samples it will only output one single number; this is a particularity of the WGAN implementation.
As for the topology use Conv2d layers of stride 2 and padding of 1, effectively reducing by half the resolution until the final layer reduces the output to a single filter of a single pixel. Leaky ReLU is a variation of ReLU which allows some leaking of values below zero, it is said to outperform ReLU in most cases and is used here per the suggestion of the original WGAN paper.
The training is the most differentiating thing between PyTorch and other frameworks, as I discuss here. PyTorch allows for more customization, but that means there is a bit more code for us to deal with. Instead of calling a fit function we need to implement our own training loop that looks like this:
for epoch in range(epochs): for batch_real,_ in dataset_loader: # weight updates
Not rocket science. Note that the dataset_loader is that iterator we talked about earlier, which provides us with the proper sized, shuffled batches. We retrieve only the images and not the images’ classes from the loader, hence the ‘_’ in ‘batch_real,_’.
And here is what the actual weight update looks like:
for i in range(disc_iters): # (1) #train discriminator discriminator.zero_grad() loss_real = discriminator(batch_real.to(device)) # (2) loss_real.backward(retain_graph=True) with torch.no_grad(): # (3) z = torch.FloatTensor(batch_real.shape, n_z, 1, 1).normal_(0, 1).to(device) fake_imgs = Variable(generator(z), requires_grad=False) loss_fake = discriminator(fake_imgs) # (4) loss_fake.backward(torch.tensor([-1.]).to(device), retain_graph=True) loss_d = loss_real - loss_fake # (5) loss_d.backward(retain_graph=True) # (6) optimizer_d.step() for p in discriminator.parameters(): # (7) p.data.clamp_(-.01, .01) # train generator for p in discriminator.parameters(): # (8) p.requires_grad = False # freeze discriminator update with torch.no_grad(): # (9) z = torch.FloatTensor(batch_real.shape, n_z, 1, 1).normal_(0, 1).to(device) fake_imgs = Variable(generator(z), requires_grad=False) generator.zero_grad() loss_g = discriminator(generator(z)) # (10) loss_g.backward(retain_graph=True) optimizer_g.step() for p in discriminator.parameters(): # (11) p.requires_grad = True # unfreeze discriminator updates
We started updating the discriminator (1) WGAN authors recommend us to update it N times for every time we update the generator. We then (2) feedforward and backpropagate the discriminator with images from the original dataset. (3) We generate a fake batch of images using the generator, forward and backpropagate it as well; mind you that we use torch.no_grad so we don’t end up backpropagating all the way to the generator. (4) When calling the .backward() function, with the fake batch of images, we pass a -1 tensor through it, this makes PyTorch backpropagate gradients in the opposite direction. In effect, this is what we are doing: minimizing the discriminator outputs for real images and increasing discriminator outputs for fake images. We then (5) get what is called Earth-Mover or Wasserstein distance by subtracting the fake loss from the real loss, and backpropagate it and step the optimizer. The last step (7) is to clamp the weights of the discriminator network, this enforces something called a 1-Lipschitz constraint, a requirement for the use of Earth-Mover distance.
For the update of the generator weights we first (8) freeze the weights of the discriminator network we will stack atop it. (9) We generate a new batch of fake input variables, (10) forward pass those through the generator and through the discriminator and then backprop. Note we are backpropagating in the same direction we backpropagated the samples of the real images. At the end (11) we unfreeze the discriminator for the following updates.
In the final code (resources), I have included some heuristics proposed by the authors of the WGAN paper that make the training process smoother.
After some 3000 epochs we should get some very interesting results. Like this:
In the resources for this article I have included an interactive widget where you can play with the values in the input variable and see results in real time:
GAN is a very powerful type of learning, instead of correlating one dataset to another it correlates their probability distribution. In the case of our textures, it correlates a Gaussian distribution to the distribution in images that look like marbles. But this may have many other applications in our fields, like generating poses from sparse pose markers or generating photo-realistic images from outlines and roto-masks.
GANs are also not restricted to image generation and CNNs, you can use them with other types of data and topologies.
- Link to the dataset
- Code in Jupyter notebook format
- Trained model
Subscribe to receive them via e-mail.