Code Analysis of Cycle-GAN

This is a code analysis of Cycle-GAN. The original code is here which is the implementation of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.

This is a code analysis of Cycle-GAN. The original code is here which is the implementation of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.

Introduction

CycleGAN is a generative adversarial network (GAN) that can translate images from one domain to another without paired examples. This means that you can use CycleGAN to translate images from a source domain, such as photos of horses, to a target domain, such as photos of zebras, even if you don't have any paired images of horses and zebras.

Model Architecture

It works by training two generators and two discriminators:

  • One generator, $G_A$, that translates images from domain A to domain B.
  • Another generator, $G_B$, that translates images from domain B to domain A.
  • One discriminator, $D_A$, that discriminates between real images from domain A and fake images from domain A.
  • Another discriminator, $D_B$, that discriminates between real images from domain B and fake images from domain B.

The following codes shows the four networks $G_A$, $G_B$, $D_A$, $D_B$.

Note that this article is for breifly explaining the Cycle-GAN in perspective of code reading. Implementation details for network definition may not be included, however; I will provide the line number of the code that you can refer to.
# models/cycle_gan_model.py (line 73-81)
class CycleGANModel(BaseModel):
    # ...
    # Generator A and B
    self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
    self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

    if self.isTrain:
        # Discriminator A and B
        self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

    # ...

    def forward(self):
        """Run forward pass; called by both training and testing"""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

# models/networks.py (line 120-204)
def define_G(...):
    # ...
    net_generator = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
    return net_generator

def defind_D(...):
    # ...
    net_discriminator = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
    return net_discriminator


# models/networks.py (line 361, line 539)
class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """
    # ...


class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        # ...

Loss Functions

For CycleGAN, in addition to GAN losses, the paper introduced $\lambda_{A}$, $\lambda_{B}$, and $\lambda_{identity}$ for the following losses, which have A as the source domain and B as the target domain.

The generators in the last section are denoted as $G_{A}$ and $G_{B}$, and the discriminators are denoted as $D_{A}$ and $D_{B}$, which, respectively:

  • $G_{A}$: A $\rightarrow$ B
  • $G_{B}$: B $\rightarrow$ A
  • $D_{A}$: $G_{A}(A)$ vs. B
  • $D_{B}$: $G_{B}(B)$ vs. A

The loss functions for optimizing each networks are categorized by forward cycle loss, backward cycle loss, and identity loss.

  • Forward cycle loss: $\lambda_{A} \cdot ||G_{B}(G_{A}(A)) - A||$ (Eqn. (2) in the paper)
  • Backward cycle loss: $\lambda_{B} \cdot ||G_{A}(G_{B}(B)) - B||$ (Eqn. (2) in the paper)
  • Identity loss (optional): $\lambda_{identity} \cdot (||G_{A}(B) - B|| \cdot \lambda_{B} + ||G_{B}(A) - A|| \cdot \lambda_{A})$ (Sec 5.2 "Photo generation from paintings" in the paper)
Note that Forward and backward cycle loss are for calculating the loss between the original image and the image that is translated and translated back again. For identity loss, since we have the image on domain B and we want the $G_{A}$ (from domain A to domain B) to generate the image on domain B with the image on domain B as input, which means that the image should be recognized as a domain B image but not domain A image (or say no need to translate). Therefore, we calculate the loss to optimize this ability of $G_{A}$ by minimizing the loss for this "recognization" process, and vice versa.
# models/cycle_gan_model.py (line 47-89)

class CycleGANModel(BaseModel):
    # ...
    def __init__(self, opt):
        # ...
        self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
        # define loss functions
        self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()
        # ...


# models/cycle_gan_model.py (line 141-178)
def backward_D_A(self):
    """Calculate GAN loss for discriminator D_A"""
    fake_B = self.fake_B_pool.query(self.fake_B)
    self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

def backward_D_B(self):
    """Calculate GAN loss for discriminator D_B"""
    fake_A = self.fake_A_pool.query(self.fake_A)
    self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

def backward_G(self):
    """Calculate the loss for generators G_A and G_B"""
    lambda_idt = self.opt.lambda_identity
    lambda_A = self.opt.lambda_A
    lambda_B = self.opt.lambda_B
    # Identity loss
    if lambda_idt > 0:
        # G_A should be identity if real_B is fed: ||G_A(B) - B||
        self.idt_A = self.netG_A(self.real_B)
        self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
        # G_B should be identity if real_A is fed: ||G_B(A) - A||
        self.idt_B = self.netG_B(self.real_A)
        self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
    else:
        self.loss_idt_A = 0
        self.loss_idt_B = 0

    # GAN loss D_A(G_A(A))
    self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
    # GAN loss D_B(G_B(B))
    self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
    # Forward cycle loss || G_B(G_A(A)) - A||
    self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
    # Backward cycle loss || G_A(G_B(B)) - B||
    self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
    # combined loss and calculate gradients
    self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
    self.loss_G.backward()

Training

For the feedforward and backpropagation processes, essential lines of codes are contained in the previous sections (function forward() and backward_*()). The following is the training process for the whole model.

# train.py (line 51-52)
for epoch in # ...
    # ...
    model.set_input(data)         # input data
    model.optimize_parameters()   # feedforward and backpropagation

Testing

For testing, we need to set the model to eval() mode. This is because we do not need to calculate the gradients for testing. The following is the testing process for the whole model.

# For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
# For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
if opt.eval:
    model.eval()
for i, data in enumerate(dataset):
    # ...
    model.set_input(data)  # unpack data from data loader
    model.test()           # run inference
    visuals = model.get_current_visuals()  # get image results
    img_path = model.get_image_paths()     # get image paths
    if i % 5 == 0:  # save images to an HTML file
        print('processing (%04d)-th image... %s' % (i, img_path))
    save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, use_wandb=opt.use_wandb)
webpage.save()  # save the HTML

Conclusion

In this post, we have breifly introduced the CycleGAN model and its implementation in PyTorch according to the original paper. We have also discussed the training and testing processes for the model.