Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Implement CQL algorithm logic in new API stack. #47000

Merged
merged 41 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
360bde7
Added CQLLearner and CQLTorchLearner.
simonsays1980 Aug 2, 2024
1aeacea
Merge branch 'cql-learner' into cql-algorithm
simonsays1980 Aug 5, 2024
505a011
Added basic functionality for CQL in new API stack.
simonsays1980 Aug 5, 2024
4f02de1
LINTER.
simonsays1980 Aug 5, 2024
72acecf
Fixed two linter errors.
simonsays1980 Aug 5, 2024
d32b9a8
Added all functionalities to train CQL on the new API stack and teste…
simonsays1980 Aug 7, 2024
d79a8a7
Added tuned example for CQL to BUILD file and added data.
simonsays1980 Aug 7, 2024
03345c1
Merge branch 'master' into cql-learner
simonsays1980 Aug 7, 2024
ec109e1
MOdified loss calculation, specifically Q-value sampling in the CQL l…
simonsays1980 Aug 7, 2024
e599d5d
Merge branch 'cql-learner' into cql-algorithm
simonsays1980 Aug 7, 2024
3d7e353
Moved optimizing from 'compute_loss_for_module' to 'compute_gradients'.
simonsays1980 Aug 12, 2024
af3ac60
Changed logging training iterations to a simpler logic proposed by @s…
simonsays1980 Aug 12, 2024
afc7098
Merge branch 'master' into cql-learner
simonsays1980 Aug 12, 2024
f162bbc
Merge branch 'master' into cql-algorithm
simonsays1980 Aug 12, 2024
88fe775
Merge branch 'cql-learner' into cql-algorithm
simonsays1980 Aug 12, 2024
1f197ce
Added '__init__.py' to 'cql/torch'.
simonsays1980 Aug 12, 2024
a22c2b8
Moved connector pipeline setup to 'build_learner_connector' from lear…
simonsays1980 Aug 12, 2024
0aadee6
Changed 'joinpath' to '/' in poaths for examples of BC, MARWIL, CQL.
simonsays1980 Aug 12, 2024
8a824b0
Merged master.
simonsays1980 Aug 13, 2024
c77eab7
Merge branch 'master' into cql-algorithm
simonsays1980 Aug 13, 2024
38d8178
[RLlib; Offline RL] Implement twin-Q net option for CQL. (#47105)
simonsays1980 Aug 13, 2024
cc66a74
[core] remove unused GcsAio(Publisher|Subscriber) methods and subclas…
rynewang Aug 13, 2024
9227f8b
[Core] Fix a bug where we submit the actor creation task to the wrong…
jjyao Aug 13, 2024
f1eaddd
[doc][build] Update all changed files timestamp to latest (#47115)
khluu Aug 13, 2024
cc8bbbf
[serve] split `test_proxy.py` into unit and e2e tests (#47112)
zcin Aug 13, 2024
8824246
[Utility] add `env_float` utility into `ray._private.ray_constants` (…
hongpeng-guo Aug 13, 2024
07abc38
[Data] Fix progress bars not showing % progress (#47120)
scottjlee Aug 13, 2024
92b5b6b
[data] change data17 to datal (#47082)
aslonnie Aug 14, 2024
b63fe92
[ci] change data build for all python versions to arrow 17 (#47121)
can-anyscale Aug 14, 2024
f900ab3
Fixed a small buck that came with a change in validation.
simonsays1980 Aug 14, 2024
f119e8c
Merge branch 'master' into cql-algorithm
simonsays1980 Aug 15, 2024
805ac91
Merge branch 'master' into cql-algorithm
simonsays1980 Aug 15, 2024
ee21446
[core] make a note GetThreadContext() can't be used in fiber context …
hongchaodeng Aug 15, 2024
3d3bd24
docs: introduce quickstart button for image classification (#46715)
saihaj Aug 15, 2024
12c413d
[Data] Fix validation bug when size=0 in ActorPoolStrategy (#47072)
xingyu-long Aug 15, 2024
750b1ec
[Data] Fix exception in async map (#47110)
Bye-legumes Aug 15, 2024
8e0fd3d
[Data] Progress Bar: Sort sample in "rows" and remove the duplicate S…
Bye-legumes Aug 15, 2024
a7d4adf
Add iceberg datasource (#46889)
dev-goyal Aug 15, 2024
4a57617
[observability][export-api] Add event data schema for nodes (#47086)
nikitavemuri Aug 15, 2024
2bc0fd5
[autoscaler] fix import not used lint error on typing (#47163)
aslonnie Aug 16, 2024
7ff39ad
Merge branch 'master' into cql-algorithm
simonsays1980 Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,19 @@ py_test(
args = ["--dir=tuned_examples/cql"]
)

py_test(
name = "learning_tests_pendulum_cql",
main = "tuned_examples/cql/pendulum_cql.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "medium",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size="large" is better, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe. It takes around 50 iterations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will tune this in another PR.

srcs = ["tuned_examples/cql/pendulum_cql.py"],
# Include the zipped json data file as well.
data = [
"tests/data/pendulum/pendulum-v1_enormous",
],
args = ["--as-test", "--enable-new-api-stack"]
)

# DQN
py_test(
name = "learning_tests_cartpole_dqn",
Expand Down
126 changes: 125 additions & 1 deletion rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import logging
from typing import Optional, Type
from typing import Optional, Type, Union

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.sac.sac import (
Expand All @@ -23,15 +30,23 @@
)
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.metrics import (
ALL_MODULES,
LEARNER_RESULTS,
LEARNER_UPDATE_TIMER,
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
NUM_TARGET_UPDATES,
OFFLINE_SAMPLING_TIMER,
TARGET_NET_UPDATE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
SAMPLE_TIMER,
TIMERS,
)
from ray.rllib.utils.typing import ResultDict

Expand Down Expand Up @@ -122,6 +137,38 @@ def training(

return self

@override(SACConfig)
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
if self.framework_str == "torch":
from ray.rllib.algorithms.cql.torch.cql_torch_learner import CQLTorchLearner

return CQLTorchLearner
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use `'torch'` instead."
)

@override(AlgorithmConfig)
def build_learner_connector(
self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment in CQLLearner. I think we should move this there. Same as for DQN. Or if there are good arguments to leave these in the Config classes, we should also change it in DQN/SAC.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above. In rare cases the class inheritance avoids it being changed there and changing it in build_learner_connector would then still work.

You are right that we need to unify the way of adding them in all algorithms (also MARWIL)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just doesn't feel clean, doing it here. The config classes should contain as little (actually implemented) algo logic as possible and just store settings.

I know I sound like a broken record, but "separation of concerns" principle :):

  • Whatever an algo-specific Learner must-have, it should create this in its build() method, for example a lr-schedule, or a kl-coeff-schedule, etc..
  • Whatever an algo-specific Learner must-have, it should check that it has this in its build() method, e.g. "does my RLModule implement API xyz?"
  • Algo configs should NOT implement any algo logic. If they still do somewhere for some algos, we should add a TODO to move the logic into a more appropriate location.

Suggestion:
Can we try creating a CQLLearner class that inherits from SACLearner and overrides the build() method just to insert that connector piece? Then do class CQLTorchLearner(SACTorchLearner, CQLLearner):. That should work, no?

input_observation_space,
input_action_space,
device=None,
):
pipeline = super().build_learner_connector(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
device=device,
)

pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

return pipeline

@override(SACConfig)
def validate(self) -> None:
# First check, whether old `timesteps_per_iteration` is used.
Expand Down Expand Up @@ -150,6 +197,12 @@ def validate(self) -> None:
)
try_import_tfp(error=True)

# Assert that for a local learner the number of iterations is 1. Note,
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# this is needed because we have no iterators, but instead a single
# batch returned directly from the `OfflineData.sample` method.
if self.num_learners == 0 and not self.dataset_num_iters_per_learner:
self.dataset_num_iters_per_learner = 1


class CQL(SAC):
"""CQL (derived from SAC)."""
Expand All @@ -171,6 +224,77 @@ def get_default_policy_class(

@override(SAC)
def training_step(self) -> ResultDict:
if self.config.enable_env_runner_and_connector_v2:
return self._training_step_new_api_stack()
elif self.config.enable_rl_module_and_learner:
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Hybrid API stack is not supported. Either set "
"`enable_rl_module_and_learner=True` and "
"`enable_env_runner_and_connector_v2=True` or set both "
"attributed to `False`."
)
else:
return self._training_step_old_api_stack()

def _training_step_new_api_stack(self) -> ResultDict:

with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
# Sampling from offline data.
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
batch = self.offline_data.sample(
num_samples=self.config.train_batch_size_per_learner,
num_shards=self.config.num_learners,
return_iterator=True if self.config.num_learners > 1 else False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: return_iterator=self.config.num_learners > 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch :)

)

with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
# Updating the policy.
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# TODO (simon, sven): Check, if we should execute directly s.th. like
# update_from_iterator.
learner_results = self.learner_group.update_from_batch(
batch,
minibatch_size=self.config.train_batch_size_per_learner,
num_iters=self.config.dataset_num_iters_per_learner,
)

# Log training results.
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
self.metrics.log_value(
NUM_ENV_STEPS_TRAINED_LIFETIME,
self.metrics.peek(
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
reduce="sum",
)
self.metrics.log_dict(
{
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): (
stats[NUM_MODULE_STEPS_TRAINED]
)
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items()
},
reduce="sum",
)

# Synchronize weights.
# As the results contain for each policy the loss and in addition the
# total loss over all policies is returned, this total loss has to be
# removed.
modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}

# Update weights - after learning on the local worker -
# on all remote workers. Note, we only have the local `EnvRunner`,
# but from this `EnvRunner` the evaulation `EnvRunner`s get updated.
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
self.env_runner_group.sync_weights(
# Sync weights from learner_group to all EnvRunners.
from_worker_or_learner_group=self.learner_group,
policies=modules_to_update,
inference_only=True,
)

return self.metrics.reduce()

def _training_step_old_api_stack(self) -> ResultDict:
# Collect SampleBatches from sample workers.
with self._timers[SAMPLE_TIMER]:
train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
Expand Down
20 changes: 20 additions & 0 deletions rllib/algorithms/cql/cql_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.sac.sac_learner import SACLearner
from ray.rllib.core.learner.learner import Learner
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import ALL_MODULES


class CQLLearner(SACLearner):
@override(Learner)
def build(self) -> None:
# We need to call the `super()`'s `build` method here to have the variables
# for `alpha`` and the target entropy defined.
super().build()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it's better to add the NEXT_OBS LearnerConnector piece here, what do you think?

The CQLLearner is the component that has this (hard) requirement (meanin it won't work w/o this connector piece), so it should take care of adding it here.
Similar to how it would check, whether the used RLModule fits a certain API, if required.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the Learner is the component that needs the connectors. On the other side it uses the algorithm.build_learner_connector to build its connector pipeline. So, I am a bit unsure, where we should generally pack it.

In this specific case I needed to add it in the build_learner_connector b/c the class inheritance was somehow avoiding it otherwise (i.e. the CQLLearner had it, but the TorchCQLLearner did not.

# Add a metric to keep track of training iterations to
# determine when switching the actor loss from behavior
# cloning to SAC.
self.metrics.log_value(
(ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Add a metric to keep track of training iterations to
# determine when switching the actor loss from behavior
# cloning to SAC.
self.metrics.log_value(
(ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1
)
# Add a metric to keep track of training iterations to
# determine when switching the actor loss from behavior
# cloning to SAC.
self.metrics.log_value(
(ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can all go, b/c we now use the default=0 arg.

Empty file.
Loading