Pix2Pix - Image to image translation with Conditional Adversarial Networks
A tutorial on Pix2Pix Conditional GANs and implementation with PyTorch
GANs
PyTorch
Published
February 13, 2021
Author Introduction
Hi! My name is Aniket Maurya. I am a Machine Learning Engineer at Quinbay Technologies, India. I research and build ML products for an e-commerce giant. I like to share my limited knowledge of Machine Learning and Deep Learning with on my blog or YouTube channel. You can connect with me on Linkedin/Twitter.
Introduction to Conditional Adversarial Networks
Image to Image translation means transforming the given source image into a different image. Gray scale image to colour image conversion is one such example of image of image translation.
In this tutorial we will discuss GANs, a few points from Pix2Pix paper and implement the Pix2Pix network to translate segmented facade into real pictures. We will create the Pix2Pix model in PyTorch and use PyTorch lightning to avoid boilerplates.
GANs are Generative models that learns a mapping from random noise vector to an output image. G(z) -> Image (y)
For example, GANs can learn mapping from random normal vectors to generate smiley images. For training such a GAN we just need a set of smiley images and train the GAN with an adversarial loss 🙂. After the model is trained we can use random normal noise vectors to generate images that were not in the training dataset.
But what if we want to build a network in such a way that we can control what the model will generate. In our case we want the model to generate a laughing smiley.
Conditional GANs are Generative networks which learn mapping from random noise vectors and a conditional vector to output an image. Suppose we have 4 types of smileys - smile, laugh, sad and angry (🙂 😂 😔 😡). So our class vector for smile 🙂 can be (1,0,0,0), laugh can be 😂 (0,1,0,0) and similarly for others. Here the conditional vector is the smiley embedding.
During training of the generator the conditional image is passed to the generator and fake image is generated. The fake image is then passed through the discriminator along with the conditional image, both fake image and conditional image are concatenated. Discriminator penalizes the generator if it correctly classifies the fake image as fake.
Pix2Pix
Pix2Pix is an image-to-image translation Generative Adversarial Networks that learns a mapping from an image X and a random noise Z to output image Y or in simple language it learns to translate the source image into a different image.
During the time Pix2Pix was released, several other works were also using Conditional GANs on discrete labels. Pix2Pix uses a U-Net based architecture for the Generator and for the Discriminator a PathGAN Classifier is used.
Pix2Pix Generator is an U-Net based architecture which is an encoder-decoder network with skip connections. Both generator and discriminator uses Convolution-BatchNorm-ReLu like module or in simple words we can say that it is the unit block of the generator and discriminator. Skip connections are added between each layer i and layer n − i, where n is the total number of layers. At each skip connection all the channels from current layer i are concatenated with all the channels at n-i layer.
We create the unit module that will be used in Generator and Discriminator (Convolution->BatchNorm->ReLu). We also keep our option open to use DropOut layer when we need.
class ConvBlock(nn.Module):""" Unit block of the Pix2Pix """def__init__(self, in_channels, out_channels, use_dropout=False, use_bn=True):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)if use_bn:self.batchnorm = nn.BatchNorm2d(out_channels)self.use_bn = use_bnif use_dropout:self.dropout = nn.Dropout()self.use_dropout = use_dropoutself.activation = nn.LeakyReLU(0.2)def forward(self, x): x =self.conv1(x)ifself.use_bn: x =self.batchnorm(x)ifself.use_dropout: x =self.dropout(x) x =self.activation(x)return x
In the first part of U-Net network the layer size decreases, we create a DownSampleConv module for this. This module will contain the unit block that we just created ConvBlock.
class DownSampleConv(nn.Module):def__init__(self, in_channels, use_dropout=False, use_bn=False):super().__init__()self.conv_block1 = ConvBlock(in_channels, in_channels *2, use_dropout, use_bn)self.conv_block2 = ConvBlock( in_channels *2, in_channels *2, use_dropout, use_bn )self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)def forward(self, x): x =self.conv_block1(x) x =self.conv_block2(x) x =self.maxpool(x)return x
Now in the second part the network expands and so we create UpSampleConv
class UpSampleConv(nn.Module):def__init__(self, input_channels, use_dropout=False, use_bn=True):super().__init__()self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)self.conv1 = nn.Conv2d(input_channels, input_channels //2, kernel_size=2)self.conv2 = nn.Conv2d( input_channels, input_channels //2, kernel_size=3, padding=1 )self.conv3 = nn.Conv2d( input_channels //2, input_channels //2, kernel_size=2, padding=1 )if use_bn:self.batchnorm = nn.BatchNorm2d(input_channels //2)self.use_bn = use_bnself.activation = nn.ReLU()if use_dropout:self.dropout = nn.Dropout()self.use_dropout = use_dropoutdef forward(self, x, skip_con_x): x =self.upsample(x) x =self.conv1(x) skip_con_x = center_crop(skip_con_x, x.shape[-2:]) x = torch.cat([x, skip_con_x], axis=1) x =self.conv2(x)ifself.use_bn: x =self.batchnorm(x)ifself.use_dropout: x =self.dropout(x) x =self.activation(x) x =self.conv3(x)ifself.use_bn: x =self.batchnorm(x)ifself.use_dropout: x =self.dropout(x) x =self.activation(x)return x
Now the basic blocks of the Pix2Pix generated is created, we create the generator module. Generator is formed of expanding and contracting layers. The first part network contracts and then expands again, i.e. first we have encoder block and then decoder block. Below is the encoder-decoder of U-Net network configuration from official paper. Here C denotes the unit block that we created ConvBlock and D denotes Drop Out with value 0.5. In the decoder, the output tensors from n-i layer of encoder concatenates with i layer of the decoder. Also the first three blocks of the decoder has drop out layers.
class Generator(nn.Module):def__init__(self, in_channels, out_channels, hidden_channels=32, depth=6):super().__init__()self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)self.conv_final = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)self.depth = depthself.contracting_layers = []self.expanding_layers = []self.sigmoid = nn.Sigmoid()# encoding/contracting path of the Generatorfor i inrange(depth): down_sample_conv = DownSampleConv( hidden_channels *2**i, )self.contracting_layers.append(down_sample_conv)# decoder/Expanding path of the Generatorfor i inrange(depth): upsample_conv = UpSampleConv( hidden_channels *2** (i +1), use_dropout=(Trueif i <3elseFalse) )self.expanding_layers.append(upsample_conv)self.contracting_layers = nn.ModuleList(self.contracting_layers)self.expanding_layers = nn.ModuleList(self.expanding_layers)def forward(self, x): depth =self.depth contractive_x = [] x =self.conv1(x) contractive_x.append(x)for i inrange(depth): x =self.contracting_layers[i](x) contractive_x.append(x)for i inrange(depth -1, -1, -1): x =self.expanding_layers[i](x, contractive_x[i]) x =self.conv_final(x)returnself.sigmoid(x)
Discriminator
A discriminator is a ConvNet which learns to classify images into discrete labels. In GANs, discriminators learns to predict whether the given image is real or fake. PatchGAN is the discriminator used for Pix2Pix. Its architecture is different from a typical image classification ConvNet because of the output layer size. In convnets output layer size is equal to the number of classes while in PatchGAN output layer size is a 2D matrix.
Now we create our Discriminator - PatchGAN. In this network we use the same DownSampleConv module that we created for generator.
Loss function used in Pix2Pix are Adversarial loss and Reconstruction loss. Adversarial loss is used to penalize the generator to predict more realistic images. In conditional GANs, generators job is not only to produce realistic image but also to be near the ground truth output. Reconstruction Loss helps network to produce the realistic image near the conditional image.