Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax BNN is several times slower in JAX 0.4.33 compared to JAX 0.4.31 #1867

Open
ziatdinovmax opened this issue Sep 25, 2024 · 1 comment
Open
Labels
discussion jax This issue is specific to JAX

Comments

@ziatdinovmax
Copy link

Jax-0.4.31: Runtime: 27.06 seconds
https://colab.research.google.com/drive/1EsFY1St8Y2ZNBZ9UXTa9FDWrjPDdTU4U?usp=sharing

Jax-0.4.33: Runtime: 84.91 seconds
https://colab.research.google.com/drive/1g7GkuK4-GloO6cywvDUf5BVU9qO2jf1W?usp=sharing

I’m not sure if this issue is specific to flax_random_module or a broader problem, but I’ve primarily been using NumPyro for HMC BNNs, and the difference in speed with the latest JAX release is quite dramatic

Code:

import time
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
from numpyro.contrib.module import random_flax_module
import flax.linen as nn


# Set a random seed for reproducibility
rng_key = jax.random.PRNGKey(0)

# Generate some dummy data
def generate_data(n=100, noise_std=0.1):
    X = jnp.linspace(-1, 1, n)
    y = 3 * X + 2 + np.random.normal(0, noise_std, size=X.shape)
    return X[:, None], y

# Define a simple neural network
class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x.squeeze()

# Define the model
def model(X, y):
    module = SimpleNN()
    nn = random_flax_module("nn", module, input_shape=(1, X.shape[-1]), prior=dist.Normal(0, 1))

    with numpyro.plate("data", X.shape[0]):
        mean = nn(X)
        numpyro.sample("obs", dist.Normal(mean, 0.1), obs=y)

# Generate data
X, y = generate_data()

# Initialize the NUTS sampler
nuts_kernel = NUTS(model)

# Run inference
num_warmup, num_samples = 500, 1000

start_time = time.time()

mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key, X, y)

end_time = time.time()

# Print runtime
print(f"Runtime: {end_time - start_time:.2f} seconds")

# Print summary statistics
print(mcmc.print_summary())
@tillahoffmann
Copy link
Contributor

Hey, this may be the same issue as jax-ml/jax#23822.

@fehiepsi fehiepsi added discussion jax This issue is specific to JAX labels Sep 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion jax This issue is specific to JAX
Projects
None yet
Development

No branches or pull requests

3 participants