Home

Awesome

BiT-jax2tf

This repository hosts the code to port NumPy model weights of BiT-ResNets [1] to TensorFlow SavedModel format. These models are results of [2]. The original model weights come from [3].

Huge thanks to Willi Gierke (of Google) for helping with the porting.

The TensorFlow SavedModels are available on TensorFlow Hub as a collection: https://tfhub.dev/sayakpaul/collections/bit-resnet/1. A total of 8 models are available:

Model <br>NameInput<br>ResolutionClassifierFeature<br>Extractor
BiT-ResNet152x2384LinkLink
BiT-ResNet152x2224LinkLink
BiT-ResNet50x1224LinkLink
BiT-ResNet50x1160LinkLink

You could use the convert_jax_weights_tf.ipynb notebook to understand how model porting works between JAX and TensorFlow. There is also an experimental tool called jax2tf from the JAX team that you can find here.

References

[1] Big Transfer (BiT): General Visual Representation Learning by Kolesnikov et al.

[2] Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

[3] BiT GitHub