Skip to content

Commit

Permalink
jax.random.PRNGKey(...) -> jax.random.key(..)
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Sep 24, 2023
1 parent 89f3d94 commit a17be41
Show file tree
Hide file tree
Showing 12 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
for _ in range(100):
rng_key, nuts_key = jax.random.split(rng_key)
state, _ = nuts.step(nuts_key, state)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_custom_gradients.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def logdensity_fn(y):
hmc = blackjax.hmc(logdensity_fn,1e-3, jnp.ones(1), 10)
state = hmc.init(1.)
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
new_state, info = hmc.step(rng_key, state)
```

Expand Down
4 changes: 2 additions & 2 deletions docs/examples/howto_metropolis_within_gibbs.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def sampling_loop(rng_key, initial_state, parameters, num_samples):

```{code-cell} ipython3
%%time
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
positions = sampling_loop(rng_key, initial_state, parameters, 10_000)
```

Expand Down Expand Up @@ -305,7 +305,7 @@ def sampling_loop_general(rng_key, initial_state, logdensity_fn, step_fn, init,

```{code-cell} ipython3
%%time
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
positions_general = sampling_loop_general(
rng_key=rng_key,
initial_state=initial_state,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_other_frameworks.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ step_size=1e-3
nuts = blackjax.nuts(numba_logpdf, step_size, inverse_mass_matrix)
init = nuts.init(0.)
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
state, info = nuts.step(rng_key, init)
for _ in range(10):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_sample_multiple_chains.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ And finally, to put `jax.vmap` and `jax.pmap` on an equal foot we sample as many
import multiprocessing
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
num_chains = multiprocessing.cpu_count()
```

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_aesara.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def init_param_fn(seed):
"thetas": jax.random.uniform(seed, (n_rat_tumors,), "float64", minval=0, maxval=1),
}
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
init_position = init_param_fn(rng_key)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_numpyro.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import jax
from numpyro.infer.util import initialize_model
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
init_params, potential_fn_gen, *_ = initialize_model(
rng_key,
eight_schools_noncentered,
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/howto_use_oryx.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ from oryx.core.ppl import joint_sample
bnn = mlp([50, 50], num_classes)
initial_weights = joint_sample(bnn)(jax.random.PRNGKey(0), jnp.ones(num_features))
initial_weights = joint_sample(bnn)(jax.random.key(0), jnp.ones(num_features))
print(initial_weights.keys())
```
Expand All @@ -136,7 +136,7 @@ We can now run the window adaptation to get good values for the parameters of th
%%time
import blackjax
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(rng_key, initial_weights, 100)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
Expand Down Expand Up @@ -173,7 +173,7 @@ posterior_weights = states.position
output_logits = jax.vmap(
lambda weights: jax.vmap(lambda x: intervene(bnn, **weights)(
jax.random.PRNGKey(0), x)
jax.random.key(0), x)
)(features)
)(posterior_weights)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_pymc.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import jax
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
rng_key = jax.random.PRNGKey(1234)
rng_key = jax.random.key(1234)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(rng_key, init_position, 1000)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_tfp.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ initial_position = {
}
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
adapt = blackjax.window_adaptation(
blackjax.hmc, logdensity_fn, num_integration_steps=3
)
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def inference_loop(rng_key, kernel, initial_state, num_samples):

```{code-cell} python
%%time
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 10_000)
loc_samples = states.position["loc"].block_until_ready()
Expand Down Expand Up @@ -136,7 +136,7 @@ initial_state

```{code-cell} python
%%time
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
states = inference_loop(rng_key, nuts.step, initial_state, 4_000)
loc_samples = states.position["loc"].block_until_ready()
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)
# Iterate
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
step = jax.jit(nuts.step)
for _ in range(1_000):
rng_key, nuts_key = jax.random.split(rng_key)
Expand Down

0 comments on commit a17be41

Please sign in to comment.