Awesome
BiT-jax2tf
This repository hosts the code to port NumPy model weights of BiT-ResNets [1] to TensorFlow SavedModel format. These models are results of [2]. The original model weights come from [3].
Huge thanks to Willi Gierke (of Google) for helping with the porting.
The TensorFlow SavedModels are available on TensorFlow Hub as a collection: https://tfhub.dev/sayakpaul/collections/bit-resnet/1. A total of 8 models are available:
Model <br>Name | Input<br>Resolution | Classifier | Feature<br>Extractor |
---|---|---|---|
BiT-ResNet152x2 | 384 | Link | Link |
BiT-ResNet152x2 | 224 | Link | Link |
BiT-ResNet50x1 | 224 | Link | Link |
BiT-ResNet50x1 | 160 | Link | Link |
You could use the convert_jax_weights_tf.ipynb
notebook to understand how model porting works between JAX and TensorFlow. There
is also an experimental tool called jax2tf
from the JAX team that you can find here.
References
[1] Big Transfer (BiT): General Visual Representation Learning by Kolesnikov et al.
[2] Knowledge distillation: A good teacher is patient and consistent by Beyer et al.
[3] BiT GitHub