Awesome
FLatten Transformer
This repo contains the official PyTorch code and pre-trained models for FLatten Transformer (ICCV 2023).
Updates
- May 28 2024: Fix numerical instability problem. Now FLatten Transformers can be trained with auto mixed precision (amp) or float16.
Introduction
Motivation
<p align="center"> <img src="figures/attention.png" width= "600"> </p>The quadratic computation complexity of self-attention $\mathcal{O}(N^2)$ has been a long-standing problem when applying Transformer models to vision tasks. Apart from reducing attention regions, linear attention is also considered as an effective solution to avoid excessive computation costs. By approximating Softmax with carefully designed mapping functions, linear attention can switch the computation order in the self-attention operation and achieve linear complexity $\mathcal{O}(N)$. Nevertheless, current linear attention approaches either suffer from severe performance drop or involve additional computation overhead from the mapping function. In this paper, we propose a novel Focused Linear Attention module to achieve both high efficiency and expressiveness.
Method
<p align="center"> <img src="figures/fp.png" width= "600"> </p> <p align="center"> <img src="figures/rank.png" width= "600"> </p>In this paper, we first perform a detailed analysis of the inferior performances of linear attention from two perspectives: focus ability and feature diversity. Then, we introduce a simple yet effective mapping function and an efficient rank restoration module and propose our Focused Linear Attention (FLatten) which adequately addresses these concerns and achieves high efficiency and expressive capability.
Results
- Comparison of different models on ImageNet-1K.
- Accuracy-Runtime curve on ImageNet.
Dependencies
- Python 3.9
- PyTorch == 1.11.0
- torchvision == 0.12.0
- numpy
- timm == 0.4.12
- einops
- yacs
Data preparation
The ImageNet dataset should be prepared as follows:
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
Pretrained Models
Based on different model architectures, we provide several pretrained models, as listed below.
model | Reso | acc@1 | config | pretrained weights |
---|---|---|---|---|
FLatten-PVT-T | $224^2$ | 77.8 (+2.7) | config | TsinghuaCloud |
FLatten-PVTv2-B0 | $224^2$ | 71.1 (+0.6) | config | TsinghuaCloud |
FLatten-Swin-T | $224^2$ | 82.1 (+0.8) | config | TsinghuaCloud |
FLatten-Swin-S | $224^2$ | 83.5 (+0.5) | config | TsinghuaCloud |
FLatten-Swin-B | $224^2$ | 83.8 (+0.3) | config | TsinghuaCloud |
FLatten-Swin-B | $384^2$ | 85.0 (+0.5) | config | TsinghuaCloud |
FLatten-CSwin-T | $224^2$ | 83.1 (+0.4) | config | TsinghuaCloud |
Evaluate one model on ImageNet:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>
Outputs of the four T/B0 pretrained models are:
[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 294): INFO * Acc@1 77.758 Acc@5 93.910
[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 149): INFO Accuracy of the network on the 50000 test images: 77.8%
[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 294): INFO * Acc@1 71.098 Acc@5 90.596
[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 149): INFO Accuracy of the network on the 50000 test images: 71.1%
[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 294): INFO * Acc@1 82.106 Acc@5 95.900
[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 149): INFO Accuracy of the network on the 50000 test images: 82.1%
[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 294): INFO * Acc@1 83.130 Acc@5 96.376
[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 149): INFO Accuracy of the network on the 50000 test images: 83.1%
Train Models from Scratch
- To train
FLatten-PVT-T/S/M/B
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_t.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_s.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_m.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_b.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
- To train
FLatten-PVT-v2-b0/1/2/3/4
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b0.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b1.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b2.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b3.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b4.yaml --data-path <imagenet-path> --output <output-path>
- To train
FLatten-Swin-T/S/B
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_t.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_s.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b.yaml --data-path <imagenet-path> --output <output-path>
- To train
FLatten-CSwin-T/S/B
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_t.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_s.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99982
Fine-tuning on higher resolution
Fine-tune a FLatten-Swin-B
model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights>
Fine-tune a FLatten-CSwin-B
model pre-trained on 224x224 resolution to 384x384 resolution:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights> --model-ema --model-ema-decay 0.99982
Visualization
We provide code for visualizing flatten attention. For example, to visualize flatten attention in FLatten-Swin-T, add the following to this line.
from visualize import AttnVisualizer
visualizer = AttnVisualizer(qk=[q, k], kernel=self.dwc.weight, name='flatten_swin_t')
visualizer.visualize_all_attn(max_num=196, image='./visualize/img_ori_00809.png')
Then run:
python visualize.py
Note: Don't forget to modify the path of FLatten-Swin-T pretrained weight in visualize.py
.
Acknowledgements
This code is developed on the top of Swin Transformer. The computational resources supporting this work are provided by Hangzhou High-Flyer AI Fundamental Research Co.,Ltd
Citation
If you find this repo helpful, please consider citing us.
@InProceedings{han2023flatten,
title={FLatten Transformer: Vision Transformer using Focused Linear Attention},
author={Han, Dongchen and Pan, Xuran and Han, Yizeng and Song, Shiji and Huang, Gao},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2023}
}
Contact
If you have any questions, please feel free to contact the authors.
Dongchen Han: hdc23@mails.tsinghua.edu.cn
Xuran Pan: pxr18@mails.tsinghua.edu.cn