Home

Awesome

(Generic) EfficientNets for PyTorch

-- **NOTE** This repo is not being maintained --

Please use timm instead. It includes all of these model definitions (compatible weights) and much much more.

A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search.

All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from here)

What's New

Aug 19, 2020

April 5, 2020

March 23, 2020

Feb 12, 2020

Jan 22, 2020

Nov 22, 2019

Nov 15, 2019

Oct 30, 2019

Oct 27, 2019

Models

Implemented models include:

I originally implemented and trained some these models with code here, this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code.

Pretrained

I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models

ModelPrec@1 (Err)Prec@5 (Err)Param#(M)MAdds(M)Image ScalingResolutionCrop
efficientnet_b382.240 (17.760)96.116 (3.884)12.23TBDbicubic3201.0
efficientnet_b382.076 (17.924)96.020 (3.980)12.23TBDbicubic3000.904
mixnet_xl81.074 (18.926)95.282 (4.718)11.90TBDbicubic2561.0
efficientnet_b280.612 (19.388)95.318 (4.682)9.1TBDbicubic2881.0
mixnet_xl80.476 (19.524)94.936 (5.064)11.90TBDbicubic2240.875
efficientnet_b280.288 (19.712)95.166 (4.834)9.11003bicubic2600.890
mixnet_l78.976 (21.02494.184 (5.816)7.33TBDbicubic2240.875
efficientnet_b178.692 (21.308)94.086 (5.914)7.8694bicubic2400.882
efficientnet_es78.066 (21.934)93.926 (6.074)5.44TBDbicubic2240.875
efficientnet_b077.698 (22.302)93.532 (6.468)5.3390bicubic2240.875
mobilenetv2_120d77.294 (22.70693.502 (6.498)5.8TBDbicubic2240.875
mixnet_m77.256 (22.744)93.418 (6.582)5.01353bicubic2240.875
mobilenetv2_14076.524 (23.476)92.990 (7.010)6.1TBDbicubic2240.875
mixnet_s75.988 (24.012)92.794 (7.206)4.13TBDbicubic2240.875
mobilenetv3_large_10075.766 (24.234)92.542 (7.458)5.5TBDbicubic2240.875
mobilenetv3_rw75.634 (24.366)92.708 (7.292)5.5219bicubic2240.875
efficientnet_lite075.472 (24.528)92.520 (7.480)4.65TBDbicubic2240.875
mnasnet_a175.448 (24.552)92.604 (7.396)3.9312bicubic2240.875
fbnetc_10075.124 (24.876)92.386 (7.614)5.6385bilinear2240.875
mobilenetv2_110d75.052 (24.948)92.180 (7.820)4.5TBDbicubic2240.875
mnasnet_b174.658 (25.342)92.114 (7.886)4.4315bicubic2240.875
spnasnet_10074.084 (25.916)91.818 (8.182)4.4TBDbilinear2240.875
mobilenetv2_10072.978 (27.022)91.016 (8.984)3.5TBDbicubic2240.875

More pretrained models to come...

Ported Weights

The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args.

IMPORTANT:

To run validation for tf_efficientnet_b5: python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic

To run validation w/ TF preprocessing for tf_efficientnet_b5: python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing

To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp: python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5

ModelPrec@1 (Err)Prec@5 (Err)Param #Image ScalingImage SizeCrop
tf_efficientnet_l2_ns *tfp88.352 (11.648)98.652 (1.348)480bicubic800N/A
tf_efficientnet_l2_nsTBDTBD480bicubic8000.961
tf_efficientnet_l2_ns_47588.234 (11.766)98.546 (1.454)480bicubic4750.936
tf_efficientnet_l2_ns_475 *tfp88.172 (11.828)98.566 (1.434)480bicubic475N/A
tf_efficientnet_b7_ns *tfp86.844 (13.156)98.084 (1.916)66.35bicubic600N/A
tf_efficientnet_b7_ns86.840 (13.160)98.094 (1.906)66.35bicubic600N/A
tf_efficientnet_b6_ns86.452 (13.548)97.882 (2.118)43.04bicubic528N/A
tf_efficientnet_b6_ns *tfp86.444 (13.556)97.880 (2.120)43.04bicubic528N/A
tf_efficientnet_b5_ns *tfp86.064 (13.936)97.746 (2.254)30.39bicubic456N/A
tf_efficientnet_b5_ns86.088 (13.912)97.752 (2.248)30.39bicubic456N/A
tf_efficientnet_b8_ap *tfp85.436 (14.564)97.272 (2.728)87.4bicubic672N/A
tf_efficientnet_b8 *tfp85.384 (14.616)97.394 (2.606)87.4bicubic672N/A
tf_efficientnet_b885.370 (14.630)97.390 (2.610)87.4bicubic6720.954
tf_efficientnet_b8_ap85.368 (14.632)97.294 (2.706)87.4bicubic6720.954
tf_efficientnet_b4_ns *tfp85.298 (14.702)97.504 (2.496)19.34bicubic380N/A
tf_efficientnet_b4_ns85.162 (14.838)97.470 (2.530)19.34bicubic3800.922
tf_efficientnet_b7_ap *tfp85.154 (14.846)97.244 (2.756)66.35bicubic600N/A
tf_efficientnet_b7_ap85.118 (14.882)97.252 (2.748)66.35bicubic6000.949
tf_efficientnet_b7 *tfp84.940 (15.060)97.214 (2.786)66.35bicubic600N/A
tf_efficientnet_b784.932 (15.068)97.208 (2.792)66.35bicubic6000.949
tf_efficientnet_b6_ap84.786 (15.214)97.138 (2.862)43.04bicubic5280.942
tf_efficientnet_b6_ap *tfp84.760 (15.240)97.124 (2.876)43.04bicubic528N/A
tf_efficientnet_b5_ap *tfp84.276 (15.724)96.932 (3.068)30.39bicubic456N/A
tf_efficientnet_b5_ap84.254 (15.746)96.976 (3.024)30.39bicubic4560.934
tf_efficientnet_b6 *tfp84.140 (15.860)96.852 (3.148)43.04bicubic528N/A
tf_efficientnet_b684.110 (15.890)96.886 (3.114)43.04bicubic5280.942
tf_efficientnet_b3_ns *tfp84.054 (15.946)96.918 (3.082)12.23bicubic300N/A
tf_efficientnet_b3_ns84.048 (15.952)96.910 (3.090)12.23bicubic300.904
tf_efficientnet_b5 *tfp83.822 (16.178)96.756 (3.244)30.39bicubic456N/A
tf_efficientnet_b583.812 (16.188)96.748 (3.252)30.39bicubic4560.934
tf_efficientnet_b4_ap *tfp83.278 (16.722)96.376 (3.624)19.34bicubic380N/A
tf_efficientnet_b4_ap83.248 (16.752)96.388 (3.612)19.34bicubic3800.922
tf_efficientnet_b483.022 (16.978)96.300 (3.700)19.34bicubic3800.922
tf_efficientnet_b4 *tfp82.948 (17.052)96.308 (3.692)19.34bicubic380N/A
tf_efficientnet_b2_ns *tfp82.436 (17.564)96.268 (3.732)9.11bicubic260N/A
tf_efficientnet_b2_ns82.380 (17.620)96.248 (3.752)9.11bicubic2600.89
tf_efficientnet_b3_ap *tfp81.882 (18.118)95.662 (4.338)12.23bicubic300N/A
tf_efficientnet_b3_ap81.828 (18.172)95.624 (4.376)12.23bicubic3000.904
tf_efficientnet_b381.636 (18.364)95.718 (4.282)12.23bicubic3000.904
tf_efficientnet_b3 *tfp81.576 (18.424)95.662 (4.338)12.23bicubic300N/A
tf_efficientnet_lite481.528 (18.472)95.668 (4.332)13.00bilinear3800.92
tf_efficientnet_b1_ns *tfp81.514 (18.486)95.776 (4.224)7.79bicubic240N/A
tf_efficientnet_lite4 *tfp81.502 (18.498)95.676 (4.324)13.00bilinear380N/A
tf_efficientnet_b1_ns81.388 (18.612)95.738 (4.262)7.79bicubic2400.88
tf_efficientnet_el80.534 (19.466)95.190 (4.810)10.59bicubic3000.904
tf_efficientnet_el *tfp80.476 (19.524)95.200 (4.800)10.59bicubic300N/A
tf_efficientnet_b2_ap *tfp80.420 (19.580)95.040 (4.960)9.11bicubic260N/A
tf_efficientnet_b2_ap80.306 (19.694)95.028 (4.972)9.11bicubic2600.890
tf_efficientnet_b2 *tfp80.188 (19.812)94.974 (5.026)9.11bicubic260N/A
tf_efficientnet_b280.086 (19.914)94.908 (5.092)9.11bicubic2600.890
tf_efficientnet_lite379.812 (20.188)94.914 (5.086)8.20bilinear3000.904
tf_efficientnet_lite3 *tfp79.734 (20.266)94.838 (5.162)8.20bilinear300N/A
tf_efficientnet_b1_ap *tfp79.532 (20.468)94.378 (5.622)7.79bicubic240N/A
tf_efficientnet_cc_b1_8e *tfp79.464 (20.536)94.492 (5.508)39.7bicubic2400.88
tf_efficientnet_cc_b1_8e79.298 (20.702)94.364 (5.636)39.7bicubic2400.88
tf_efficientnet_b1_ap79.278 (20.722)94.308 (5.692)7.79bicubic2400.88
tf_efficientnet_b1 *tfp79.172 (20.828)94.450 (5.550)7.79bicubic240N/A
tf_efficientnet_em *tfp78.958 (21.042)94.458 (5.542)6.90bicubic240N/A
tf_efficientnet_b0_ns *tfp78.806 (21.194)94.496 (5.504)5.29bicubic224N/A
tf_mixnet_l *tfp78.846 (21.154)94.212 (5.788)7.33bilinear224N/A
tf_efficientnet_b178.826 (21.174)94.198 (5.802)7.79bicubic2400.88
tf_mixnet_l78.770 (21.230)94.004 (5.996)7.33bicubic2240.875
tf_efficientnet_em78.742 (21.258)94.332 (5.668)6.90bicubic2400.875
tf_efficientnet_b0_ns78.658 (21.342)94.376 (5.624)5.29bicubic2240.875
tf_efficientnet_cc_b0_8e *tfp78.314 (21.686)93.790 (6.210)24.0bicubic2240.875
tf_efficientnet_cc_b0_8e77.908 (22.092)93.656 (6.344)24.0bicubic2240.875
tf_efficientnet_cc_b0_4e *tfp77.746 (22.254)93.552 (6.448)13.3bicubic2240.875
tf_efficientnet_cc_b0_4e77.304 (22.696)93.332 (6.668)13.3bicubic2240.875
tf_efficientnet_es *tfp77.616 (22.384)93.750 (6.250)5.44bicubic224N/A
tf_efficientnet_lite2 *tfp77.544 (22.456)93.800 (6.200)6.09bilinear260N/A
tf_efficientnet_lite277.460 (22.540)93.746 (6.254)6.09bicubic2600.89
tf_efficientnet_b0_ap *tfp77.514 (22.486)93.576 (6.424)5.29bicubic224N/A
tf_efficientnet_es77.264 (22.736)93.600 (6.400)5.44bicubic224N/A
tf_efficientnet_b0 *tfp77.258 (22.742)93.478 (6.522)5.29bicubic224N/A
tf_efficientnet_b0_ap77.084 (22.916)93.254 (6.746)5.29bicubic2240.875
tf_mixnet_m *tfp77.072 (22.928)93.368 (6.632)5.01bilinear224N/A
tf_mixnet_m76.950 (23.050)93.156 (6.844)5.01bicubic2240.875
tf_efficientnet_b076.848 (23.152)93.228 (6.772)5.29bicubic2240.875
tf_efficientnet_lite1 *tfp76.764 (23.236)93.326 (6.674)5.42bilinear240N/A
tf_efficientnet_lite176.638 (23.362)93.232 (6.768)5.42bicubic2400.882
tf_mixnet_s *tfp75.800 (24.200)92.788 (7.212)4.13bilinear224N/A
tf_mobilenetv3_large_100 *tfp75.768 (24.232)92.710 (7.290)5.48bilinear224N/A
tf_mixnet_s75.648 (24.352)92.636 (7.364)4.13bicubic2240.875
tf_mobilenetv3_large_10075.516 (24.484)92.600 (7.400)5.48bilinear2240.875
tf_efficientnet_lite0 *tfp75.074 (24.926)92.314 (7.686)4.65bilinear224N/A
tf_efficientnet_lite074.842 (25.158)92.170 (7.830)4.65bicubic2240.875
tf_mobilenetv3_large_075 *tfp73.730 (26.270)91.616 (8.384)3.99bilinear224N/A
tf_mobilenetv3_large_07573.442 (26.558)91.352 (8.648)3.99bilinear2240.875
tf_mobilenetv3_large_minimal_100 *tfp72.678 (27.322)90.860 (9.140)3.92bilinear224N/A
tf_mobilenetv3_large_minimal_10072.244 (27.756)90.636 (9.364)3.92bilinear2240.875
tf_mobilenetv3_small_100 *tfp67.918 (32.082)87.958 (12.0422.54bilinear224N/A
tf_mobilenetv3_small_10067.918 (32.082)87.662 (12.338)2.54bilinear2240.875
tf_mobilenetv3_small_075 *tfp66.142 (33.858)86.498 (13.502)2.04bilinear224N/A
tf_mobilenetv3_small_07565.718 (34.282)86.136 (13.864)2.04bilinear2240.875
tf_mobilenetv3_small_minimal_100 *tfp63.378 (36.622)84.802 (15.198)2.04bilinear224N/A
tf_mobilenetv3_small_minimal_10062.898 (37.102)84.230 (15.770)2.04bilinear2240.875

*tfp models validated with tf-preprocessing pipeline

Google tf and tflite weights ported from official Tensorflow repositories

Usage

Environment

All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x.

Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself.

PyTorch versions 1.4, 1.5, 1.6 have been tested with this code.

I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:

conda create -n torch-env
conda activate torch-env
conda install -c pytorch pytorch torchvision cudatoolkit=10.2

PyTorch Hub

Models can be accessed via the PyTorch Hub API

>>> torch.hub.list('rwightman/gen-efficientnet-pytorch')
['efficientnet_b0', ...]
>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)
>>> model.eval()
>>> output = model(torch.randn(1,3,224,224))

Pip

This package can be installed via pip.

Install (after conda env/install):

pip install geffnet

Eval use:

>>> import geffnet
>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True)
>>> m.eval()

Train use:

>>> import geffnet
>>> # models can also be created by using the entrypoint directly
>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2)
>>> m.train()

Create in a nn.Sequential container, for fast.ai, etc:

>>> import geffnet
>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True)

Exporting

Scripts are included to

As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation:

python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx
python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx 

These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible export now requires additional args mentioned in the export script (not needed in earlier versions).

Export Notes

  1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless _EXPORTABLE flag in config.py is set to True. Use config.set_exportable(True) as in the onnx_export.py script.
  2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working.
  3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization.
  4. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here.