Home

Awesome

MLP Can Be A Good Transformer Learner (CVPR2024)

Overview

Note

Requirements

Note that we use torch==1.7.1 for training. To incorparate with ToMe, we use torch==1.12.1.

Checkpoints

We provide some checkpoints for reference. Here the prefix indicates the architectures while the suffix indicates which attention layers are removed.

Performance

We found that the same code and checkpoint would produce different inference results using different pytorch versions. We still cannot figure out and welcome discussions.

<table class="tg"> <thead> <tr> <th class="tg-0pky" rowspan="2">Arch</th> <th class="tg-0pky" rowspan="2">Baseline</th> <th class="tg-c3ow" colspan="2">25%</th> <th class="tg-c3ow" colspan="2">30%</th> <th class="tg-c3ow" colspan="2">40%</th> <th class="tg-c3ow" colspan="2">50%</th> </tr> <tr> <th class="tg-c3ow">1.7.1</th> <th class="tg-c3ow">1.12.1</th> <th class="tg-c3ow">1.7.1</th> <th class="tg-c3ow">1.12.1</th> <th class="tg-c3ow">1.7.1</th> <th class="tg-c3ow">1.12.1</th> <th class="tg-c3ow">1.7.1</th> <th class="tg-c3ow">1.12.1</th> </tr> </thead> <tbody> <tr> <td class="tg-0pky">Base</td> <td class="tg-c3ow">81.8</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow"><span style="font-weight:400;font-style:normal">81.83</span></td> <td class="tg-c3ow"><span style="font-weight:400;font-style:normal">81.77</span></td> <td class="tg-c3ow">81.33</td> <td class="tg-c3ow">81.46</td> </tr> <tr> <td class="tg-0pky">Small</td> <td class="tg-c3ow">79.9</td> <td class="tg-c3ow">80.31</td> <td class="tg-c3ow">80.33</td> <td class="tg-c3ow">79.90</td> <td class="tg-c3ow">79.89</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> </tr> <tr> <td class="tg-0pky">Tiny</td> <td class="tg-c3ow">72.2</td> <td class="tg-c3ow">72.94</td> <td class="tg-c3ow">72.79</td> <td class="tg-c3ow">71.90</td> <td class="tg-c3ow">71.88</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> <td class="tg-c3ow">-</td> </tr> </tbody> </table>

We deploy the ToMe over the normal blocks (indexed by 0, 1, 2, ...). Typically, we use this technique on the normal block started by index 1 and its subsequent normal blocks. The model is evaluated with torch==1.12.1 .

<table class="tg"> <thead> <tr> <th class="tg-0pky">Arch</th> <th class="tg-0lax">Remove Ratio</th> <th class="tg-0pky">w/o ToMe</th> <th class="tg-c3ow">Started idx</th> <th class="tg-c3ow">r</th> <th class="tg-0lax">w ToMe</th> </tr> </thead> <tbody> <tr> <td class="tg-0pky" rowspan="3">Base </td> <td class="tg-baqh" rowspan="2">40%</td> <td class="tg-c3ow" rowspan="2"><span style="font-weight:400;font-style:normal">81.77</span></td> <td class="tg-c3ow">1</td> <td class="tg-c3ow">24</td> <td class="tg-baqh">81.58</td> </tr> <tr> <td class="tg-c3ow">1</td> <td class="tg-c3ow">28</td> <td class="tg-baqh">81.42</td> </tr> <tr> <td class="tg-baqh">50%</td> <td class="tg-baqh">81.46</td> <td class="tg-baqh">0</td> <td class="tg-baqh">14</td> <td class="tg-baqh">81.28</td> </tr> <tr> <td class="tg-0pky" rowspan="2">Small</td> <td class="tg-baqh">25%</td> <td class="tg-c3ow">80.33</td> <td class="tg-c3ow">1</td> <td class="tg-c3ow">22</td> <td class="tg-baqh">79.86</td> </tr> <tr> <td class="tg-baqh">30%</td> <td class="tg-baqh">79.89</td> <td class="tg-baqh">1</td> <td class="tg-baqh">19</td> <td class="tg-baqh">79.62</td> </tr> <tr> <td class="tg-0pky" rowspan="2">Tiny</td> <td class="tg-baqh">25%</td> <td class="tg-c3ow">72.79</td> <td class="tg-c3ow">1</td> <td class="tg-c3ow">19</td> <td class="tg-baqh">72.35</td> </tr> <tr> <td class="tg-baqh">30%</td> <td class="tg-baqh">71.88</td> <td class="tg-baqh">1</td> <td class="tg-baqh">14</td> <td class="tg-baqh">71.7</td> </tr> </tbody> </table>

Before getting started

Training

We use 8 GPUs with 256 images per GPU.

E.g.

./script/shrink_base.sh

Testing

./script/test.sh

Speed, Params & FLOPs

Please refer to benchmark.py and run

python benchmark.py

To-Do

Issues / Contact

Feel free to create an issue if you get a question or just drop me emails ( sihao.lin@student.rmit.edu.au ).

Acknowledgement

This work is built upon DeiT. Thanks to their awesome work.