Skip to content

JAX implementation of "Gradients without Backpropagation" paper

Notifications You must be signed in to change notification settings

YigitDemirag/forward-gradients

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Gradients without Backpropagation - JAX Implementation

This repository contains a JAX implementation of the methods described in the paper Gradients without Backpropagation

Sometimes, all we want is to get rid of backpropagation of errors and estimate unbiased gradient of loss function during single inference pass :)

Overview

The code demonstrates how to train a simple MLP on MNIST, using either forward gradients (described as $(\nabla f(\boldsymbol{\theta}) \cdot \boldsymbol{v}) \boldsymbol{v}$) calculated by JVP (Jacobian-vector product, forward AD) or traditional VJP (vector-Jacobian product, aka reverse AD) methods. To investigate how stable and scalable the forward gradients method is (as the variance of the estimate is proportional to the number of parameters), you can increase --num_layers parameters.

Note: It seems like this doesn't efficiently scale beyond 10 layers because variance of the gradient estimation depends on number of parameters of the network.

Comparison

Requirements

  • JAX <3
  • optax (for learning rate scheduling)
  • wandb (optional, for logging)

Usage

To run the code and replicate MLP training with forward gradients on MNIST, simply execute the train.py :

python train.py

About

JAX implementation of "Gradients without Backpropagation" paper

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages