Home

Awesome

<img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/main/figures/logo/logo_name.svg" style="width: 400px;">

A Python library for Continual Inference Networks in PyTorch

Quick-startDocsPrinciplesPaperExamplesModulesModel ZooContributeLicense

<div> <a href="https://pypi.org/project/continual-inference/" style="display:inline-block;"> <img src="https://img.shields.io/pypi/pyversions/continual-inference" height="20" > </a> <a href="https://badge.fury.io/py/continual-inference" style="display:inline-block;"> <img src="https://badge.fury.io/py/continual-inference.svg" height="20" > </a> <a href="https://continual-inference.readthedocs.io/en/latest/generated/README.html" style="display:inline-block;"> <img src="https://readthedocs.org/projects/continual-inference/badge/?version=latest" alt="Documentation Status" height="20"/> </a> <a href="https://pepy.tech/project/continual-inference" style="display:inline-block;"> <img src="https://static.pepy.tech/badge/continual-inference" height="20"> </a> <a href="https://codecov.io/gh/LukasHedegaard/continual-inference" style="display:inline-block;"> <img src="https://codecov.io/gh/LukasHedegaard/continual-inference/branch/main/graph/badge.svg?token=XW1UQZSEOG" height="20"/> </a> <a href="https://opensource.org/licenses/Apache-2.0" style="display:inline-block;"> <img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" height="20"> </a> <!-- <a href="https://arxiv.org/abs/2204.03418" style="display:inline-block;"> <img src="http://img.shields.io/badge/paper-arxiv.2204.03418-B31B1B.svg" height="20" > </a> --> <a href="https://github.com/psf/black" style="display:inline-block;"> <img src="https://img.shields.io/badge/code%20style-black-000000.svg" height="20"> </a> <a href="https://www.codefactor.io/repository/github/lukashedegaard/continual-inference/overview/main" style="display:inline-block;"> <img src="https://www.codefactor.io/repository/github/lukashedegaard/continual-inference/badge/main" alt="We match PyTorch interfaces exactly. Method arguments named 'input' reduce the codefactor to 'A-'" height="20" /> </a> </div>

Continual Inference Networks ensure efficient stream processing

Many of our favorite Deep Neural Network architectures (e.g., CNNs and Transformers) were built with offline-processing for offline processing. Rather than processing inputs one sequence element at a time, they require the whole (spatio-)temporal sequence to be passed as a single input. Yet, many important real-life applications need online predictions on a continual input stream. While CNNs and Transformers can be applied by re-assembling and passing sequences within a sliding window, this is inefficient due to the redundant intermediary computations from overlapping clips.

Continual Inference Networks (CINs) are built to ensure efficient stream processing by employing an alternative computational ordering, which allows sequential computations without the use of sliding window processing. In general, CINs requires approx. L × fewer FLOPs per prediction compared to sliding window-based inference with non-CINs, where L is the corresponding sequence length of a non-CIN network. For more details, check out the videos below describing Continual 3D CNNs [1] and Transformers [2].

<div align="center"> <a href="http://www.youtube.com/watch?feature=player_embedded&v=Jm2A7dVEaF4" target="_blank"> <img src="http://img.youtube.com/vi/Jm2A7dVEaF4/hqdefault.jpg" alt="Presentation of Continual 3D CNNs" style="width:240px;height:auto;" /> </a> <a href="http://www.youtube.com/watch?feature=player_embedded&v=gy802Tlp-eQ" target="_blank"> <img src="http://img.youtube.com/vi/gy802Tlp-eQ/hqdefault.jpg" alt="Presentation of Continual Transformers" style="width:240px;height:auto;" /> </a> </div>

News

Quick-start

Install

pip install continual-inference

Example

co modules are weight-compatible drop-in replacement for torch.nn, enhanced with the capability of efficient continual inference:

import torch
import continual as co
                                                           
#                      B, C, T, H, W
example = torch.randn((1, 1, 5, 3, 3))

conv = co.Conv3d(in_channels=1, out_channels=1, kernel_size=(3, 3, 3))

