Home

Awesome

Pytorch GAN Zoo

A GAN toolbox for researchers and developers with:

<img src="illustration.png" alt="illustration"> Picture: Generated samples from GANs trained on celebaHQ, fashionGen, DTD. <img src="illustartionCelebaHQ.jpg" alt="celeba"> Picture: fake faces with celebaHQ

This code also implements diverse tools:

Requirements

This project requires:

Optional:

If you don't already have pytorch or torchvision please have a look at https://pytorch.org/ as the installation command may vary depending on your OS and your version of CUDA.

You can install all other dependencies with pip by running:

pip install -r requirements.txt

Recommended datasets

Quick training

The datasets.py script allows you to prepare your datasets and build their corresponding configuration files.

If you want to waste no time and just launch a training session on celeba cropped

python datasets.py celeba_cropped $PATH_TO_CELEBA/img_align_celeba/ -o $OUTPUT_DATASET
python train.py PGAN -c config_celeba_cropped.json --restart -n celeba_cropped

And wait for a few days. Your checkpoints will be dumped in output_networks/celeba_cropped. You should get 128x128 generations at the end.

For celebaHQ:

python datasets.py celebaHQ $PATH_TO_CELEBAHQ -o $OUTPUT_DATASET - f
python train.py PGAN -c config_celebaHQ.json --restart -n celebaHQ

Your checkpoints will be dumped in output_networks/celebaHQ. You should get 1024x1024 generations at the end.

For fashionGen:

python datasets.py fashionGen $PATH_TO_FASHIONGEN_RES_256 -o $OUTPUT_DIR
python train.py PGAN -c config_fashionGen.json --restart -n fashionGen

The above command will train the fashionGen model up resolution 256x256. If you want to train fashionGen on a specific sub-dataset for example CLOTHING, run:

python train.py PGAN -c config_fashionGen.json --restart -n fashionGen -v CLOTHING

Four sub-datasets are available: CLOTHING, SHOES, BAGS and ACCESSORIES.

For the DTD texture dataset:

python datasets.py dtd $PATH_TO_DTD
python train.py PGAN -c config_dtd.json --restart -n dtd

For cifar10:

python datasets.py cifar10 $PATH_TO_CIFAR10 -o $OUTPUT_DATASET
python train.py PGAN -c config_cifar10.json --restart -n cifar10

Load a pretrained model with torch.hub

Models trained on celebaHQ, fashionGen, cifar10 and celeba cropped are available with torch.hub.

Checkpoints:

See hubconf.py for how to load a checkpoint !

GDPP

To apply the GDPP loss to your model just add the option --GDPP true to your training command.

(beta) StyleGAN

To run StyleGAN, use the model name StyleGAN when running train.py. Besides,to run StyleGAN you can use the pre-computed configurations for celeba and celebaHQ. For example:

python train.py StyleGAN -c config_celebaHQ.json --restart -n style_gan_celeba

Advanced guidelines

How to run a training session ?

python train.py $MODEL_NAME -c $CONFIGURATION_FILE[-n $RUN_NAME][-d $OUTPUT_DIRECTORY][OVERRIDES]

Where:

1 - MODEL_NAME is the name of the model you want to run. Currently, two models are available: - PGAN(progressive growing of gan) - PPGAN(decoupled version of PGAN)

2 - CONFIGURATION_FILE(mandatory): path to a training configuration file. This file is a json file containing at least a pathDB entry with the path to the training dataset. See below for more informations about this file.

3 - RUN_NAME is the name you want to give to your training session. All checkpoints will be saved in $OUTPUT_DIRECTORY/$RUN_NAME. Default value is default

4 - OUTPUT_DIRECTORY is the directory were all training sessions are saved. Default value is output_networks

5 - OVERRIDES: you can overrides some of the models parameters defined in "config" field of the configuration file(see below) in the command line. For example:

python train.py PPGAN -c coin.json -n PAN --learningRate 0.2

Will force the learning rate to be 0.2 in the training whatever the configuration file coin.json specifies.

To get all the possible override options, please type:

python train.py $MODEL_NAME --overrides

Configuration file of a training session

The minimum configuration file for a training session is a json file with the following lines

{
    "pathDB": PATH_TO_YOUR_DATASET
}

Where a dataset can be:

To this you can add a "config" entry giving overrides to the standard configuration. See models/trainer/standard_configurations to see all possible options. For example:

{
    "pathDB": PATH_TO_YOUR_DATASET,
    "config": {"baseLearningRate": 0.1,
               "miniBatchSize": 22}
}

Will override the learning rate and the mini-batch-size. Please note that if you specify a - -baseLearningRate option in your command line, the command line will prevail. Depending on how you work you might prefer to have specific configuration files for each run or only rely on one configuration file and input your training parameters via the command line.

Other fields are available on the configuration file, like:

{
    image_name1.jpg: {attribute1: label, attribute2, label ...}
    image_name2.jpg: {attribute1: label, attribute2, label ...}
    ...
}

With a dataset in the fashionGen format(.h5) it's a dictionary summing up statistics on the class to be sampled.

How to run a evaluation of the results of your training session ?

You need to use the eval.py script.

Image generation

You can generate more images from an existing checkpoint using:

python eval.py visualization -n $modelName -m $modelType

Where modelType is in [PGAN, PPGAN, DCGAN] and modelName is the name given to your model. This script will load the last checkpoint detected at testNets/$modelName. If you want to load a specific iteration, please call:

python eval.py visualization -n $modelName -m $modelType -s $SCALE -i $ITER

If your model is conditioned, you can ask the visualizer to print out some conditioned generations. First, use --showLabels to see all the available categories and their labels.

python eval.py visualization -n $modelName -m $modelType --showLabels

Then, run your generation with:

python eval.py visualization -n $modelName -m $modelType --$CATEGORY_NAME $LABEL_NAME

For example with a model trained on fashionGen:

python eval.py visualization -n $modelName -m $modelType --Class T_SHIRT

Will plot a batch of T_SHIRTS in visdom.

Fake dataset generation

To save a randomly generated fake dataset from a checkpoint please use:

python eval.py visualization -n $modelName -m $modelType --save_dataset $PATH_TO_THE_OUTPUT_DATASET --size_dataset $SIZE_OF_THE_OUTPUT

SWD metric

Using the same kind of configuration file as above, just launch:

python eval.py laplacian_SWD -c $CONFIGURATION_FILE -n $modelName -m $modelType

Where $CONFIGURATION_FILE is the training configuration file called by train.py (see above): it must contains a "pathDB" field pointing to path to the dataset's directory. For example, if you followed the instruction of the Quick Training section to launch a training session on celebaHQ your configuration file will be config_celebaHQ.json.

You can add optional arguments:

Inspirational generation

To make an inspirational generation, you first need to build a feature extractor:

python save_feature_extractor.py {vgg16, vgg19} $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR --layers 3 4 5

Then run your model:

python eval.py inspirational_generation -n $modelName -m $modelType --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR

I have generated my metrics. How can i plot them on visdom ?

Just run

python eval.py metric_plot  -n $modelName

LICENSE

This project is under BSD-3 license.