-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] Support prev_state/prev_action in rollout and fix multiagent #4565
Conversation
Can one of the admins verify this patch? |
Test FAILed. |
policy_map = agent.local_evaluator.policy_map | ||
state_init = {p: m.get_initial_state() for p, m in policy_map.items()} | ||
use_lstm = {p: len(s) > 0 for p, s in state_init.items()} | ||
action_init = { | ||
p: m.action_space.sample() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that in the actual code we feed the all-zeros action initially
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that make sense? All zeros might not even be in the action space?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a good question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine for now, though it would be better to set it to zeros for consistency (or switch the sampler to stick in a random initial action, but that might be weird).
python/ray/rllib/rollout.py
Outdated
@@ -124,37 +140,46 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): | |||
while steps < (num_steps or steps + 1): | |||
if out is not None: | |||
rollout = [] | |||
state = env.reset() | |||
obs = env.reset() | |||
multi_obs = obs if multiagent else {0: obs} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could use https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/base_env.py#L196 instead of 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -124,37 +140,46 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): | |||
while steps < (num_steps or steps + 1): | |||
if out is not None: | |||
rollout = [] | |||
state = env.reset() | |||
obs = env.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the mapping cache be reset as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think that's appropriate.
python/ray/rllib/rollout.py
Outdated
policy_agent_mapping = agent.config["multiagent"][ | ||
"policy_mapping_fn"] | ||
mapping_cache = {} | ||
else: | ||
policy_agent_mapping = lambda _: DEFAULT_POLICY_ID |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Reset the mapping cache at the start of each episode.
Test FAILed. |
Test FAILed. |
Test FAILed. |
Seems this causes rllib/tests/test_rollout.sh to raise an error:
|
I believe step() needs to also be contingent on multiagent, i.e. and line #172
|
Single-agent envs should work now (test_rollout.py passes). |
Test FAILed. |
Test failure doesn't seem related to this PR:
|
FileNotFoundError was added in Python 3. In Python 2 use OSError instead. |
Latest version still is updating |
@waldroje You are correct, good catch (if only python had types...). Fixed now. |
Test FAILed. |
Lint unrelated. |
Thanks! |
What do these changes do?
Fixes a few issues with rollout.py:
The code has also been simplified and should be more readable.
Closes #4573
Linter
scripts/format.sh
to lint the changes in this PR.