# Same exact computation as torch.nn.Conv3d ✅
output = conv(example)

# But can also perform online inference efficiently 🚀
firsts = conv.forward_steps(example[:, :, :4])
last = conv.forward_step(example[:, :, 4])

assert torch.allclose(output[:, :, : conv.delay], firsts)
assert torch.allclose(output[:, :, conv.delay], last)

# Temporal properties
assert conv.receptive_field == 3
assert conv.delay == 2

See the network composition and model zoo sections for additional examples.

Library principles

Forward modes

The library components feature three distinct forward modes, which are handy for different situations, namely forward, forward_step, and forward_steps:

forward(input)

Performs a forward computation over multiple time-steps. This function is identical to the corresponding module in torch.nn, ensuring cross-compatibility. Moreover, it's handy for efficient training on clip-based data.

         O            (O: output)
         ↑ 
         N            (N: network module)
         ↑ 
 -----------------    (-: aggregation)
 P   I   I   I   P    (I: input frame, P: padding)

forward_step(input, update_state=True)

Performs a forward computation for a single frame and (optionally) updates internal states accordingly. This function performs efficient continual inference.

O+S O+S O+S O+S   (O: output, S: updated internal state)
 ↑   ↑   ↑   ↑ 
 N   N   N   N    (N: network module)
 ↑   ↑   ↑   ↑ 
 I   I   I   I    (I: input frame)

forward_steps(input, pad_end=False, update_state=True)

Performs a forward computation across multiple time-steps while updating internal states for continual inference (if update_state=True). Start-padding is always accounted for, but end-padding is omitted per default in expectance of the next input step. It can be added by specifying pad_end=True. If so, the output-input mapping the exact same as that of forward.

         O            (O: output)
         ↑ 
 -----------------    (-: aggregation)
 O  O+S O+S O+S  O    (O: output, S: updated internal state)
 ↑   ↑   ↑   ↑   ↑
 N   N   N   N   N    (N: network module)
 ↑   ↑   ↑   ↑   ↑
 P   I   I   I   P    (I: input frame, P: padding)

__call__

Per default, the __call__ function operates identically to torch.nn and executes forward. We supply two options for changing this behavior, namely the call_mode property and the call_mode context manager. An example of their use follows:

timeseries = torch.randn(batch, channel, time)
timestep = timeseries[:, :, 0]

net(timeseries)  # Invokes net.forward(timeseries)

# Assign permanent call_mode property
net.call_mode = "forward_step"
net(timestep)  # Invokes net.forward_step(timestep)

# Assign temporary call_mode with context manager
with co.call_mode("forward_steps"):
    net(timeseries)  # Invokes net.forward_steps(timeseries)

net(timestep)  # Invokes net.forward_step(timestep) again

Composition

Continual Inference Networks require strict handling of internal data delays to guarantee correspondence between forward modes. While it is possible to compose neural networks by defining forward, forward_step, and forward_steps manually, correct handling of delays is cumbersome and time-consuming. Instead, we provide a rich interface of container modules, which handles delays automatically. On top of co.Sequential (which is a drop-in replacement of torch.nn.Sequential), we provide modules for handling parallel and conditional dataflow.

Composition examples:

<details> <summary><b>Residual module</b></summary>

Short-hand:

residual = co.Residual(co.Conv3d(32, 32, kernel_size=3, padding=1))

Explicit:

residual = co.Sequential(
    co.Broadcast(2),
    co.Parallel(
        co.Conv3d(32, 32, kernel_size=3, padding=1),
        co.Delay(2),
    ),
    co.Reduce("sum"),
)
</details> <details> <summary><b>3D MobileNetV2 Inverted residual block</b></summary>

Continual 3D version of the MobileNetV2 Inverted residual block.

