Home

Awesome

Super Resolution using Generative Adversarial Networks

This is an implementation of the SRGAN model proposed in the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network in Keras. Note that this project is a work in progress.

A simplified view of the model can be seen as below: <br> <img src="https://github.com/titu1994/Super-Resolution-using-Generative-Adversarial-Networks/blob/master/architecture/SRGAN-simple-architecture.jpg?raw=true" width=40% height=50%>

Implementation Details

The SRGAN model is built in stages within models.py. Initially, only the SR-ResNet model is created, to which the VGG network is appended to create the pre-training model. The VGG weights are freezed as we will not update these weights.

In the pre-train mode:

  1. The discriminator model is not attached to the entire network. Therefore it is only the SR + VGG model that will be pretrained first.

  2. During pretraining, the VGG perceptual losses will be used to train (using the ContentVGGRegularizer) and TotalVariation loss (using TVRegularizer). No other loss (MSE, Binary crosss entropy, Discriminator) will be applied.

  3. Content Regularizer loss will be applied to the VGG Convolution 2-2 layer

  4. After pre training the SR + VGG model, we will pretrain the discriminator model.

  5. During discriminator pretraining, model is Generaor + Discriminator. Only binary cross entropy loss is used to train the Discriminator network.

In the full train mode:

  1. The discriminator model is attached to the entire network. Therefore it creates the SR + GAN + VGG model (SRGAN)
  2. Discriminator loss is also added to the VGGContentLoss and TVLoss.
  3. Content regularizer loss is applied to the VGG Convolution 5-3 layer. (VGG 16 is used instead of 19 for now)

Usage

Currently, models.py contains most of the code to train and create the models. To use different modes, uncomment the parts of the code that you need.

Note the difference between the *_network objects and *_model objects.

Note: The training images need to be stored in a subdirectory. Assume the path to the images is /path-to-dir/path-to-sub-dir/*.png, then simply write the path as coco_path = /path-to-dir. If this does not work, try coco_path = /path-to-dir/ with a trailing slash (/)

To just create the pretrain model:

srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_model = srgan_network.build_srgan_pretrain_model()

# Plot the model
from keras.utils.visualize_util import plot
plot(srgan_model, to_file='SRGAN.png', show_shapes=True)

To pretrain the SR network:

srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_network.pre_train_srgan(iamges_path, nb_epochs=1, nb_images=50000)

** NOTE **: There may be many cases where generator initializations may lead to completely solid validation images. Please check the first few iterations to see if the validation images are not solid images.

To counteract this, a pretrained generator model has been provided, from which you can restart training. Therefore the model can continue learning without hitting a bad initialization.

To pretrain the Discriminator network:

srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_network.pre_train_discriminator(iamges_path, nb_epochs=1, nb_images=50000, batchsize=16)

To train the full network (Does NOT work properly right now, Discriminator is not correctly trained):

srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_network.train_full_model(coco_path, nb_images=80000, nb_epochs=10)

Benchmarks

Currently supports validation agains Set5, Set14 and BSD 100 dataset images. To download the images, each of the 3 dataset have scripts called download_*.py which must be run before running benchmark_test.py test.

Current Scores (Due to RGB grid and Blurred restoration):

SR ResNet:

Drawbacks:

Plans

The codebase is currently very chaotic, since I am focusing on correct implementation before making the project better. Therefore, expect the code to drastically change over commits.

Some things I am currently trying out:

Discussion

There is an ongoing discussion at https://github.com/fchollet/keras/issues/3940 where I detail some of the outputs and attempts to correct the errors.

Requirements