Skip to content

A selection of neural network models ported from torchvision for JAX & Flax.

License

Notifications You must be signed in to change notification settings

rolandgvc/flaxvision

Repository files navigation

flaxvision

The flaxvision package contains a selection of neural network models ported from torchvision to be used with JAX & Flax.

Note: flaxvision is currently in active development. API and functionality may change between releases.

Roadmap to v0.1.0

Planned features for the first release:

  • Update models to linen API
  • Add support for transfer learning
  • Add support to ResNet for dilated convolutions
  • Port DeepLabv3 model for image segmentation

Quickstart

Transfer Learning Example

from jax import random
from flaxvision import models

rng = random.PRNGKey(0)

pretrained_model = models.vgg16(rng, pretrained=True)

How To Contribute

If interested in adding additional models or improving existent ones, please start by openning an Issue describing your idea.

Acknowledgments

The initial work for flaxvision started during the Google Summer of Code program at Google AI under Avital Oliver's mentorship.

About

A selection of neural network models ported from torchvision for JAX & Flax.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published