<div align="center"> <img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/main/figures/examples/mb_conv.png" style="width: 15vw; min-width: 200px;"> <br> MobileNetV2 Inverted residual block. Source: https://arxiv.org/pdf/1801.04381.pdf </div>
mb_conv = co.Residual(
    co.Sequential(
      co.Conv3d(32, 64, kernel_size=(1, 1, 1)),
      nn.BatchNorm3d(64),
      nn.ReLU6(),
      co.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1), groups=64),
      nn.ReLU6(),
      co.Conv3d(64, 32, kernel_size=(1, 1, 1)),
      nn.BatchNorm3d(32),
    )
)
</details> <details> <summary><b>3D Squeeze-and-Excitation module</b></summary>

Continual 3D version of the Squeeze-and-Excitation module

<div align="center"> <img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/main/figures/examples/se_block.png" style="width: 15vw; min-width: 200px;"> <br> Squeeze-and-Excitation block. Scale refers to a broadcasted element-wise multiplication. Adapted from: https://arxiv.org/pdf/1709.01507.pdf </div>
se = co.Residual(
    co.Sequential(
        OrderedDict([
            ("pool", co.AdaptiveAvgPool3d((1, 1, 1), kernel_size=7)),
            ("down", co.Conv3d(256, 16, kernel_size=1)),
            ("act1", nn.ReLU()),
            ("up", co.Conv3d(16, 256, kernel_size=1)),
            ("act2", nn.Sigmoid()),
        ])
    ),
    reduce="mul",
)
</details> <details> <summary><b>3D Inception module</b></summary>

Continual 3D version of the Inception module:

<div align="center"> <img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/main/figures/examples/inception_block.png" style="width: 25vw; min-width: 350px;"> <br> Inception module. Source: https://arxiv.org/pdf/1409.4842v1.pdf </div>
def norm_relu(module, channels):
    return co.Sequential(
        module,
        nn.BatchNorm3d(channels),
        nn.ReLU(),
    )

inception_module = co.BroadcastReduce(
    co.Conv3d(192, 64, kernel_size=1),
    co.Sequential(
        norm_relu(co.Conv3d(192, 96, kernel_size=1), 96),
        norm_relu(co.Conv3d(96, 128, kernel_size=3, padding=1), 128),
    ),
    co.Sequential(
        norm_relu(co.Conv3d(192, 16, kernel_size=1), 16),
        norm_relu(co.Conv3d(16, 32, kernel_size=5, padding=2), 32),
    ),
    co.Sequential(
        co.MaxPool3d(kernel_size=(1, 3, 3), padding=(0, 1, 1), stride=1),
        norm_relu(co.Conv3d(192, 32, kernel_size=1), 32),
    ),
    reduce="concat",
)
</details>

Input shapes

We enforce a unified ordering of input dimensions for all library modules, namely:

(batch, channel, time, optional_dim2, optional_dim3)

Outputs

The outputs produces by forward_step and forward_steps are identical to those of forward, provided the same data was input beforehand and state update was enabled. We know that input and output shapes aren't necessarily the same when using forward in the PyTorch library, and generally depends on padding, stride and receptive field of a module.

For the forward_step function, this comes to show by some None-valued outputs. Specifically, modules with a delay (i.e. with receptive fields larger than the padding + 1) will produce None until the input count exceeds the delay. Moreover, stride > 1 will produce Tensor outputs every stride steps and None the remaining steps. A visual example is shown below:

<div align="center"> <img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/main/figures/continual/continual-stride.png" style="width:300px;height:auto;"/> </br> A mixed example of delay and outputs under padding and stride. Here, we illustrate the step-wise operation of two co module layers, l1 with with receptive_field = 3, padding = 2, and stride = 2 and l2 with receptive_field = 3, no padding and stride = 1. ⧇ denotes a padded zero, ■ is a non-zero step-feature, and ☒ is an empty output. </div>

For more information, please see the library paper.

Handling state

During stream processing, network modules which operate over multiple time-steps, e.g., a convolution with kernel_size > 1 in the temporal dimension, will aggregate and cache state internally. Each module has its own local state, which can be inspected using module.get_state(). During forward_step and forward_steps, the state is updated unless the forward_step(s) is invoked with an update_state = False argument.

