Skip to content

Commit

Permalink
test: modify smoke tests to sample from action spec
Browse files Browse the repository at this point in the history
  • Loading branch information
aar65537 committed Apr 8, 2023
1 parent 7bd1f93 commit 275fbc7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 90 deletions.
36 changes: 7 additions & 29 deletions jumanji/environments/packing/bin_pack/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import chex
import jax
import jax.numpy as jnp
import numpy as np
import pytest

from jumanji import tree_utils
Expand All @@ -33,29 +32,11 @@
item_from_space,
location_from_space,
)
from jumanji.testing.env_not_smoke import SelectActionFn, check_env_does_not_smoke
from jumanji.testing.env_not_smoke import check_env_does_not_smoke
from jumanji.testing.pytrees import assert_is_jax_array_tree
from jumanji.types import TimeStep


@pytest.fixture
def bin_pack_random_select_action(bin_pack: BinPack) -> SelectActionFn:
num_ems, num_items = np.asarray(bin_pack.action_spec().num_values)

def select_action(key: chex.PRNGKey, observation: Observation) -> chex.Array:
"""Randomly sample valid actions, as determined by `observation.action_mask`."""
ems_item_id = jax.random.choice(
key=key,
a=num_ems * num_items,
p=observation.action_mask.flatten(),
)
ems_id, item_id = jnp.divmod(ems_item_id, num_items)
action = jnp.array([ems_id, item_id], jnp.int32)
return action

return jax.jit(select_action) # type: ignore


@pytest.fixture(scope="function")
def normalize_dimensions(request: pytest.mark.FixtureRequest) -> bool:
return request.param # type: ignore
Expand Down Expand Up @@ -160,25 +141,22 @@ def test_bin_pack__render_does_not_smoke(bin_pack: BinPack, dummy_state: State)
bin_pack.close()


def test_bin_pack__does_not_smoke(
bin_pack: BinPack,
bin_pack_random_select_action: SelectActionFn,
) -> None:
def test_bin_pack__does_not_smoke(bin_pack: BinPack) -> None:
"""Test that we can run an episode without any errors."""
check_env_does_not_smoke(bin_pack, bin_pack_random_select_action)
check_env_does_not_smoke(bin_pack)


def test_bin_pack__pack_all_items_dummy_instance(
bin_pack: BinPack, bin_pack_random_select_action: SelectActionFn
) -> None:
def test_bin_pack__pack_all_items_dummy_instance(bin_pack: BinPack) -> None:
"""Functional test to check that the dummy instance can be completed with a random agent."""
step_fn = jax.jit(bin_pack.step)
key = jax.random.PRNGKey(0)
state, timestep = bin_pack.reset(key)

while not timestep.last():
action_key, key = jax.random.split(key)
action = bin_pack_random_select_action(action_key, timestep.observation)
action = bin_pack.action_spec().sample(
action_key, timestep.observation.action_mask
)
state, timestep = step_fn(state, action)

assert jnp.array_equal(state.items_placed, state.items_mask)
Expand Down
16 changes: 2 additions & 14 deletions jumanji/environments/routing/cleaner/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jumanji.environments.routing.cleaner.constants import CLEAN, DIRTY, WALL
from jumanji.environments.routing.cleaner.env import Cleaner
from jumanji.environments.routing.cleaner.generator import Generator
from jumanji.environments.routing.cleaner.types import Observation, State
from jumanji.environments.routing.cleaner.types import State
from jumanji.testing.env_not_smoke import check_env_does_not_smoke
from jumanji.testing.pytrees import assert_is_jax_array_tree
from jumanji.types import StepType, TimeStep
Expand Down Expand Up @@ -176,19 +176,7 @@ def test_cleaner__action_mask(self, cleaner: Cleaner, key: chex.PRNGKey) -> None
assert jnp.all(action_mask[2] == jnp.array([False, False, False, True]))

def test_cleaner__does_not_smoke(self, cleaner: Cleaner) -> None:
def select_actions(key: chex.PRNGKey, observation: Observation) -> chex.Array:
@jax.vmap # map over the keys and agents
def select_action(
key: chex.PRNGKey, agent_action_mask: chex.Array
) -> chex.Array:
return jax.random.choice(
key, jnp.arange(4), p=agent_action_mask.flatten()
)

subkeys = jax.random.split(key, cleaner.num_agents)
return select_action(subkeys, observation.action_mask)

check_env_does_not_smoke(cleaner, select_actions)
check_env_does_not_smoke(cleaner)

def test_cleaner__compute_extras(self, cleaner: Cleaner, key: chex.PRNGKey) -> None:
state, _ = cleaner.reset(key)
Expand Down
55 changes: 8 additions & 47 deletions jumanji/testing/env_not_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, TypeVar, Union
from typing import Callable, Optional, TypeVar

import chex
import jax
import jax.numpy as jnp

from jumanji import specs
from jumanji.env import Environment
Expand All @@ -26,42 +25,13 @@
SelectActionFn = Callable[[chex.PRNGKey, Observation], Action]


def make_random_select_action_fn(
action_spec: Union[
specs.BoundedArray, specs.DiscreteArray, specs.MultiDiscreteArray
]
) -> SelectActionFn:
def make_random_select_action_fn(action_spec: specs.Spec) -> SelectActionFn:
"""Create select action function that chooses random actions."""

def select_action(key: chex.PRNGKey, state: chex.ArrayTree) -> chex.ArrayTree:
del state
if (
isinstance(action_spec, specs.DiscreteArray)
or isinstance(action_spec, specs.MultiDiscreteArray)
or jnp.issubdtype(action_spec.dtype, jnp.integer)
):
action = jax.random.randint(
key=key,
shape=action_spec.shape,
minval=action_spec.minimum,
maxval=action_spec.maximum + 1,
dtype=action_spec.dtype,
)
elif isinstance(action_spec, specs.BoundedArray):
assert jnp.issubdtype(action_spec.dtype, jnp.floating)
action = jax.random.uniform(
key=key,
shape=action_spec.shape,
dtype=action_spec.dtype,
minval=action_spec.minimum,
maxval=action_spec.maximum,
)
else:
raise ValueError(
"Only supported for action specs of type `specs.BoundedArray, "
"specs.DiscreteArray or specs.MultiDiscreteArray`."
)
return action
def select_action(key: chex.PRNGKey, observation: chex.ArrayTree) -> chex.ArrayTree:
if hasattr(observation, "action_mask"):
return action_spec.sample(key, observation.action_mask)
return action_spec.sample(key)

return select_action

Expand All @@ -74,17 +44,8 @@ def check_env_does_not_smoke(
"""Run an episode of the environment, with a jitted step function to check no errors occur."""
action_spec = env.action_spec()
if select_action is None:
if isinstance(action_spec, specs.BoundedArray) or isinstance(
action_spec, specs.DiscreteArray
):
select_action = make_random_select_action_fn(action_spec)
else:
raise NotImplementedError(
f"Currently the `make_random_select_action_fn` only works for environments with "
f"either discrete actions or bounded continuous actions. The input environment to "
f"this test has an action spec of type {action_spec}, and therefore requires "
f"a custom `SelectActionFn` to be provided to this test."
)
select_action = make_random_select_action_fn(action_spec)

key = jax.random.PRNGKey(0)
key, reset_key = jax.random.split(key)
state, timestep = env.reset(reset_key)
Expand Down

0 comments on commit 275fbc7

Please sign in to comment.