Home

Awesome

Latent Subspace Optimization

The official Pytorch implementation of our paper Where is My Spot? Few-shot Image Generation via Latent Subspace Optimization, CVPR 2023.

framework

Where is My Spot? Few-shot Image Generation via Latent Subspace Optimization

Chenxi Zheng, Bangzhen Liu, Xuemiao Xu, Huaidong Zhang, and Shengfeng He

Paper

Environment

Prepare the environment using either conda or pip.

cd envs
conda env create -f environment.yaml
cd ..
cd envs
conda create -n LSO python=3.7
conda activate LSO
pip install -r requirements.txt
cd ..

If installation of pytorch fails or a custom Pytorch version is needed, please install pytorch following the official guidance manually.

Getting started

Before training, prepare the pretrained ckpts optimized with seen categories.

Note that the latent codes only contain the subset for image generation $\mathbb{S}_{gen}^{c}$ in Sec. 4.2. For example, the shape of tensor in flowers_unseen17_0-10_step1300.npy is $[17*10, 12, 512]$, which is corresponding to the slice $[85:102, 0:10, :, :, :]$ of images in flower_c8189_s128_data_rgb.npy which has a tensor shape of $[102, 40, 128, 128, 3]$.

Unzip the files and set up the paths in configs/default_configs.py.

DATA_PATH = <PATH_TO_DATASET_NPY>
CKPT_PATH = <PATH_TO_STYLEGAN2_CKPT>
WS_PATH = <PATH_TO_WS>
IDCKPT_PATH = <PATH_TO_IDWEIGHTS>

Training

Multi-task for quantitative evaluation

python train_unseen.py \
    --outdir <output_dir> \
    --k_shot <k> \
    --dataset_name <dataset_name>

Single-task for detailed visualization

We also provide single-task optimization for visualization and detailed optimization evaluation.

python train_unseen.py \
    --outdir <output_dir> \
    --k_shot <k> \
    --single_task <cidx> <idx_1,...,idx_k> \
    --dataset_name <dataset_name>
<!-- * You may also customize the parameters in `configs`. * It takes about 30 hours to train the network on a V100 GPU. -->

Evaluation

(optional) If the images are generated in the separated runs, use merger.py to combine all the images.

python merger.py \
    --path <output_dir> \
    --idx <runidx_1>,...,<runidx_n>

Quantitative evaluation of the generated images.

python main_metric_calculate.py \
    --real_dir <real_directory> \
    --fake_dir <fake_directory> \
    --dataset_name <dataset_name>

The calculation of LPIPS is significantly accelerated by first extracting the features of each image. Please refer to metrics/lpips_fs/lpips_fs.py.

Citation

If you use this code for your research, please cite our paper.

@inproceedings{zheng2023my,
title={Where Is My Spot? Few-Shot Image Generation via Latent Subspace Optimization},
author={Zheng, Chenxi and Liu, Bangzhen and Zhang, Huaidong and Xu, Xuemiao and He, Shengfeng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={3272--3281},
year={2023}
}

Acknowledgement

This project builds upon and has been inspired by the following repositories:

We would like to thank the entire open-source community for fostering an environment of collaboration and knowledge sharing.

License

This repository is under MIT license.