-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Preparatory PR: Make EnvRunners use (enhanced) Connector API (#01: mostly cleanups and small fixes) #41074
[RLlib] Preparatory PR: Make EnvRunners use (enhanced) Connector API (#01: mostly cleanups and small fixes) #41074
Conversation
Signed-off-by: sven1977 <[email protected]>
@@ -1932,7 +1932,10 @@ def compute_actions( | |||
filtered_obs, filtered_state = [], [] | |||
for agent_id, ob in observations.items(): | |||
worker = self.workers.local_worker() | |||
preprocessed = worker.preprocessors[policy_id].transform(ob) | |||
if worker.preprocessors.get(policy_id) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bug fix.
@@ -319,26 +319,37 @@ def __init__(self, algo_class=None): | |||
# If not specified, we will try to auto-detect this. | |||
self._is_atari = None | |||
|
|||
# TODO (sven): Rename this method into `AlgorithmConfig.sampling()` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we are aiming for a the EnvRunner API as the default, we should rename/clarify some of these config settings and methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider loading a checkpoint here? Are these renaming backward compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there even a story around this? Like can people even move from rllib 2+ to 3?
rllib/core/models/torch/encoder.py
Outdated
@@ -285,30 +285,31 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: | |||
bias=config.use_bias, | |||
) | |||
|
|||
self.state_in_out_spec = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified (repetitive) code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this private attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -212,75 +212,75 @@ def get_observations( | |||
|
|||
return self._getattr_by_index("observations", indices, global_ts) | |||
|
|||
def get_actions( | |||
def get_infos( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reordered:
- obs, infos (<- env.reset data)
- action, reward, terminated/truncated (<- other env.step results)
- extra model outs
rllib/env/single_agent_env_runner.py
Outdated
gym.register( | ||
"custom-env-v0", | ||
partial( | ||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug fix.
@@ -690,6 +690,9 @@ def foreach_worker( | |||
if local_worker and self.local_worker() is not None: | |||
local_result = [func(self.local_worker())] | |||
|
|||
if not self.__worker_manager.actor_ids(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shortcut for local-worker only case.
@@ -30,7 +30,7 @@ multi-agent-cartpole-crashing-appo: | |||
# Switch on resiliency for failed sub environments (within a vectorized stack). | |||
restart_failed_sub_environments: true | |||
|
|||
# Switch on evaluation workers being managed by AsyncRequestsManager object. | |||
# Switch on asynchronous handling of evaluation workers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AsyncRequestsManager
doesn't exist anymore.
@@ -205,6 +205,56 @@ def flatten_to_single_ndarray(input_): | |||
return input_ | |||
|
|||
|
|||
@DeveloperAPI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very useful new utility. Inverse of already existing unbatch
utility.
@@ -319,26 +319,37 @@ def __init__(self, algo_class=None): | |||
# If not specified, we will try to auto-detect this. | |||
self._is_atari = None | |||
|
|||
# TODO (sven): Rename this method into `AlgorithmConfig.sampling()` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider loading a checkpoint here? Are these renaming backward compatible?
@@ -319,26 +319,37 @@ def __init__(self, algo_class=None): | |||
# If not specified, we will try to auto-detect this. | |||
self._is_atari = None | |||
|
|||
# TODO (sven): Rename this method into `AlgorithmConfig.sampling()` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there even a story around this? Like can people even move from rllib 2+ to 3?
rllib/core/models/torch/encoder.py
Outdated
@@ -285,30 +285,31 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: | |||
bias=config.use_bias, | |||
) | |||
|
|||
self.state_in_out_spec = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this private attribute?
rllib/utils/spaces/space_utils.py
Outdated
@@ -205,6 +205,56 @@ def flatten_to_single_ndarray(input_): | |||
return input_ | |||
|
|||
|
|||
@DeveloperAPI | |||
def batch(list_of_structs, individual_items_already_have_batch_1: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data types please (for input and output)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have unittest of this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done and done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also enhanced the docstring to make the example and explanations more clear.
flat = [[] for _ in range(len(flattened_item))] | ||
for i, value in enumerate(flattened_item): | ||
flat[i].append(value) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add:
if item is None:
raise ValueError("Input list_of_structs does not contain valid structs.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
in this struct represents the batch for a single component | ||
(in case struct is tuple/dict). Alternatively, a simple batch of | ||
primitives (non tuple/dict) might be returned. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add
if not list_of_structs:
raise ValueError("Input list_of_structs is empty.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Signed-off-by: sven1977 <[email protected]>
Thanks for the review @kouroshHakha ! Waiting for tests to pass ... |
Signed-off-by: sven1977 <[email protected]>
Preparatory PR: Make EnvRunners use (enhanced) Connector API (#1: mostly cleanups and small fixes)
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.