Home

Awesome

SparK: the first successful BERT/MAE-style pretraining on any convolutional networks  Reddit Twitter

This is the official implementation of ICLR paper Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling, which can pretrain any CNN (e.g., ResNet) in a BERT-style self-supervised manner. We've tried our best to make the codebase clean, short, easy to read, state-of-the-art, and only rely on minimal dependencies.

<!-- <p align="center"> --> <!-- <img src="https://user-images.githubusercontent.com/39692511/211496814-e6cb9243-833c-43d2-a859-d35afa96ed22.png" width=86% class="center"> --> <!-- </p> -->

https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33-ab21-4785b8a7833c.mp4

<br> <div align="center">

SOTA  OpenReview  arXiv

</div> <!-- <div align="center"> --> <!-- [[`pdf`](https://arxiv.org/pdf/2301.03580.pdf)] --> <!-- [[`bibtex`](https://github.com/keyu-tian/SparK#citation)] --> <!-- </div> -->

🔥 News

<!-- ## 📺 Video demo (we use [these ppt slides](https://github.com/keyu-tian/SparK/releases/tag/file_sharing) to make the animated video) --> <!-- https://user-images.githubusercontent.com/6366788/213662770-5f814de0-cbe8-48d9-8235-e8907fd81e0e.mp4 -->

🕹️ Colab Visualization Demo

Check pretrain/viz_reconstruction.ipynb for visualizing the reconstruction of SparK pretrained models, like:

<p align="center"> <img src="https://user-images.githubusercontent.com/39692511/226376648-3f28a1a6-275d-4f88-8f3e-cd1219882488.png" width=50% <p>

We also provide pretrain/viz_spconv.ipynb that shows the "mask pattern vanishing" issue of dense conv layers.

What's new here?

🔥 Pretrained CNN beats pretrained Swin-Transformer:

<p align="center"> <img src="https://user-images.githubusercontent.com/39692511/226844278-1dc1e13c-1f07-4b8f-9843-8c47fca47253.jpg" width=66%> <p>

🔥 After SparK pretraining, smaller models can beat un-pretrained larger models:

<p align="center"> <img src="https://user-images.githubusercontent.com/39692511/226861835-77e43c07-0a00-4020-9395-03e81bfe6959.jpg" width=72%> <p>

🔥 All models can benefit, showing a scaling behavior:

<p align="center"> <img src="https://user-images.githubusercontent.com/39692511/211705760-de15f4a1-0508-4690-981e-5640f4516d2a.png" width=65%> <p>

🔥 Generative self-supervised pretraining surpasses contrastive learning:

<p align="center"> <img src="https://user-images.githubusercontent.com/39692511/211497479-0563e891-f2ad-4cf1-b682-a21c2be1442d.png" width=65%> <p>

See our paper for more analysis, discussions, and evaluations.

Todo list

<details> <summary>catalog</summary> </details>

Pretrained weights (self-supervised; w/o decoder; can be directly finetuned)

Note: for network definitions, we directly use timm.models.ResNet and official ConvNeXt.

reso.: the image resolution; acc@1: ImageNet-1K finetuned acc (top-1)

arch.reso.acc@1#paramsflopsweights (self-supervised, without SparK's decoder)
ResNet5022480.626M4.1Gresnet50_1kpretrained_timm_style.pth
ResNet10122482.245M7.9Gresnet101_1kpretrained_timm_style.pth
ResNet15222482.760M11.6Gresnet152_1kpretrained_timm_style.pth
ResNet20022483.165M15.1Gresnet200_1kpretrained_timm_style.pth
ConvNeXt-S22484.150M8.7GconvnextS_1kpretrained_official_style.pth
ConvNeXt-B22484.889M15.4GconvnextB_1kpretrained_official_style.pth
ConvNeXt-L22485.4198M34.4GconvnextL_1kpretrained_official_style.pth
ConvNeXt-L38486.0198M101.0GconvnextL_384_1kpretrained_official_style.pth
<details> <summary> <b> Pretrained weights (with SparK's UNet-style decoder; can be used to reconstruct images) </b> </summary> <br>
arch.reso.acc@1#paramsflopsweights (self-supervised, with SparK's decoder)
ResNet5022480.626M4.1Gres50_withdecoder_1kpretrained_spark_style.pth
ResNet10122482.245M7.9Gres101_withdecoder_1kpretrained_spark_style.pth
ResNet15222482.760M11.6Gres152_withdecoder_1kpretrained_spark_style.pth
ResNet20022483.165M15.1Gres200_withdecoder_1kpretrained_spark_style.pth
ConvNeXt-S22484.150M8.7GcnxS224_withdecoder_1kpretrained_spark_style.pth
ConvNeXt-L38486.0198M101.0GcnxL384_withdecoder_1kpretrained_spark_style.pth
</details> <br>

Installation & Running

We highly recommended you to use torch==1.10.0, torchvision==0.11.1, and timm==0.5.4 for reproduction. Check INSTALL.md to install all pip dependencies.

# download our weights `resnet50_1kpretrained_timm_style.pth` first
import torch, timm
res50, state = timm.create_model('resnet50'), torch.load('resnet50_1kpretrained_timm_style.pth', 'cpu')
res50.load_state_dict(state.get('module', state), strict=False)     # just in case the model weights are actually saved in state['module']

Acknowledgement

We referred to these useful codebases:

License

This project is under the MIT license. See LICENSE for more details.

Citation

If you found this project useful, you can kindly give us a star ⭐, or cite us in your work 📖:

@Article{tian2023designing,
  author  = {Keyu Tian and Yi Jiang and Qishuai Diao and Chen Lin and Liwei Wang and Zehuan Yuan},
  title   = {Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling},
  journal = {arXiv:2301.03580},
  year    = {2023},
}