Awesome
GANSeg: Learning to Segment by Unsupervised Hierarchical Image Generation (CVPR 2022)
GANSeg: Learning to Segment by Unsupervised Hierarchical Image Generation <br> Xingzhe He, Bastian Wandt, and Helge Rhodin <br> IEEE/CVF Conference on Computer Vision and Pattern Recognition CVPR 2022
[Paper]
Setup
Setup environment
conda create -n ganseg python=3.8
conda activate ganseg
pip install -r requirements.txt
Download datasets
The CelebA-in-the-wild, Taichi, CUB and Flower can be found on their websites. We provide the pre-processing code for CelebA-in-the-wild, CUB and Flower to make them h5
files. Taichi can be used directly.
Download pre-trained models
The pre-trained models (GAN and Segmenter) can be downloaded from Google Drive.
Testing
Segmentation
You can use gen_mask.py
to generate the segmentation masks.
python gen_mask.py --segmenter_log log/seg_celeba_wild_k8 --test_class_name mafl_wild_test --data_root data/celeba_wild --save_root saved_mask/celeba_wild_k8 --checkpoint 10
where,
--segmenter_log
specifies the log folder of the segmentation network,--test_class_name
specifices the particular dataset to test,--data_root
specifies the location of the dataset (the folder containing theh5
file),--save_root
defines the location of the saved images, and--checkpoint
specifies the index of the checkpoint.
Therefore, the above command will generate masks on the CelebA-in-the-wild.
You can also quantitatively test the segmentation.
python test_seg.py --segmenter_log log/seg_celeba_wild_k8 --test_class_name mafl_wild_test --data_root data/celeba_wild --checkpoint 10
GAN
You can use gen_img.py
to generate images with our GAN.
python gen_img.py --generator_log log/gan_celeba_wild_k8 --save_root saved_image/celeba_wild_k8 --checkpoint 30
Training
GAN
To train our GAN on CelebA-in-the-wild, run
python train_gan.py --class_name celeba_wild --data_root data/celeba_wild --n_keypoints 8
The trained weights and log can be found in logs/gan_celeba_wild_k8
.
We also provide a custom
choice for class_name
. You can specify data_root
to your own image folder to train our GAN on your own images.
To finetune the learned model, run
python train_gan.py --class_name celeba_wild --data_root data/celeba_wild --n_keypoints 8 --checkpoint [the epoch index to start]
The example parameters can be found in log
.
Segmentation
To train the segmenter on the pre-trained GAN (CelebA-in-the-wild), run
python train_seg.py --generator_log log/gan_celeba_wild_k8 --data_root data/celeba_wild --checkpoint 30
where,
--generator_log
specifies the generator log folder (used to generate image-mask pairs),--data_root
specifies the location of the dataset, and--checkpoint
specifies the index of the checkpoint of the GAN.
Citation
@inproceedings{he2022ganseg,
title={GANSeg: Learning to Segment by Unsupervised Hierarchical Image Generation},
author={He, Xingzhe and Wandt, Bastian and Rhodin, Helge},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={1225--1235},
year={2022}
}
Acknowledgements
The code is built upon LatentKeypointGAN.