Home

Awesome

Weakly Supervised Data Augmentation Network

This is the official TensorFlow implementation of WS-DAN.

See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification.

Compatibility

Requirements

$ git clone git@github.com:tau-yihouxiang/WS_DAN.git
$ cd WS_DAN
$ python setup.py install
$ conda install -c menpo opencv
$ pip install tqdm

Datasets and Pre-trained models

Datasets#AttentionPre-trained model
CUB-200-201132WS-DAN
Stanford-Cars32WS-DAN
FGVC-Aircraft32WS-DAN

Inspiration

The code is based on the Tensorflow-Slim Library.

Preparing Datasets

Download and pre-process images and labels to tfrecords.

The convert_data.py will generate ./tfrecords folder blow the provided $dataset_dir

-Bird
   └── Data
         └─── tfrecords
         └─── images.txt
         └─── image_class_labels.txt
         └─── train_test_split.txt
         └─── images
                 └─── ****.jpg
$ python convert_data.py --dataset_name=Bird --dataset_dir=./Bird/Data
-Car
  └── Data
        └─── tfrecords
        └─── devkit
        |         └─── cars_train_annos.mat
        |         └─── cars_test_annos_withlabels.mat
        └─── cars_train
        |        └─── ****.jpg
        └─── cars_test
                 └─── ****.jpg
$ python convert_data.py --dataset_name=Car --dataset_dir=./Car/Data
-Aircraft
    └── Data
          └─── tfrecords
          └─── fgvc-aircraft-2013b
                       └─── ***
$ python convert_data.py --dataset_name=Aircraft --dataset_dir=./Aircraft/Data

Running training

ImageNet pre-trained model

Download imagenet pre-trained model inception_v3.ckpt and put it blow folder ./pre_trained/

DATASET="Bird"
TRAIN_DIR="./$DATASET/WS_DAN/TRAIN/ws_dan_part_32"
MODEL_PATH='./pre_trained/inception_v3.ckpt'

python train_sample.py --learning_rate=0.001 \
                            --dataset_name=$DATASET \
                            --dataset_dir="./$DATASET/Data/tfrecords" \
                            --train_dir=$TRAIN_DIR \
                            --checkpoint_path=$MODEL_PATH \
                            --max_number_of_steps=80000 \
                            --weight_decay=1e-5 \
                            --model_name='inception_v3_bap' \
                            --checkpoint_exclude_scopes="InceptionV3/bilinear_attention_pooling" \
                            --batch_size=12 \
                            --train_image_size=448 \
                            --num_clones=1 \
                            --gpus="3"\
                            --feature_maps="Mixed_6e"\
                            --attention_maps="Mixed_7a_b0"\
                            --num_parts=32

Running testing

DATASET="Bird"
TRAIN_DIR="./$DATASET/WS_DAN/TRAIN/ws_dan_part_32"
TEST_DIR="./$DATASET/WS_DAN/TEST/ws_dan_part_32"

python eval_sample.py --checkpoint_path=$TRAIN_DIR \
                         --dataset_name=$DATASET \
                         --dataset_split_name='test' \
                         --dataset_dir="./$DATASET/Data/tfrecords" \
                         --eval_dir=$TEST_DIR \
                         --model_name='inception_v3_bap' \
                         --batch_size=16 \
                         --eval_image_size=448\
                         --gpus="2"\
                         --feature_maps="Mixed_6e"\
                         --attention_maps="Mixed_7a_b0"\
                         --num_parts=32

Visualization

$ tensorboard --logdir=/path/to/model_dir --port=8081

Contact

Email: yihouxiang@gmail.com

Other Re-implementation

WS-DAN.PyTorch

License

MIT