-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] Implement CQL algorithm logic in new API stack. #47000
Conversation
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…d it. Made some smaller typing changes in learners and added a tuned example for CQL in the new API stack. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…oss and switched in actor loss from selected actions to sampled actions from the current policy. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…ven1977. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
rllib/BUILD
Outdated
name = "learning_tests_pendulum_cql", | ||
main = "tuned_examples/cql/pendulum_cql.py", | ||
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], | ||
size = "medium", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size="large" is better, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, maybe. It takes around 50 iterations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will tune this in another PR.
rllib/algorithms/cql/cql_learner.py
Outdated
# Add a metric to keep track of training iterations to | ||
# determine when switching the actor loss from behavior | ||
# cloning to SAC. | ||
self.metrics.log_value( | ||
(ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Add a metric to keep track of training iterations to | |
# determine when switching the actor loss from behavior | |
# cloning to SAC. | |
self.metrics.log_value( | |
(ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1 | |
) | |
# Add a metric to keep track of training iterations to | |
# determine when switching the actor loss from behavior | |
# cloning to SAC. | |
self.metrics.log_value( | |
(ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1 | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can all go, b/c we now use the default=0
arg.
rllib/algorithms/cql/cql_learner.py
Outdated
# We need to call the `super()`'s `build` method here to have the variables | ||
# for `alpha`` and the target entropy defined. | ||
super().build() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like it's better to add the NEXT_OBS LearnerConnector piece here, what do you think?
The CQLLearner is the component that has this (hard) requirement (meanin it won't work w/o this connector piece), so it should take care of adding it here.
Similar to how it would check, whether the used RLModule fits a certain API, if required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the Learner is the component that needs the connectors. On the other side it uses the algorithm.build_learner_connector
to build its connector pipeline. So, I am a bit unsure, where we should generally pack it.
In this specific case I needed to add it in the build_learner_connector
b/c the class inheritance was somehow avoiding it otherwise (i.e. the CQLLearner
had it, but the TorchCQLLearner
did not.
@@ -320,7 +320,7 @@ def get_default_policy_class( | |||
@override(Algorithm) | |||
def training_step(self) -> ResultDict: | |||
if self.config.enable_env_runner_and_connector_v2: | |||
return self._training_step_new_stack() | |||
return self._training_step_new_api_stack() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying this! Always better to be expressive :)
@@ -278,7 +278,7 @@ def compute_loss_for_module( | |||
|
|||
@override(DQNRainbowTorchLearner) | |||
def compute_gradients( | |||
self, loss_per_module: Dict[str, TensorType], **kwargs | |||
self, loss_per_module: Dict[ModuleID, TensorType], **kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
@@ -442,7 +442,7 @@ def configure_optimizers_for_module( | |||
@OverrideToImplementCustomLogic | |||
@abc.abstractmethod | |||
def compute_gradients( | |||
self, loss_per_module: Dict[str, TensorType], **kwargs | |||
self, loss_per_module: Dict[ModuleID, TensorType], **kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
|
||
|
||
stop = { | ||
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -700.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we change the test to "size=large", would the return get a bit higher?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tried this. Will do.
|
||
@override(AlgorithmConfig) | ||
def build_learner_connector( | ||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my comment in CQLLearner. I think we should move this there. Same as for DQN. Or if there are good arguments to leave these in the Config classes, we should also change it in DQN/SAC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above. In rare cases the class inheritance avoids it being changed there and changing it in build_learner_connector
would then still work.
You are right that we need to unify the way of adding them in all algorithms (also MARWIL)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just doesn't feel clean, doing it here. The config classes should contain as little (actually implemented) algo logic as possible and just store settings.
I know I sound like a broken record, but "separation of concerns" principle :):
- Whatever an algo-specific Learner must-have, it should create this in its
build()
method, for example a lr-schedule, or a kl-coeff-schedule, etc.. - Whatever an algo-specific Learner must-have, it should check that it has this in its
build()
method, e.g. "does my RLModule implement API xyz?" - Algo configs should NOT implement any algo logic. If they still do somewhere for some algos, we should add a TODO to move the logic into a more appropriate location.
Suggestion:
Can we try creating a CQLLearner
class that inherits from SACLearner
and overrides the build()
method just to insert that connector piece? Then do class CQLTorchLearner(SACTorchLearner, CQLLearner):
. That should work, no?
…ct#47121) Following up from ray-project#47082, we actually have 6 different data builds, with this matrix ``` python 3.9 python 3.12 arrow 6 X X arrow 17 X X arrow nightly X X ``` They all share the same build environment (https://github.com/ray-project/ray/blob/master/ci/docker/data.build.Dockerfile), but we have 6 configurations of these build environments given the above matrix This PR updates other flavors to use arrow 17 as well Test: - CI Signed-off-by: can <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…ay-project#47151) Add a comment on worker_context_.GetCurrentTask() that we can't use it in fiber context. Signed-off-by: hongchaodeng <[email protected]>
…ct#46715) ![CleanShot 2024-08-13 at 21 02 00@2x](https://github.com/user-attachments/assets/96d02613-3523-45e2-bf6f-8fbaa2e8ac51) --------- Signed-off-by: Saihajpreet Singh <[email protected]>
…ct#47072) To resolve the issue 40393 and we need to check if size is not None beforehand. Signed-off-by: Xingyu Long <[email protected]>
close ray-project#47102 --------- Signed-off-by: zhilong <[email protected]>
…ort sample. (ray-project#47106) <!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Currently, the sort sample is not in rows and there is a duplicate sort sample progress bar. ![image](https://github.com/user-attachments/assets/30aa9fc3-8e96-473e-a794-da4fc023093a) With this modification, Sort sample will be also in rows and the additional progress bar will be removed. ![image](https://github.com/user-attachments/assets/f0a3e5b6-3f84-4993-9f03-36a350aa47b0) In fact there should only one sort sample progress bar which is created at https://github.com/ray-project/ray/blob/e066289b374464f1e2692382fdea871eb34e3156/python/ray/data/_internal/planner/exchange/sort_task_spec.py#L166 while the one created in ``` sub_progress_bar_names=[ SortTaskSpec.SORT_SAMPLE_SUB_PROGRESS_BAR_NAME, ExchangeTaskSpec.MAP_SUB_PROGRESS_BAR_NAME, ExchangeTaskSpec.REDUCE_SUB_PROGRESS_BAR_NAME, ], ``` should be deleted. ## Related issue number <!-- For example: "Closes ray-project#1234" --> ## Checks - [√] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [√] I've run `scripts/format.sh` to lint the changes in this PR. - [ √ I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [√] Unit tests - [√] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: zhilong <[email protected]>
## Why are these changes needed? This PR adds the capability to load an Iceberg table into a Ray Dataset. Compared to the PyIceberg functionality, which can only materialize the entire Iceberg table into a single `pyarrow` table first, which is then converted to a Ray dataset, this PR allows a streaming implementation, where the Iceberg table can be distributed into a Ray Dataset. ## Related issue number ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [x] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Dev <[email protected]> Signed-off-by: dev-goyal <[email protected]> Signed-off-by: Alan Guo <[email protected]> Signed-off-by: tungh2 <[email protected]> Signed-off-by: Jiajun Yao <[email protected]> Signed-off-by: Scott Lee <[email protected]> Signed-off-by: Ruiyang Wang <[email protected]> Signed-off-by: Galen Wang <[email protected]> Signed-off-by: Shilun Fan <[email protected]> Signed-off-by: Deepyaman Datta <[email protected]> Signed-off-by: Matthew Owen <[email protected]> Signed-off-by: cristianjd <[email protected]> Signed-off-by: Justin Yu <[email protected]> Co-authored-by: Fokko Driesprong <[email protected]> Co-authored-by: Sven Mika <[email protected]> Co-authored-by: Alan Guo <[email protected]> Co-authored-by: tungh2 <[email protected]> Co-authored-by: Jiajun Yao <[email protected]> Co-authored-by: Scott Lee <[email protected]> Co-authored-by: Ruiyang Wang <[email protected]> Co-authored-by: Galen Wang <[email protected]> Co-authored-by: Max van Dijck <[email protected]> Co-authored-by: slfan1989 <[email protected]> Co-authored-by: Deepyaman Datta <[email protected]> Co-authored-by: Samuel Chan <[email protected]> Co-authored-by: Matthew Owen <[email protected]> Co-authored-by: cristianjd <[email protected]> Co-authored-by: Justin Yu <[email protected]>
…ect#47086) Add ExportNodeData proto schema which contains a subset of fields from GCSNodeInfo that are used to populate the dashboard APIs (https://docs.google.com/document/d/1qjoF51h2oUN2sr_MtPnovbNFZYZrh3WLNR_P0HrUuOI/edit)
) annotate type in code rather than in comments Signed-off-by: Lonnie Liu <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Why are these changes needed?
This PR proposes a training logic for CQL in our new API stack using
RLModule
s,Learner API
,Offline RL API
, andEnvRunner API
. More specifically CQL on the new stack usesOfflineData
to read and batch training data andRLModule
s andLearner API
to define and train a policy. More specifically, it inherits most of its model logic from SAC as it implements the entropy version of CQL. To addNEXT_OBS
to the batch it overrides theAlgorithmConfig
'sbuild_learner_connector
and proposes therewith a new form how the learner connector should be modified (in contrast to adding more connectors to the learner's pipeline in the learner'sbuild
method.Furthermore, this PR adds a "tuned example" using
Pendulum-v1
that shows how CQL can be used on the new API stack.This PR is a part of a sequence of PRs porposed and coming.
Related issue number
Relates to #46969 and closes #37779
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.