Home

Awesome

SoT: Delving Deeper into Classification Head for Transformer

<div> &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;<img src="images/overview.jpg" width="100%"/> </div>

Contents

  1. Introduction
  2. Installation
  3. Usage
  4. Classification results on CV tasks
  5. Classification results on NLP tasks
  6. Visualization
  7. Change log
  8. Acknowledgments
  9. Contact

Introduction

This repository is the official implementation of "SoT: Delving Deeper into Classification Head for Transformer". It contains the source code under PyTorch framework and models for image classification and text classification tasks.

Citation

Please consider cite the paper if it's useful for you.

@articles{SoT,
    author = {Jiangtao Xie, Ruiren Zeng, Qilong Wang, Ziqi Zhou, Peihua Li},
    title = {SoT: Delving Deeper into Classification Head for Transformer},
    booktitle = {arXiv:2104.10935v2},
    year = {2021}
}

Motivation and Contributions

For classification tasks whether in CV or NLP, the current works based on pure transformer architecture pay little attention to the classification head, applying Classification token (ClassT) solely in the classifier, however neglecting the Word tokens (WordT) which contains rich information. In our experiments, we show the ClassT and WordT are highly complementary, and the fusion of all tokens can further boost the performance. Therefore, we propose a novel classification paradigm by jointly utilizing ClassT and WordT, where the multiheaded global cross-covariance pooling with singluar value power normalization is proposed for effectively harness the rich information of WordT. We evaluate our proposed classfication scheme on the both CV and NLP tasks, achieving the very competitive performance with the counterparts.

Installation

git clone https://github.com/jiangtaoxie/SoT.git
cd SoT/
pip install -r requirments.txt

main libs: torch(>=1.7.0) | timm(==0.3.4) | apex (alternative)

python setup.py install 

Usage

Prepare dataset

Please prepare the dataset as the following file structure:

.
├── train
│   ├── class1
│   │   ├── class1_001.jpg
│   │   ├── class1_002.jpg
|   |   └── ...
│   ├── class2
│   ├── class3
│   ├── ...
│   ├── ...
│   └── classN
└── val
    ├── class1
    │   ├── class1_001.jpg
    │   ├── class1_002.jpg
    |   └── ...
    ├── class2
    ├── class3
    ├── ...
    ├── ...
    └── classN

Using our proposed SoT model

You can train the models of SoT family by using the command:

sh ./distributed_train.sh $NODE_NUM $DATA_ROOT --model $MODEL_NAME -b $BATCH_SIZE --lr  $INIT_LR\
--weight-decay $WEIGHT_DECAY \
--img-size $RESOLUTION \
--amp 

Basic hyper-parameter of our SoT:

Hyper-parameterSoT-TinySoT-SmallSoT-Base
Batch size10241024512
Init. LR1e-31e-35e-4
Weight Decay3e-23e-26.5e-2

Also, we provide the shell files in ./scripts for reproducing conveniently, you can run:

sh ./scripts/train_SoT_Tiny.sh # reproduce SoT-Tiny
sh ./scripts/train_SoT_Small.sh # reproduce SoT-Small
sh ./scripts/train_SoT_Base.sh # reproduce SoT-Base

On validation set of ImageNet-1K:

python main.py $DATA_ROOT $MODEL_NAME --b 256 --eval_checkpoint $CHECKPOINT_PATH

On ImageNet-A:

python main.py $DATA_ROOT $MODEL_NAME --b 256 --eval_checkpoint $CHECKPOINT_PATH --IN_A

The $MODEL_NAME can be SoT_Tiny/SoT_Small/SoT_Base

Using our proposed classification head in your architecture

from sot_src.model import Classifier, OnlyVisualTokensClassifier
classification_head_config = dict(
    type='MGCrP',
    fusion_type='sum_fc',
    args=dict(
        dim=256,
        num_heads=6,
        wr_dim=14,
        normalization=dict(
            type='svPN'
            alpha=0.5,
            iterNum=1,
            svNum=1,
            regular=None, # or nn.Dropout(0.5)
            input_dim=14,
        ),
    ),
)

classifier = Classifier(classification_head_config)

Notes:

Besides, we provide the implementation based on the DeiT and Swin-Transformer in CV tasks and BERT in NLP tasks for reference.

