Awesome
flaxvision
The flaxvision package contains a selection of neural network models ported from torchvision to be used with JAX & Flax.
Note: flaxvision is currently in active development. API and functionality may change between releases.
Roadmap to v0.1.0
Planned features for the first release:
- Update models to linen API
- Add support for transfer learning
- Add support to ResNet for dilated convolutions
- Port DeepLabv3 model for image segmentation
Quickstart
Transfer Learning Example
from jax import random
from flaxvision import models
rng = random.PRNGKey(0)
pretrained_model = models.vgg16(rng, pretrained=True)
How To Contribute
If interested in adding additional models or improving existent ones, please start by openning an Issue describing your idea.
Acknowledgments
The initial work for flaxvision started during the Google Summer of Code program at Google AI under Avital Oliver's mentorship.