DCGAN (Deep Convolutional GAN)
Generates MNIST-like Images with Dramatically Better Quality
In this article, we incorporate the idea from DCGAN to improve the simple GAN model that we trained in the previous article. Just like before, we will implement DCGAN step by step.
1 DCGAN - Our Reference Model
We refer to PyTorch’s DCGAN tutorial for DCGAN model implementation. We are especially interested in the convolutional (Conv2d) layers as we believe they will improve how the discriminator extracts features. DCGAN also uses transposed convolution (TransposeConv2d) layers to improve how the generator generates images.
DCGAN generates RGB-color images, and the image size (64x64) is much bigger than MNIST images. We must adjust these to generate in grayscale (1 channel) with MNIST image size (28x28).
2 Generator Network with Transposed Convolutions
The generator network from the previous article was very simple.
# Generator network
class Generator(nn.Sequential):
def __init__(self, sample_size: int):
super().__init__(
128),
nn.Linear(sample_size, 0.01),
nn.LeakyReLU(128, 784),
nn.Linear(
nn.Sigmoid())
# Random value vector size
self.sample_size = sample_size
def forward(self, batch_size: int):
# Generate random values
= torch.randn(batch_size, self.sample_size)
z
# Generator output
= super().forward(z)
output
# Convert the output into a greyscale image (1x28x28)
= output.reshape(batch_size, 1, 28, 28)
generated_images return generated_images
In the above model, we reshape the generator output into the MNIST image shape. In the updated model (below), the DCGAN generator architecture includes transposed convolution after image reshaping since ConvTranspose2d
deals with image data rather than flattened data.
# Generator network with transposed convolutions
class Generator(nn.Module):
def __init__(self, sample_size: int, alpha: float):
super().__init__()
# sample_size => 784
self.fc = nn.Sequential(
784),
nn.Linear(sample_size, 784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha))
# 784 => 16 x 7 x 7
self.reshape = Reshape(16, 7, 7)
# 16 x 7 x 7 => 32 x 14 x 14
self.conv1 = nn.Sequential(
16, 32,
nn.ConvTranspose2d(=5, stride=2, padding=2,
kernel_size=1, bias=False),
output_padding32),
nn.BatchNorm2d(
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 1 x 28 x 28
self.conv2 = nn.Sequential(
32, 1,
nn.ConvTranspose2d(=5, stride=2, padding=2,
kernel_size=1, bias=False),
output_padding
nn.Sigmoid())
# Random value vector size
self.sample_size = sample_size
def forward(self, batch_size: int):
# Random value generation
= torch.randn(batch_size, self.sample_size)
z
= self.fc(z) # => 784
x = self.reshape(x) # => 16 x 7 x 7
x = self.conv1(x) # => 32 x 14 x 14
x = self.conv2(x) # => 1 x 28 x 28
x return x
Like DCGAN, we are using ConvTranspose2d to expand image size from 7x7 to 28x28. ConvTranspose2d
layers have learnable parameters we train through GAN training. As such, the transposed convolution layers help expand image size and generate better-quality images. We have Batch Normalization to speed up the learning process. For reshaping, we prepare the following helper class.
# Reshape helper
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.reshape(-1, *self.shape)
The data shape changes as follows, starting with the random value vector size of 100:
100
=> 784
=> 16 x 7 x 7 # Reshape
=> 32 x 14 x 14 # nn.ConvTranspose2d
=> 1 x 28 x 28 # nn.ConvTranspose2d
With these arrangements, the updated generator generates greyscale images of 28x28 size.
3 Discriminator Network with Convolutions
The discriminator network from the previous article was very simple.
# Discriminator network
class Discriminator(nn.Sequential):
def __init__(self):
super().__init__(
784, 128),
nn.Linear(0.01),
nn.LeakyReLU(128, 1))
nn.Linear(
def forward(self, images: torch.Tensor, targets: torch.Tensor):
= super().forward(images.reshape(-1, 784))
prediction = F.binary_cross_entropy_with_logits(prediction, targets)
loss return loss
We feed flattened image data through fully-connected linear layers to output one value per image which scores how likely input images are real (as if they come from MNIST). Finally, the discriminator network outputs loss values.
The updated discriminator network incorporates convolutional layers.
# Discriminator network with convolutions
class Discriminator(nn.Module):
def __init__(self, alpha: float):
super().__init__()
# 1 x 28 x 28 => 32 x 14 x 14
self.conv1 = nn.Sequential(
1, 32,
nn.Conv2d(=5, stride=2, padding=2, bias=False),
kernel_size
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 16 x 7 x 7
self.conv2 = nn.Sequential(
32, 16,
nn.Conv2d(=5, stride=2, padding=2, bias=False),
kernel_size16),
nn.BatchNorm2d(
nn.LeakyReLU(alpha))
# 16 x 7 x 7 => 784
self.fc = nn.Sequential(
nn.Flatten(),784, 784),
nn.Linear(784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha),784, 1))
nn.Linear(
def forward(self, images: torch.Tensor, targets: torch.Tensor):
= self.conv1(images) # => 32 x 14 x 14
x = self.conv2(x) # => 16 x 7 x 7
x = self.fc(x) # => 1
prediction
= F.binary_cross_entropy_with_logits(prediction, targets)
loss return loss
We use Conv2d
to shrink image size from 1x28x28 to 16x7x7, extracting features (channels). After that, we feed flattened data into fully-connected linear layers for classification, just like the previous version of the discriminator. As in the updated generator, the update discriminatory incorporates Batch Normalization
to make the learning process more efficient.
4 The Entire DCGAN Code
The DCGAN implementation is mostly the same as the previous article except for Generator
and Discriminator
definitions. I also adjusted the learning rate for the generator slightly higher this time which seems to work better.
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm
# General config
= 64
batch_size
# Generator config
= 100 # Random value sample size
sample_size = 0.01 # LeakyReLU alpha
g_alpha = 1.0e-3 # Learning rate (higher than previous version)
g_lr
# Discriminator config
= 0.01 # LeakyReLU alpha
d_alpha = 1.0e-4 # Learning rate
d_lr
# DataLoader for MNIST
= transforms.ToTensor()
transform = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataset = DataLoader(dataset, batch_size=batch_size, drop_last=True)
dataloader
# Reshape helper
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.reshape(-1, *self.shape)
# Generator network
class Generator(nn.Module):
def __init__(self, sample_size: int, alpha: float):
super().__init__()
# sample_size => 784
self.fc = nn.Sequential(
784),
nn.Linear(sample_size, 784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha))
# 784 => 16 x 7 x 7
self.reshape = Reshape(16, 7, 7)
# 16 x 7 x 7 => 32 x 14 x 14
self.conv1 = nn.Sequential(
16, 32,
nn.ConvTranspose2d(=5, stride=2, padding=2,
kernel_size=1, bias=False),
output_padding32),
nn.BatchNorm2d(
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 1 x 28 x 28
self.conv2 = nn.Sequential(
32, 1,
nn.ConvTranspose2d(=5, stride=2, padding=2,
kernel_size=1, bias=False),
output_padding
nn.Sigmoid())
# Random value sample size
self.sample_size = sample_size
def forward(self, batch_size: int):
# Generate random input values
= torch.randn(batch_size, self.sample_size)
z
# Use transposed convolutions
= self.fc(z) # => 784
x = self.reshape(x) # => 16 x 7 x 7
x = self.conv1(x) # => 32 x 14 x 14
x = self.conv2(x) # => 1 x 28 x 28
x return x
# Discriminator network
class Discriminator(nn.Module):
def __init__(self, alpha: float):
super().__init__()
# 1 x 28 x 28 => 32 x 14 x 14
self.conv1 = nn.Sequential(
1, 32,
nn.Conv2d(=5, stride=2, padding=2, bias=False),
kernel_size
nn.LeakyReLU(alpha))
# 32 x 14 x 14 => 16 x 7 x 7
self.conv2 = nn.Sequential(
32, 16,
nn.Conv2d(=5, stride=2, padding=2, bias=False),
kernel_size16),
nn.BatchNorm2d(
nn.LeakyReLU(alpha))
# 16 x 7 x 7 => 784
self.fc = nn.Sequential(
nn.Flatten(),784, 784),
nn.Linear(784),
nn.BatchNorm1d(
nn.LeakyReLU(alpha),784, 1))
nn.Linear(
def forward(self, images: torch.Tensor, targets: torch.Tensor):
# Extract image features using convolutions
= self.conv1(images) # => 32 x 14 x 14
x = self.conv2(x) # => 16 x 7 x 7
x = self.fc(x) # => 1
prediction
= F.binary_cross_entropy_with_logits(prediction, targets)
loss return loss
# Save image grid
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
= make_grid(images, ncol) # Images in a grid
image_grid = image_grid.permute(1, 2, 0) # Move channel last
image_grid = image_grid.cpu().numpy() # To Numpy
image_grid
plt.imshow(image_grid)
plt.xticks([])
plt.yticks([])f'generated_{epoch:03d}.jpg')
plt.savefig(
plt.close()
# Real and fake labels
= torch.ones(batch_size, 1)
real_targets = torch.zeros(batch_size, 1)
fake_targets
# Generator and discriminator networks
= Generator(sample_size, g_alpha)
generator = Discriminator(d_alpha)
discriminator
# Optimizers
= torch.optim.Adam(discriminator.parameters(), lr=d_lr)
d_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)
g_optimizer
# Training loop
for epoch in range(100):
= []
d_losses = []
g_losses
for images, labels in tqdm(dataloader):
#===============================
# Discriminator training
#===============================
# Loss with MNIST image inputs and real_targets as labels
discriminator.train()= discriminator(images, real_targets)
d_loss
# Generate images in eval mode
eval()
generator.with torch.no_grad():
= generator(batch_size)
generated_images
# Loss with generated image inputs and fake_targets as labels
+= discriminator(generated_images, fake_targets)
d_loss
# Optimizer updates the discriminator parameters
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
#===============================
# Generator Network Training
#===============================
# Generate images in train mode
generator.train()= generator(batch_size)
generated_images
# batchnorm is unstable in eval due to generated images
# change drastically every epoch. We'll not use the eval here.
# discriminator.eval()
# Loss with generated image inputs and real_targets as labels
= discriminator(generated_images, real_targets)
g_loss
# Optimizer updates the generator parameters
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# Keep losses for logging
d_losses.append(d_loss.item())
g_losses.append(g_loss.item())
# Print average losses
print(epoch, np.mean(d_losses), np.mean(g_losses))
# Save images
=8) save_image_grid(epoch, generator(batch_size), ncol
It takes longer to train than the previous version. Incorporating GPU support would improve the speed.
4.1 One Caveat about Discriminator’s BatchNorm in Eval Mode
In the above source code, I commented out the line that enables the discriminator’s eval
mode. The batch norm’s running averages are unstable because generated images change drastically in every batch. We should keep the discriminator in the train
mode to constantly adjust the batch norm’s parameters. In later epochs, we could perhaps enable the eval
mode for the discriminator, but there is no need. We can keep everything in the train
mode for the discriminator and generator networks, and the GAN training will work fine. The DCGAN sample code from Pytorch does that, too. Also, there is an explanation of the issue by Soumith Chintala in this link.
5 Before and After
5.1 Epoch 1
The previous version generated the below images after the first epoch.
The updated version generated the below images after the first epoch.
It already looks promising.
5.2 Epoch 50
The previous version generated the below images after the 50th epoch.
The updated version generated the below images after the 50th epoch.
They already look a lot better than the final outputs of the previous version.
5.3 Epoch 100
The previous version generated the below images after the 100th epoch.
The updated version generated the below images after the 100th epoch.
The quality of images dramatically improved. I can not tell if the above images are actually from MNIST or generated ones.
5.4 Real MNIST images for comparison
Below are real MNIST images for comparison. Do they look real or fake to you?