Skip to content

Commit

Permalink
fix: issue with can_sample not being jittable
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Sep 20, 2024
1 parent 57f77ca commit 989976c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
22 changes: 13 additions & 9 deletions flashbax/buffers/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
import jax.numpy as jnp
from chex import Numeric, dataclass
from jax import Array
from jax.tree_util import tree_map

from flashbax.buffers.flat_buffer import TransitionSample
Expand Down Expand Up @@ -140,25 +141,28 @@ def sample_mixer_fn(


def can_sample_mixer_fn(
states: Sequence[StateTypes], can_sample_fns: Sequence[Callable[[StateTypes], bool]]
) -> bool:
states: Sequence[StateTypes],
can_sample_fns: Sequence[Callable[[StateTypes], Array]],
) -> Array:
"""Check if all buffers can sample.
Args:
states (Sequence[StateTypes]): list of buffer states
can_sample_fns (Sequence[Callable[[StateTypes], bool]]): list of can_sample functions
can_sample_fns (Sequence[Callable[[StateTypes], Array]]): list of can_sample functions
from each buffer
Returns:
bool: whether all buffers can sample
"""
each_can_sample = tree_map(
lambda state, can_sample: can_sample(state),
states,
can_sample_fns,
is_leaf=lambda leaf: type(leaf) in state_types,
each_can_sample = jnp.asarray(
tree_map(
lambda state, can_sample: can_sample(state),
states,
can_sample_fns,
is_leaf=lambda leaf: type(leaf) in state_types,
)
)
return all(each_can_sample)
return jnp.all(each_can_sample)


def make_mixer(
Expand Down
4 changes: 4 additions & 0 deletions flashbax/buffers/mixer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ def test_mixed_buffer_does_not_smoke(
proportions=proportions,
sample_batch_size=sample_batch_size,
)

can_sample = jax.jit(mixer.can_sample)(buffer_states)
assert can_sample

samples = jax.jit(mixer.sample)(buffer_states, rng_key)

assert samples is not None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ dev = [
'mkdocs-mermaid2-plugin==1.1.1',
'mkdocstrings[python]==0.23.0',
'mknotebooks==0.8.0',
'mypy>=0.982',
'mypy>=1.8.0',
'pre-commit>=2.20.0',
'pytest>=7.4.2',
'pytest-cov>=4.00',
Expand Down

0 comments on commit 989976c

Please sign in to comment.