A state cleanup can be accomplished via module.clean_state().

Module library

Continual Inference features a rich collection of modules for defining Continual Inference Networks. Specific care was taken to create CIN versions of the PyTorch modules found in torch.nn:

<details> <summary><b>Convolutions</b></summary> </details> <details> <summary><b>Pooling</b></summary> </details> <details> <summary><b>Linear</b></summary> </details> <details> <summary><b>Recurrent</b></summary> </details> <details> <summary><b>Transformers</b></summary> </details>

Modules for composing and converting networks. Both composition and utility modules can be used for regular definition of PyTorch modules as well.

<details> <summary><b>Composition modules</b></summary> </details> <details> <summary><b>Utility modules</b></summary> </details> <details> <summary><b>Converters</b></summary> </details>

We support drop-in interoperability with with the following torch.nn modules:

<details> <summary><b>Activation</b></summary> </details> <details> <summary><b>Normalization</b></summary> </details> <details> <summary><b>Dropout</b></summary> </details>

Model Zoo and Benchmarks

Continual 3D CNNs

Benchmark results for 1-view testing on Kinetics400. For reference, X3D-L scores 69.3% top-1 acc with 19.2 GFLOPs per prediction.

ArchAvg. pool sizeTop 1 (%)FLOPs (G) per stepFLOPs reductionParams (M)CodeWeights
CoX3D-L6471.61.2515.3x6.2linklink
CoX3D-M6471.00.3315.1x3.8linklink
CoX3D-S6464.70.1712.1x3.8linklink
CoSlow6473.16.908.0x32.5linklink
CoI3D6464.05.685.0x28.0linklink

FLOPs reduction is noted relative to non-continual inference. Note that on-hardware inference doesn't reach the same speedups as "FLOPs reductions" might suggest due to overhead of state reads and writes. This overhead is less important for large batch sizes. This applies to all models in the model zoo.

Continual ST-GCNs

Benchmark results for on NTU RGB+D 60 for the joint modality. For reference, ST-GCN achieves 86% X-Sub and 93.4 X-View accuracy with 16.73 GFLOPs per prediction.

ArchReceptive fieldX-Sub Acc (%)X-View Acc (%)FLOPs (G) per stepFLOPs reductionParams (M)Code
CoST-GCN30086.393.80.16107.7x3.1link
CoA-GCN30084.192.60.17108.7x3.5link
CoST-GCN30086.392.40.15107.6x3.1link

Here, you can download pre-trained,model weights for the above architectures on NTU RGB+D 60, NTU RGB+D 120, and Kinetics-400 on joint and bone modalities.

Continual Transformers

Benchmark results for on THUMOS14 on top of features extracted using a TSN-ResNet50 backbone pre-trained on Kinetics400. For reference, OadTR achieves 64.4 % mAP with 2.5 GFLOPs per prediction.

ArchReceptive fieldmAP (%)FLOPs (G) per stepParams (M)Code
CoOadTR-b16464.20.4115.9link
CoOadTR-b26464.40.019.6link

The library features complete implementations of the one- and two-block continual transformer encoders as well.

Compatibility

The library modules are built to integrate seamlessly with other PyTorch projects. Specifically, extra care was taken to ensure out-of-the-box compatibility with:

<!-- - [onnxruntime](https://github.com/microsoft/onnxruntime) -->

Citation

<a href="https://arxiv.org/abs/2204.03418" style="display:inline-block;"> <img src="http://img.shields.io/badge/paper-arxiv.2204.03418-B31B1B.svg" height="20" > </a>
@inproceedings{hedegaard2022colib,
  title={Continual Inference: A Library for Efficient Online Inference with Deep Neural Networks in PyTorch},
  author={Lukas Hedegaard and Alexandros Iosifidis},
  booktitle={European Conference on Computer Vision Workshops (ECCVW)},
  year={2022}
}

Acknowledgement

This work has received funding from the European Union’s Horizon 2020 research and innovation programme under grant agreement No 871449 (OpenDR).