Home

Awesome

Flops counting tool for neural networks in pytorch framework

Pypi version

This tool is designed to compute the theoretical amount of multiply-add operations in neural networks. It can also compute the number of parameters and print per-layer computational cost of a given network.

ptflops has two backends, pytorch and aten. pytorch backend is a legacy one, it considers nn.Modules only. However, it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use aten backend, which considers aten operations, and therefore it covers more model architectures (including transformers). The default backend is aten. Please, don't use pytorch backend for transformer architectures.

aten backend

Operations considered:

Usage tips

pytorch backend

Supported layers:

Experimental support:

Usage tips

Requirements

Pytorch >= 2.0. Use pip install ptflops==0.7.2.2 to work with torch 1.x.

Install the latest version

From PyPI:

pip install ptflops

From this repository:

pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git

Example

import torchvision.models as models
import torch
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
  net = models.densenet161()
  macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='pytorch'
                                           print_per_layer_stat=True, verbose=True)
  print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
  print('{:<30}  {:<8}'.format('Number of parameters: ', params))

  macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='aten'
                                           print_per_layer_stat=True, verbose=True)
  print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
  print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Citation

If ptflops was useful for your paper or tech report, please cite me:

@online{ptflops,
  author = {Vladislav Sovrasov},
  title = {ptflops: a flops counting tool for neural networks in pytorch framework},
  year = 2018-2024,
  url = {https://github.com/sovrasov/flops-counter.pytorch},
}

Credits

Thanks to @warmspringwinds and Horace He for the initial version of the script.

Benchmark

torchvision

ModelInput ResolutionParams(M)MACs(G) (pytorch)MACs(G) (aten)
alexnet224x22461.100.720.71
convnext_base224x22488.5915.4315.38
densenet121224x2247.982.90
efficientnet_b0224x2245.290.41
efficientnet_v2_m224x22454.145.43
googlenet224x22413.001.51
inception_v3224x22427.165.755.71
maxvit_t224x22430.925.48
mnasnet1_0224x2244.380.33
mobilenet_v2224x2243.500.32
mobilenet_v3_large224x2245.480.23
regnet_y_1_6gf224x22411.201.65
resnet18224x22411.691.831.81
resnet50224x22425.564.134.09
resnext50_32x4d224x22425.034.29
shufflenet_v2_x1_0224x2242.280.15
squeezenet1_0224x2241.250.840.82
vgg16224x224138.3615.5215.48
vit_b_16224x22486.5717.61 (wrong)16.86
wide_resnet50_2224x22468.8811.45

timm

Model | Input Resolution | Params(M) | MACs(G)