Home

Awesome

Flax Image Models

<strong>Introduction</strong><br><strong>Installation</strong><br><strong>Usage</strong><br><strong>Examples</strong><br><strong>Available Architectures</strong><br><strong>Contributing</strong><br><strong>Acknowledgements</strong><br>

Introduction

flaim is a library of state-of-the-art pre-trained vision models, plus common deep learning modules in computer vision, for Flax. It exposes a host of diverse image models through a straightforward interface with an emphasis on simplicity, leanness, and readability, and supplies lower-level modules for designing custom architectures.

Installation

flaim can be installed through pip install flaim. Beware that pip installs the CPU version of JAX, and you must manually install JAX yourself to run your programs on a GPU or TPU.

Usage

flaim.get_model is the central function of flaim and manages model retrieval. It takes a handful of arguments:

flaim.get_model returns the model, its parameters, and, if pretrained is not False, the normalization statistics associated with the pre-trained parameters. The snippet below constructs an ImageNet1K-trained ResNet-50 with 10 output classes.

import flaim


model, vars, norm_stats = flaim.get_model(
        model_name='resnet50',
        pretrained='in1k_224',
        n_classes=10,
        )

Performing a forward pass with flaim is similar to any other Flax model. However, networks that behave differently during training versus inference, e.g., due to batch normalization, receive a training argument indicating whether the model should be in training mode or not. Furthermore, like any other Flax module incorporating batch normalization, batch_stats must be passed to mutable to update batch normalization's running statistics during training.

from jax import numpy as jnp

# input should be normalized using norm_stats beforehand
input = jnp.ones((2, 224, 224, 3))

# Training
output, new_batch_stats = model.apply(vars, input, training=True, mutable=['batch_stats'])
# Inference
output = model.apply(vars, input, training=False, mutable=False)

Finally, the model's intermediate activations can be captured by passing intermediates to mutable.

output, intermediates = model.apply(vars, input, training=False, mutable=['intermediates'])

If the model is hierarchical, intermediates's entries are the output of each network stage and can be looked up through intermediates['intermediates']['stage_ind'], where ind is the index of the desired stage, with 0 being reserved for the stem. For isotropic models, the output of every block is returned, accessible via intermediates['intermediates']['block_ind'], where ind is the index of the desired block and 0 is once again reserved for the stem.

It should be noted that Flax's sow API, which is used utilized by flaim, appends the intermediate activations to a tuple; that is, if n forward passes are performed, intermediates['intermediates']['stage_ind'] or intermediates['intermediates']['block_ind'] would be tuples of length n, with the i<sup>th</sup> item corresponding to the i<sup>th</sup> forward pass.

Examples

examples/ includes a series of annotated notebooks for solving various vision problems such as object classification using flaim.

Available Architectures

All available architectures and their pre-trained parameters, plus short descriptions and references, are listed here.

flaim.list_models also returns a list of (name of model, name of pre-trained parameters) pairs, e.g., (resnet50, in1k_224) and has two arguments:

This function is demonstrated below.

# Every model
print(flaim.list_models())

# ResNeXt-based networks of depth 50
print(flaim.list_models(model_pattern='resnext50'))

# Models trained on ImageNet22K
print(flaim.list_models(params_pattern='in22k'))

# ViTs of input size 384 x 384
print(flaim.list_models(model_pattern='^vit', params_pattern=384))

Contributing

Code contributions are currently not accepted, however, there are three alternatives for those seeking to help flaim evolve:

Acknowledgements

Many thanks to Ross Wightman for the amazing timm package, which was an inspiration for flaim and has been an indispensable guide during development. Additionally, the pre-trained parameters are stored on Hugging Face Hub; big thanks to Hugging Face for offering this service gratis. Also thanks to Google's TPU Research Cloud (TRC) program for providing hardware used to accelerate the development of this project.

References for flaim.models can be found here, and ones for flaim.layers are in the source code.