Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX #703

Open
amalrkc opened this issue Jun 4, 2024 · 0 comments
Open

Add JAX #703

amalrkc opened this issue Jun 4, 2024 · 0 comments
Assignees
Labels

Comments

@amalrkc
Copy link

amalrkc commented Jun 4, 2024

JAX

Google JAX is a machine learning framework for transforming numerical functions, to be used in Python. It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch. The primary functions of JAX are:

  1. grad: automatic differentiation
  2. jit: compilation
  3. vmap: auto-vectorization
  4. pmap: SPMD programming

Wikipedia

Github Repo

JAX in 100 seconds by Fireship

Link to the vector file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants