CGAN (Conditional GAN)
Specify What Images To Generate With 1 Simple Yet Powerful Change
GAN and DCGAN generate random images. As such, we have almost no control over what images to generate. However, CGAN can let us specify a condition so that we can tell it what images to generate. The trick is to convert label values into feature vectors using a learnable layer so the generator can learn what image to generate. The discriminator also takes advantage of label conditions. It may not be clear to you at this stage but don’t worry. This article will teach how the whole thing works in a simple code.
1 Python Environment Setup
First of all, we create a Python environment. We’ll use venv
as follows:
# Create a project folder and move there
mkdir cgan
cd cgan
# Create and activate a Python environment using venv
python3 -m venv venv
source venv/bin/activate
# We should always upgrade pip as it's usually old version
# that has older information about libraries
pip install --upgrade pip
# We install required libraries under the virtual environment
pip install torch torchvision matplotlib tqdm
The versions of installed libraries are as follows:
matplotlib==3.5.1
torch==1.11.0
torchvision==0.12.0
tqdm==4.64.0
If you prefer conda
, you can create an environment with that. Please make sure to install the required libraries.
2 Conditions as Feature Vectors
2.1 Labels to One-hot Encoded Labels
We’ll be using MNIST dataset that contains many digit images. Labels are integers between 0 and 9 inclusive. By converting labels into feature vectors, we can feed target labels (as conditions) into the generator and random value vectors so that the generated images have some variations.
First, we use PyTorch’s F.one_hot
to convert digits into one-hot encodings.
import torch
from torch.nn import functional as F
# Labels (i.e., 1 and 3)
= torch.LongTensor([1, 3])
labels
# Create one-hot encoded labels
= F.one_hot(labels, num_classes=10)
encoded
print(encoded)
The output is:
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])
So, we have one-hot encoded 1 and 3 into vectors with ten elements. We specified num_classes=10
since the labels are from 0 to 9 (10 numbers), and we need ten elements to uniquely identify each number in one-hot encoding. We use one-hot encoding because the value of a digit has no meaning like rank or order. It is a class (categorical) value, and using one-hot encoding to represent categorical values is a well-established practice in machine learning.
2.2 One-hot Encoded Labels to Feature Vectors
The generator will learn to extract features (conditions) from the one-hot encoded labels. We’ll create a class to convert one-hot encoded labels into feature vectors via a fully-connected layer. Both the generator and the discriminator will use the class.
# Coverts conditions into feature vectors
class Condition(nn.Module):
def __init__(self, alpha: float):
super().__init__()
# From one-hot encoding to features: 10 => 784
self.fc = nn.Sequential(
10, 784),
nn.Linear(784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha))
def forward(self, labels: torch.Tensor):
# One-hot encode labels
= F.one_hot(labels, num_classes=10)
x
# From Long to Float
= x.float()
x
# To feature vectors
return self.fc(x)
2.3 Reshape Helper
We define a helper class for reshaping encoded condition values in the generator and discriminator.
# Reshape helper
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.reshape(-1, *self.shape)
3 Generator Network Definition
The generator is similar to the DCGAN generator we used in this article. As before, we generate random value vectors and use a fully-connected layer to increase the dimensions to 784. We also use the Conditional layer to convert input labels into feature vectors of 784 dimensions. Then, we sum random vectors and label features. It is an element-wise addition operation. We can do that because the random value and label feature vectors have the same number of elements.
The generator network processes the sum of random value vectors and label feature vectors to generate random images for specified label classes. Let’s see how we integrate the Condition
class into the DCGAN generator class.
# Generator network
class Generator(nn.Module):
def __init__(self, sample_size: int, alpha: float):
super().__init__()
# sample_size => 784
self.fc = nn.Sequential(
784),
nn.Linear(sample_size, 784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha))
# 784 => 16 x 7 x 7
self.reshape = Reshape(16, 7, 7)
# 16 x 7 x 7 => 32 x 14 x 14
self.conv1 = nn.Sequential(
16, 32,
nn.ConvTranspose2d(=5, stride=2, padding=2,
kernel_size=1, bias=False),
output_padding32),
nn.BatchNorm2d(
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 1 x 28 x 28
self.conv2 = nn.Sequential(
32, 1,
nn.ConvTranspose2d(=5, stride=2, padding=2,
kernel_size=1, bias=False),
output_padding
nn.Sigmoid())
# Random value sample size
self.sample_size = sample_size
# To convert labels into feature vectors
self.cond = Condition(alpha)
def forward(self, labels: torch.Tensor):
# Labels as feature vectors
= self.cond(labels)
c
# Batch size is the number of labels
= len(labels)
batch_size
# Generate random inputs
= torch.randn(batch_size, self.sample_size)
z
# Inputs are the sum of random inputs and label features
= self.fc(z) # => 784
x = self.reshape(x+c) # => 16 x 7 x 7
x = self.conv1(x) # => 32 x 14 x 14
x = self.conv2(x) # => 1 x 28 x 28
x return x
As you can see, the code converts labels into feature vectors of the same dimension as the random value vectors and performs an element-wise addition operation (x+c)
. In other words, random value and label feature vectors express information required to generate images in the same space.
Let’s see the part of the code in detail.
# Label feature vectors (784)
= self.cond(labels)
c
# Random value vectors (784)
= torch.randn(batch_size, self.sample_size)
z = self.fc(z)
x
# Element-wise addition and reshape from 784 into 16x7x7
= self.reshape(x+c) x
If we don’t have random value vectors and use only the label feature vectors to train the generator, it would learn to generate one image per label input. Having random value vectors is crucial to adding variations to generated images.
In the above code, we use the element-wise addition operation, but that’s not the only way to combine random value vectors and label feature vectors. We could concatenate two vectors into one. In this case, we don’t need to make both vectors have the same number of elements. Alternatively, we could concatenate one-hot encoded labels and random value vectors and feed them through a fully-connected layer to generate input features. We would need to adjust the number of parameters in the fully-connected layer to accommodate different input vector sizes.
In this article, we use the element-wise addition operation since it’s simple to implement, but you may want to try other methods to see how it works.
4 Discriminator Network Definition
We use the Condition
class inside the discriminator network to predict whether input images are real or fake as per the given condition. For example, when a condition indicates an image is for digit 3, the discriminator classifies whether the image is a real image of digit “3” or not. Like the generator, the discriminator has its condition layer that learns to generate features for each label through training.
# Discriminator network
class Discriminator(nn.Module):
def __init__(self, alpha: float):
super().__init__()
# 1 x 28 x 28 => 32 x 14 x 14
self.conv1 = nn.Sequential(
1, 32,
nn.Conv2d(=5, stride=2, padding=2, bias=False),
kernel_size
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 16 x 7 x 7
self.conv2 = nn.Sequential(
32, 16,
nn.Conv2d(=5, stride=2, padding=2, bias=False),
kernel_size16),
nn.BatchNorm2d(
nn.LeakyReLU(alpha))
# 16 x 7 x 7 => 784
self.fc = nn.Sequential(
nn.Flatten(),784, 784),
nn.Linear(784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha),784, 1))
nn.Linear(
# Reshape label features: 784 => 16 x 7 x 7
self.cond = nn.Sequential(
Condition(alpha),16, 7, 7))
Reshape(
def forward(self, images: torch.Tensor,
labels: torch.Tensor,
targets: torch.Tensor):# Label features
= self.cond(labels)
c
# Image features + Label features => real or fake?
= self.conv1(images) # => 32 x 14 x 14
x = self.conv2(x) # => 16 x 7 x 7
x = self.fc(x+c) # => 1
prediction
= F.binary_cross_entropy_with_logits(prediction, targets)
loss return loss
5 CGAN Training
While training, we feed labels to both the discriminator and the generator. Each network generates its features for given labels useful for their objectives. The CGAN training loop trains the discriminator and the generator in turns.
5.1 Discriminator Training
It is the same as DCGAN discriminator training, except we are feeding labels.
# Train loop
for epoch in range(100):
= []
d_losses = []
g_losses
for images, labels in tqdm(dataloader):
#===============================
# Disciminator Network Training
#===============================
# Images from MNIST are considered as real
= discriminator(images, labels, real_targets)
d_loss
# Images from Generator are considered as fake
+= discriminator(generator(labels), labels, fake_targets)
d_loss
# Discriminator paramter update
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
...
5.2 Generator Training
It is the same as DCGAN generator training except that we feed labels.
# Training Loop
for epoch in range(100):
for images, labels in tqdm(dataloader):
...
#===============================
# Generator Network Training
#===============================
# Images from Generator should be as real as ones from MNIST
= discriminator(generator(labels), labels, true_targets)
g_loss
...
The entire source code is available at the end of this article. For now, let’s look at the results of the training.
6 Training Results
6.1 Test Image Generation
After each epoch, we generate eight images for each digit from 0 to 9 using the following code.
# 0 to 9 in a list
= list(range(10))
labels
# Convert to long tensor
= torch.LongTensor(labels)
labels
# Repeat each digit eight times
= labels.repeat(8)
labels
# Flatten (10x8 => 80)
= labels.flatten()
labels
# Generate 80 images
= generator(labels)
generated_images
# Save the results in a grid layout
=10) save_image_grid(epoch, generated_images, ncol
6.2 Epoch 1
After the first epoch, the generated images look like digits per the conditions. It seems to me that adding conditions made it easier for the networks to learn.
6.3 Epoch 50
I’d say the outputs look already satisfactory.
6.4 Epoch 100
Not so different compared with Epoch 50. It may be slightly better looking. Hard to say. In any case, the training succeeded because the generator can produce various images per given conditions.
CGAN would be usefful in generating synthetic training data targeting particular images.
7 Why CGAN Works
So, CGAN is the same as DCGAN, with label features added to the input vectors. That is all, yet it generates as per given label conditions. Why should that work so well?
The generator and the discriminator do not share their Condition
layers, so each network learns independently and adversarially. The generator tries to generate as real images as possible so that the loss calculated by the discriminator becomes smaller. In this process, the Condition
layer must learn to distinguish different label features as much as possible because knowing what to generate to achieve lower loss depends on predicting what digit it should generate. The discriminator’s Condition
layer also learns to distinguish between different digits, making the binary classification (real or fake) decision easier. So, it is crucial to differentiate label inputs for the generator and discriminator.
For example, when the label is digit 1, the generator must generate an image of 1 as real as possible (as MNIST image-like as possible), and the discriminator needs to give a significant loss to the generator if the generated image does not look like digit “1” at all. From the generator’s point of view, the discriminator is a loss function that also learns from the inputs.
So, as long as we train both the generator and the discriminator for all 0 to 9 digits equally well, the generator can understand conditions to generate realistic (MNIST-like) images for given labels.
8 Source Code
The source code is pretty much the same as DCGAN except that we now have the condition handling code.
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm
# Common config
= 64
batch_size
# Generator config
= 100 # Random sample size
sample_size = 0.01 # LeakyReLU alpha
g_alpha = 1.0e-4 # Learning rate
g_lr
# Discriminator config
= 0.01 # LeakyReLU alpha
d_alpha = 1.0e-4 # Learning rate
d_lr
# Data Loader for MNIST
= transforms.ToTensor()
transform = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataset = DataLoader(dataset, batch_size=batch_size, drop_last=True)
dataloader
# Coverts conditions into feature vectors
class Condition(nn.Module):
def __init__(self, alpha: float):
super().__init__()
# From one-hot encoding to features: 10 => 784
self.fc = nn.Sequential(
10, 784),
nn.Linear(784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha))
def forward(self, labels: torch.Tensor):
# One-hot encode labels
= F.one_hot(labels, num_classes=10)
x
# From Long to Float
= x.float()
x
# To feature vectors
return self.fc(x)
# Reshape helper
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.reshape(-1, *self.shape)
# Generator network
class Generator(nn.Module):
def __init__(self, sample_size: int, alpha: float):
super().__init__()
# sample_size => 784
self.fc = nn.Sequential(
784),
nn.Linear(sample_size, 784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha))
# 784 => 16 x 7 x 7
self.reshape = Reshape(16, 7, 7)
# 16 x 7 x 7 => 32 x 14 x 14
self.conv1 = nn.Sequential(
16, 32,
nn.ConvTranspose2d(=5, stride=2, padding=2,
kernel_size=1, bias=False),
output_padding32),
nn.BatchNorm2d(
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 1 x 28 x 28
self.conv2 = nn.Sequential(
32, 1,
nn.ConvTranspose2d(=5, stride=2, padding=2,
kernel_size=1, bias=False),
output_padding
nn.Sigmoid())
# Random value sample size
self.sample_size = sample_size
# To convert labels into feature vectors
self.cond = Condition(alpha)
def forward(self, labels: torch.Tensor):
# Labels as feature vectors
= self.cond(labels)
c
# Batch size is the number of labels
= len(labels)
batch_size
# Generate random inputs
= torch.randn(batch_size, self.sample_size)
z
# Inputs are the sum of random inputs and label features
= self.fc(z) # => 784
x = self.reshape(x+c) # => 16 x 7 x 7
x = self.conv1(x) # => 32 x 14 x 14
x = self.conv2(x) # => 1 x 28 x 28
x return x
# Discriminator network
class Discriminator(nn.Module):
def __init__(self, alpha: float):
super().__init__()
# 1 x 28 x 28 => 32 x 14 x 14
self.conv1 = nn.Sequential(
1, 32,
nn.Conv2d(=5, stride=2, padding=2, bias=False),
kernel_size
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 16 x 7 x 7
self.conv2 = nn.Sequential(
32, 16,
nn.Conv2d(=5, stride=2, padding=2, bias=False),
kernel_size16),
nn.BatchNorm2d(
nn.LeakyReLU(alpha))
# 16 x 7 x 7 => 784
self.fc = nn.Sequential(
nn.Flatten(),784, 784),
nn.Linear(784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha),784, 1))
nn.Linear(
# Reshape label features: 784 => 16 x 7 x 7
self.cond = nn.Sequential(
Condition(alpha),16, 7, 7))
Reshape(
def forward(self, images: torch.Tensor,
labels: torch.Tensor,
targets: torch.Tensor):# Label features
= self.cond(labels)
c
# Image features + Label features => real or fake?
= self.conv1(images) # => 32 x 14 x 14
x = self.conv2(x) # => 16 x 7 x 7
x = self.fc(x+c) # => 1
prediction
= F.binary_cross_entropy_with_logits(prediction, targets)
loss return loss
# To save grid images
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
= make_grid(images, ncol) # Into a grid
image_grid = image_grid.permute(1, 2, 0) # Channel to last
image_grid = image_grid.cpu().numpy() # Into Numpy
image_grid
plt.imshow(image_grid)
plt.xticks([])
plt.yticks([])f'generated_{epoch:03d}.jpg')
plt.savefig(
plt.close()
# Real / Fake targets
= torch.ones(batch_size, 1)
real_targets = torch.zeros(batch_size, 1)
fake_targets
# Generator and discriminator
= Generator(sample_size, g_alpha)
generator = Discriminator(d_alpha)
discriminator
# Optimizers
= torch.optim.Adam(discriminator.parameters(), lr=d_lr)
d_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)
g_optimizer
# Train loop
for epoch in range(100):
= []
d_losses = []
g_losses
for images, labels in tqdm(dataloader):
#===============================
# Disciminator Network Training
#===============================
# Images from MNIST are considered as real
= discriminator(images, labels, real_targets)
d_loss
# Images from Generator are considered as fake
+= discriminator(generator(labels), labels, fake_targets)
d_loss
# Discriminator paramter update
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
#===============================
# Generator Network Training
#===============================
# Images from Generator should be as real as ones from MNIST
= discriminator(generator(labels), labels, real_targets)
g_loss
# Generator parameter update
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# Keep losses for logging
d_losses.append(d_loss.item())
g_losses.append(g_loss.item())
# Print loss
print(epoch, np.mean(d_losses), np.mean(g_losses))
# Save generated images
= torch.LongTensor(list(range(10))).repeat(8).flatten()
labels =10) save_image_grid(epoch, generator(labels), ncol
9 References
- Conditional Generative Adversarial Nets
Mehdi Mirza、Simon Osindero