Home

Awesome

<div align=center>

Generative Prompt Model for Weakly Supervised Object Localization

</div>

This is the official implementaion of paper Generative Prompt Model for Weakly Supervised Object Localization, which is accepted in ICCV 2023. This repository contains Pytorch training code, evaluation code, pre-trained models, and visualization method.

<div align=center>

arXiv preprint Python 3.8 PyTorch 1.11 LICENSE

PWC PWC

</div> <div align=center> <img src="assets/intro.png" width="69%"> </div>

1. Contents

2. Introduction

Weakly supervised object localization (WSOL) remains challenging when learning object localization models from image category labels. Conventional methods that discriminatively train activation models ignore representative yet less discriminative object parts. In this study, we propose a generative prompt model (GenPromp), defining the first generative pipeline to localize less discriminative object parts by formulating WSOL as a conditional image denoising procedure. During training, GenPromp converts image category labels to learnable prompt embeddings which are fed to a generative model to conditionally recover the input image with noise and learn representative embeddings. During inference, GenPromp combines the representative embeddings with discriminative embeddings (queried from an off-the-shelf vision-language model) for both representative and discriminative capacity. The combined embeddings are finally used to generate multi-scale high-quality attention maps, which facilitate localizing full object extent. Experiments on CUB-200-2011 and ILSVRC show that GenPromp respectively outperforms the best discriminative models, setting a solid baseline for WSOL with the generative model.

3. Results

<div align=center> <img src="assets/results.png" width="99%"> </div>

We re-train GenPromp with a better learning schedule on 4 x A100. The performance of GenPromp on CUB-200-2011 is further improved.

MethodDatasetCls Back.Top-1 LocTop-5 LocGT-known Loc
GenPrompCUB-200-2011EfficientNet-B787.096.198.0
GenPromp (Re-train)CUB-200-2011EfficientNet-B787.2 (+0.2)96.3 (+0.2)98.3 (+0.3)
GenPrompImageNetEfficientNet-B765.273.475.0

4. Get Start

4.1 Installation

To setup the environment of GenPromp, we use conda to manage our dependencies. Our developers use CUDA 11.3 to do experiments. Run the following commands to install GenPromp:

conda create -n gpm python=3.8 -y && conda activate gpm
pip install --upgrade pip
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install --upgrade diffusers[torch]==0.13.1
pip install transformers==4.29.2 accelerate==0.19.0
pip install matplotlib opencv-python OmegaConf tqdm

4.2 Dataset and Files Preparation

To train GenPromp with pre-training weights and infer GenPromp with the given weights, download the files in the table and arrange the files according to the file tree below. (Uploading)

Dataset & FilesDownloadUsage
data/ImageNet_ILSVRC2012 (146GB)Official LinkBenchmark dataset
data/CUB_200_2011 (1.2GB)Official LinkBenchmark dataset
ckpts/pretrains (5.2GB)Official Link, Google Drive, Baidu Drive(o9ei)Stable Diffusion pretrain weights
ckpts/classifications (2.3GB)Google Drive, Baidu Drive(o9ei)Classfication results on benchmark datasets
ckpts/imagenet750 (3.3.GB)Google Drive, Baidu Drive(o9ei)Weights that achieves 75.0% GT-Known Loc on ImageNet
ckpts/cub983 (3.3GB)Google Drive, Baidu Drive(o9ei)Weights that achieves 98.3% GT-Known Loc on CUB
    |--GenPromp/
      |--data/
        |--ImageNet_ILSVRC2012/
           |--ILSVRC2012_list/
           |--train/
           |--val/
        |--CUB_200_2011
           |--attributes/
           |--images/
           ...
      |--ckpts/
        |--pretrains/
          |--stable-diffusion-v1-4/
        |--classifications/
          |--cub_efficientnetb7.json
          |--imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json
        |--imagenet750/
          |--tokens/
             |--49408.bin
             |--49409.bin
             ...
          |--unet/
        |--cub983/
          |--tokens/
             |--49408.bin
             |--49409.bin
             ...
          |--unet/
      |--configs/
      |--datasets
      |--models
      |--main.py

4.3 Training

Here is a training example of GenPromp on ImageNet.

accelerate config
accelerate launch python main.py --function train_token --config configs/imagenet.yml --opt "{'train': {'save_path': 'ckpts/imagenet/'}}"
accelerate launch python main.py --function train_unet --config configs/imagenet_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/imagenet/tokens/', 'save_path': 'ckpts/imagenet/'}}"

accelerate is used for multi-GPU training. In the first training stage, the weights of concept tokens of the representative embeddings are learned and saved to ckpts/imagenet/. In the second training stage, the weights of the learned concept tokens are loaded from ckpts/imagenet/tokens/, then the weights of the UNet are finetuned and saved to ckpts/imagenet/. Other configurations can be seen in the config files (i.e. configs/imagenet.yml and configs/imagenet_stage2.yml) and can be modified by --opt with a parameter dict (See Extra Options for details).