Using the proposed visual tokens in your architecture

You can also use the proposed TokenEmbedding module implemented by the DenseNet block like:

from sot_src import TokenEmbed

patch_embed_config = dict(
    type='DenseNet',
    embedding_dim=64,
    large_output=False, # When the resulotion of input image is 224, Ture for the 56x56 output, False for 14x14 output
)

patch_embed = TokenEmbed(patch_embed_config)

Classification results on CV tasks

Accuracy (single crop 224x224, %) on the validation set of ImageNet-1K and ImageNet-A

Our SoT family

BackboneImageNet Top-1 Acc.ImageNet-A Top-1 Acc.#Params (M)GFLOPsWeight
SoT-Tiny80.321.57.72.5Coming soon
SoT-Small82.731.826.95.8Coming soon
SoT-Base83.534.676.814.5Coming soon

DeiT family

BackboneImageNet Top-1 Acc.ImageNet-A Top-1 Acc.#Params (M)GFLOPsWeight
DeiT-T72.27.35.71.3model
DeiT-T + ours78.617.57.02.3Coming soon
DeiT-S79.818.922.14.6model
DeiT-S + ours82.731.826.95.8Coming soon
DeiT-B81.827.486.617.6model
DeiT-B + ours82.929.194.918.2Coming soon

Swin Transformer family

BackboneImageNet Top-1 Acc.ImageNet-A Top-1 Acc.#Params (M)GFLOPsWeight
Swin-T81.321.628.34.5model
Swin-T + ours83.033.531.66.0Coming soon
Swin-B83.535.887.815.4model
Swin-B + ours84.042.995.916.9Coming soon

Notes:

Classification results on NLP tasks

Accuracy (Top-1, %) on the 4 selected tasks from General Language Understanding Evaluation (GLUE) benchmark.

BackboneCoLARTEMNLIQNLIWeight
GPT54.3263.1782.1086.36model
GPT + ours57.2565.3582.4187.13Coming soon
BERT-base54.8267.1583.4790.11model
BERT-base + ours58.0369.3184.2090.78Coming soon
BERT-large60.6373.6585.9091.82model
BERT-large + ours61.8275.0986.4692.37Coming soon
SpanBERT-base57.4873.6585.5392.71model
SpanBERT-base + ours63.7777.2686.1393.31Coming soon
SpanBERT-large64.3278.3487.8994.22model
SpanBERT-large + ours65.9479.7988.1694.49Coming soon
RoBERTa-base61.5877.6087.5092.70model
RoBERTa-base + ours65.2880.5087.9093.10Coming soon
RoBERTa-large67.9886.6090.2094.70model
RoBERTa-large + ours70.9088.1090.5095.00Coming soon

Visualization

We make the further analysis by visualizing the models for CV and NLP tasks, where the SoT-Tiny and BERT-base are used as the backbone for each task respectively. We compare three variants base on the SoT-Tiny and BERT-base as follows:

<p align="center" style="color:rgb(255,0,0);">&radic;:<font color="black"> correct prediction;</font> &#10007;: <font color="black">incorrect prediction</font></p> <div> &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;<img src="images/vis.png" width="100%"/> </div>

We can see the ClassT is more suitable for classifying the categories associated with the backgrounds and the whole context. The WordT performs classfication primarily based on some local discriminative regions. Our ClassT+WordT can make fully use of merits of both word tokens and classfication token, which can focus on the most important regions for better classficaiton by exploiting both local and global information.

<div> &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;<img src="images/nlp_vis.png" width="100%"/> </div>

We selected some examples from CoLA task, which aims to judge whether an English sentence is grammatical or not. The greener background color denotes stronger impact of the word to the classification, while the bluer implies weaker one. We can see the proposed ClassT+WordT can highlight all important words in sentence while the others two fails, which can help to boost the performance of classification.

Change log

Acknowledgments

pytorch: https://github.com/pytorch/pytorch

timm: https://github.com/rwightman/pytorch-image-models

T2T-ViT: https://github.com/yitu-opensource/T2T-ViT

Contact

If you have any questions or suggestions, please contact me

jiangtaoxie@mail.dlut.edu.cn; coke990921@mail.dlut.edu.cn