The flaxvision package contains a selection of neural network models ported from torchvision to be used with JAX & Flax.
Note: flaxvision is currently in active development. API and functionality may change between releases.
Planned features for the first release:
- Update models to linen API
- Add support for transfer learning
- Add support to ResNet for dilated convolutions
- Port DeepLabv3 model for image segmentation
from jax import random
from flaxvision import models
rng = random.PRNGKey(0)
pretrained_model = models.vgg16(rng, pretrained=True)
If interested in adding additional models or improving existent ones, please start by openning an Issue describing your idea.
The initial work for flaxvision started during the Google Summer of Code program at Google AI under Avital Oliver's mentorship.