-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] Implement CQL algorithm logic in new API stack. #47000
Changes from 16 commits
360bde7
1aeacea
505a011
4f02de1
72acecf
d32b9a8
d79a8a7
03345c1
ec109e1
e599d5d
3d7e353
af3ac60
afc7098
f162bbc
88fe775
1f197ce
a22c2b8
0aadee6
8a824b0
c77eab7
38d8178
cc66a74
9227f8b
f1eaddd
cc8bbbf
8824246
07abc38
92b5b6b
b63fe92
f900ab3
f119e8c
805ac91
ee21446
3d3bd24
12c413d
750b1ec
8e0fd3d
a7d4adf
4a57617
2bc0fd5
7ff39ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,14 @@ | ||
import logging | ||
from typing import Optional, Type | ||
from typing import Optional, Type, Union | ||
|
||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided | ||
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( | ||
AddObservationsFromEpisodesToBatch, | ||
) | ||
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa | ||
AddNextObservationsFromEpisodesToTrainBatch, | ||
) | ||
from ray.rllib.core.learner.learner import Learner | ||
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy | ||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy | ||
from ray.rllib.algorithms.sac.sac import ( | ||
|
@@ -23,15 +30,23 @@ | |
) | ||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp | ||
from ray.rllib.utils.metrics import ( | ||
ALL_MODULES, | ||
LEARNER_RESULTS, | ||
LEARNER_UPDATE_TIMER, | ||
LAST_TARGET_UPDATE_TS, | ||
NUM_AGENT_STEPS_SAMPLED, | ||
NUM_AGENT_STEPS_TRAINED, | ||
NUM_ENV_STEPS_SAMPLED, | ||
NUM_ENV_STEPS_TRAINED, | ||
NUM_ENV_STEPS_TRAINED_LIFETIME, | ||
NUM_MODULE_STEPS_TRAINED, | ||
NUM_MODULE_STEPS_TRAINED_LIFETIME, | ||
NUM_TARGET_UPDATES, | ||
OFFLINE_SAMPLING_TIMER, | ||
TARGET_NET_UPDATE_TIMER, | ||
SYNCH_WORKER_WEIGHTS_TIMER, | ||
SAMPLE_TIMER, | ||
TIMERS, | ||
) | ||
from ray.rllib.utils.typing import ResultDict | ||
|
||
|
@@ -122,6 +137,38 @@ def training( | |
|
||
return self | ||
|
||
@override(SACConfig) | ||
def get_default_learner_class(self) -> Union[Type["Learner"], str]: | ||
if self.framework_str == "torch": | ||
from ray.rllib.algorithms.cql.torch.cql_torch_learner import CQLTorchLearner | ||
|
||
return CQLTorchLearner | ||
else: | ||
raise ValueError( | ||
f"The framework {self.framework_str} is not supported. " | ||
"Use `'torch'` instead." | ||
) | ||
|
||
@override(AlgorithmConfig) | ||
def build_learner_connector( | ||
self, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see my comment in CQLLearner. I think we should move this there. Same as for DQN. Or if there are good arguments to leave these in the Config classes, we should also change it in DQN/SAC. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above. In rare cases the class inheritance avoids it being changed there and changing it in You are right that we need to unify the way of adding them in all algorithms (also MARWIL) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It just doesn't feel clean, doing it here. The config classes should contain as little (actually implemented) algo logic as possible and just store settings. I know I sound like a broken record, but "separation of concerns" principle :):
Suggestion: |
||
input_observation_space, | ||
input_action_space, | ||
device=None, | ||
): | ||
pipeline = super().build_learner_connector( | ||
input_observation_space=input_observation_space, | ||
input_action_space=input_action_space, | ||
device=device, | ||
) | ||
|
||
pipeline.insert_after( | ||
AddObservationsFromEpisodesToBatch, | ||
AddNextObservationsFromEpisodesToTrainBatch(), | ||
) | ||
|
||
return pipeline | ||
|
||
@override(SACConfig) | ||
def validate(self) -> None: | ||
# First check, whether old `timesteps_per_iteration` is used. | ||
|
@@ -150,6 +197,12 @@ def validate(self) -> None: | |
) | ||
try_import_tfp(error=True) | ||
|
||
# Assert that for a local learner the number of iterations is 1. Note, | ||
simonsays1980 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# this is needed because we have no iterators, but instead a single | ||
# batch returned directly from the `OfflineData.sample` method. | ||
if self.num_learners == 0 and not self.dataset_num_iters_per_learner: | ||
self.dataset_num_iters_per_learner = 1 | ||
|
||
|
||
class CQL(SAC): | ||
"""CQL (derived from SAC).""" | ||
|
@@ -171,6 +224,77 @@ def get_default_policy_class( | |
|
||
@override(SAC) | ||
def training_step(self) -> ResultDict: | ||
if self.config.enable_env_runner_and_connector_v2: | ||
return self._training_step_new_api_stack() | ||
elif self.config.enable_rl_module_and_learner: | ||
simonsays1980 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
"Hybrid API stack is not supported. Either set " | ||
"`enable_rl_module_and_learner=True` and " | ||
"`enable_env_runner_and_connector_v2=True` or set both " | ||
"attributed to `False`." | ||
) | ||
else: | ||
return self._training_step_old_api_stack() | ||
|
||
def _training_step_new_api_stack(self) -> ResultDict: | ||
|
||
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)): | ||
# Sampling from offline data. | ||
simonsays1980 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batch = self.offline_data.sample( | ||
num_samples=self.config.train_batch_size_per_learner, | ||
num_shards=self.config.num_learners, | ||
return_iterator=True if self.config.num_learners > 1 else False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great catch :) |
||
) | ||
|
||
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): | ||
# Updating the policy. | ||
simonsays1980 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO (simon, sven): Check, if we should execute directly s.th. like | ||
# update_from_iterator. | ||
learner_results = self.learner_group.update_from_batch( | ||
batch, | ||
minibatch_size=self.config.train_batch_size_per_learner, | ||
num_iters=self.config.dataset_num_iters_per_learner, | ||
) | ||
|
||
# Log training results. | ||
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) | ||
self.metrics.log_value( | ||
NUM_ENV_STEPS_TRAINED_LIFETIME, | ||
self.metrics.peek( | ||
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED) | ||
), | ||
reduce="sum", | ||
) | ||
self.metrics.log_dict( | ||
{ | ||
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): ( | ||
stats[NUM_MODULE_STEPS_TRAINED] | ||
) | ||
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items() | ||
}, | ||
reduce="sum", | ||
) | ||
|
||
# Synchronize weights. | ||
# As the results contain for each policy the loss and in addition the | ||
# total loss over all policies is returned, this total loss has to be | ||
# removed. | ||
modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES} | ||
|
||
# Update weights - after learning on the local worker - | ||
# on all remote workers. Note, we only have the local `EnvRunner`, | ||
# but from this `EnvRunner` the evaulation `EnvRunner`s get updated. | ||
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): | ||
self.env_runner_group.sync_weights( | ||
# Sync weights from learner_group to all EnvRunners. | ||
from_worker_or_learner_group=self.learner_group, | ||
policies=modules_to_update, | ||
inference_only=True, | ||
) | ||
|
||
return self.metrics.reduce() | ||
|
||
def _training_step_old_api_stack(self) -> ResultDict: | ||
# Collect SampleBatches from sample workers. | ||
with self._timers[SAMPLE_TIMER]: | ||
train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,20 @@ | ||||||||||||||||||||||||||
from ray.air.constants import TRAINING_ITERATION | ||||||||||||||||||||||||||
from ray.rllib.algorithms.sac.sac_learner import SACLearner | ||||||||||||||||||||||||||
from ray.rllib.core.learner.learner import Learner | ||||||||||||||||||||||||||
from ray.rllib.utils.annotations import override | ||||||||||||||||||||||||||
from ray.rllib.utils.metrics import ALL_MODULES | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
class CQLLearner(SACLearner): | ||||||||||||||||||||||||||
@override(Learner) | ||||||||||||||||||||||||||
def build(self) -> None: | ||||||||||||||||||||||||||
# We need to call the `super()`'s `build` method here to have the variables | ||||||||||||||||||||||||||
# for `alpha`` and the target entropy defined. | ||||||||||||||||||||||||||
super().build() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like it's better to add the NEXT_OBS LearnerConnector piece here, what do you think? The CQLLearner is the component that has this (hard) requirement (meanin it won't work w/o this connector piece), so it should take care of adding it here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the Learner is the component that needs the connectors. On the other side it uses the In this specific case I needed to add it in the |
||||||||||||||||||||||||||
# Add a metric to keep track of training iterations to | ||||||||||||||||||||||||||
# determine when switching the actor loss from behavior | ||||||||||||||||||||||||||
# cloning to SAC. | ||||||||||||||||||||||||||
self.metrics.log_value( | ||||||||||||||||||||||||||
(ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1 | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can all go, b/c we now use the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size="large" is better, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, maybe. It takes around 50 iterations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will tune this in another PR.