Here is a training example of GenPromp on CUB_200_2011.

accelerate config
accelerate launch python main.py --function train_token --config configs/cub.yml --opt "{'train': {'save_path': 'ckpts/cub/'}}"
accelerate launch python main.py --function train_unet --config configs/cub_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/cub/tokens/', 'save_path': 'ckpts/cub/'}}"

4.4 Inference

Here is a inference example of GenPromp on ImageNet.

python main.py --function test --config configs/imagenet_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path': 'ckpts/imagnet750/log.txt'}}"

In the inference stage, the weights of the learned concept tokens are load from ckpts/imagenet750/tokens/ , the weights of the finetuned UNet are load from ckpts/imagenet750/unet/ and the log file is saved to ckpts/imagnet750/log.txt. Due the random noise added to the tested image, the results might fluctuate within a small range ($\pm$ 0.1).

Here is a inference example of GenPromp on CUB_200_2011.

python main.py --function test --config configs/cub_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}"

4.5 Extra Options

There are many extra options during training and inference. The default option is configured in the yml file. We can use --opt to add or override the default option with a parameter dict. Here are some usage of the most commonly used options.

OptionScopeUsage
{'data': {'keep_class': [0, 9]}}datakeep the data with category id in [0, 1, 2, 3, ..., 9]
{'train': {'batch_size': 2}}traintrain with batch size 2.
{'train': {'num_train_epochs': 1}}traintrain the model for 1 epoch.
{'train': {'save_steps': 200}}train_unetsave trained UNet every 200 steps.
{'train': {'max_train_steps': 600}}train_unetterminate training within 600 steps.
{'train': {'gradient_accumulation_steps': 2}}trainbatch size x2 when the memory of GPU is limited.
{'train': {'learning_rate': 5.0e-08}}trainthe learning rate is 5.0e-8.
{'train': {'scale_lr': True}}trainthe learning rate is multiplied with batch size if True.
{'train': {'load_pretrain_path': 'stable-diffusion/'}}trainthe pretrained model is load from stable-diffusion/.
{'train': {'load_token_path': 'ckpt/tokens/'}}trainthe trained concept tokens are load from ckpt/tokens/.
{'train': {'save_path': 'ckpt/'}}trainsave the trained weights to ckpt/.
{'test': {'batch_size': 2}}testtest with batch size 2.
{'test': {'cam_thr': 0.25}}testtest with cam threshold 0.25.
{'test': {'combine_ratio': 0.6}}testcombine ratio between $f_r$ and $f_d$ is 0.6.
{'test': {'load_class_path': 'imagenet_efficientnet.json'}}testload classification results from imagenet_efficientnet.json.
{'test': {'load_pretrain_path': 'stable-diffusion/'}}testthe pretrained model is load from stable-diffusion/.
{'test': {'load_token_path': 'ckpt/tokens/'}}testthe trained concept tokens are load from ckpt/tokens/.
{'test': {'load_unet_path': 'ckpt/unet/'}}testthe trained UNet is load from ckpt/unet/.
{'test': {'save_vis_path': 'ckpt/vis/'}}testthe visualized predictions are saved to ckpt/vis/.
{'test': {'save_log_path': 'ckpt/log.txt'}}testthe log file is saved to ckpt/log.txt.
{'test': {'eval_mode': 'top1'}}testtop1 denotes evaluating the predicted top1 cls category of the test image, top5 denotes evaluating the predicted top5 cls category of the test image, gtk denotes evaluating the gt category of the test image, which can be tested without the classification result. We use top1 as the default eval mode.

These options can be combined by simplely merging the dicts. For example, if you want to evaluate GenPromp with config file configs/imagenet_stage2.yml, with categories [0, 1, 2, ..., 9], concept tokens load from ckpts/imagenet750/tokens/, UNet load from ckpts/imagenet750/unet/, log file of the evaluated metrics saved to ckpts/imagnet750/log0-9.txt, combine ratio equals to 0, visualization results saved to ckpts/imagenet750/vis, using the following command:

python main.py --function test --config configs/imagenet_stage2.yml --opt "{'data': {'keep_class': [0, 9]}, 'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path':'ckpts/imagnet750/log.txt', 'combine_ratio': 0, 'save_vis_path': 'ckpts/imagenet750/vis'}}"
<div align=center> <img src="assets/visualize.png" width="99%"> </div>

5. Contacts

If you have any question about our work or this repository, please don't hesitate to contact us by emails or open an issue under this project.

6. Acknowledgment

7. Citation

@article{zhao2023generative,
  title={Generative Prompt Model for Weakly Supervised Object Localization},
  author={Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang},
  journal={arXiv preprint arXiv:2307.09756},
  year={2023}
}
@InProceedings{Zhao_2023_ICCV,
    author    = {Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang},
    title     = {Generative Prompt Model for Weakly Supervised Object Localization},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {6351-6361}
}