-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in navigation task #47
Conversation
When the goal probability is one the whole array should be true instead of false to allow for rewards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @jysdoran, for spotting the issue! Really appreciated the time you dedicated to contribute.
I found another bug while reviewing this -- added some comments.
Would you mind adding also a unittest to make sure this remains fixed in the future?
I put together a quick one but feel free to improve in it if you find a better one.
I can pick it up if you don't have time!
def test_navigation():
"""Unittest for https://github.com/epignatelli/navix/pull/47"""
height = 10
width = 10
grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32)
grid = jnp.pad(grid, 1, mode="constant", constant_values=-1)
players = Player(
position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID
)
goals = Goal(position=jnp.asarray([(1, 1), (1, 1)]), probability=jnp.asarray([0.0, 0.0]))
keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0))
doors = Door(
position=jnp.asarray([(1, 5), (1, 6)]),
direction=jnp.asarray((0, 2)),
requires=jnp.asarray((0, 0)),
open=jnp.asarray((False, True)),
)
entities = {
Entities.PLAYER: players[None],
Entities.GOAL: goals,
Entities.KEY: keys[None],
Entities.DOOR: doors,
}
state = nx.entities.State(
key=jax.random.PRNGKey(0),
grid=grid,
cache=nx.graphics.RenderingCache.init(grid),
entities=entities,
)
action = jnp.asarray(0)
reward = nx.tasks.navigation(state, action, state)
assert jnp.array_equal(reward, jnp.asarray(0.0))
Thanks @jysdoran for the PR again, I just pushed the other fixes, add the unitest and bumped the version to The fix should come online in a few minutes. |
When the goal probability is one the whole array should be true instead of false to allow for rewards