Skip to content

Commit

Permalink
[RLlib] Move Learner Hp assignment to validate (ray-project#33392)
Browse files Browse the repository at this point in the history
* Move adding params to learner hps to validate in order to be compatible with rllib yaml files
* Move learner_hp assignment from builder functions to validate

Signed-off-by: Avnish <[email protected]>
Signed-off-by: Jack He <[email protected]>
  • Loading branch information
avnishn authored and ProjectsByJackHe committed Mar 21, 2023
1 parent 2d86091 commit 5dea598
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 25 deletions.
28 changes: 11 additions & 17 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,6 @@ def __init__(self, algo_class=None):
self._tf_policy_handles_more_than_one_loss = True
# __sphinx_doc_end__
# fmt: on
self._learner_hps.discount_factor = self.gamma
self._learner_hps.entropy_coeff = self.entropy_coeff
self._learner_hps.vf_loss_coeff = self.vf_loss_coeff
self._learner_hps.vtrace_drop_last_ts = self.vtrace_drop_last_ts
self._learner_hps.vtrace_clip_rho_threshold = self.vtrace_clip_rho_threshold
self._learner_hps.vtrace_clip_pg_rho_threshold = (
self.vtrace_clip_pg_rho_threshold
)
self._learner_hps.rollout_frag_or_episode_len = self.rollout_fragment_length

# Deprecated value.
self.num_data_loader_buffers = DEPRECATED_VALUE
Expand Down Expand Up @@ -282,15 +273,10 @@ def training(
self.vtrace = vtrace
if vtrace_clip_rho_threshold is not NotProvided:
self.vtrace_clip_rho_threshold = vtrace_clip_rho_threshold
self._learner_hps.vtrace_clip_rho_threshold = vtrace_clip_rho_threshold
if vtrace_clip_pg_rho_threshold is not NotProvided:
self.vtrace_clip_pg_rho_threshold = vtrace_clip_pg_rho_threshold
self._learner_hps.vtrace_clip_pg_rho_threshold = (
vtrace_clip_pg_rho_threshold
)
if vtrace_drop_last_ts is not NotProvided:
self.vtrace_drop_last_ts = vtrace_drop_last_ts
self._learner_hps.vtrace_drop_last_ts = vtrace_drop_last_ts
if num_multi_gpu_tower_stacks is not NotProvided:
self.num_multi_gpu_tower_stacks = num_multi_gpu_tower_stacks
if minibatch_buffer_size is not NotProvided:
Expand Down Expand Up @@ -331,10 +317,8 @@ def training(
self.epsilon = epsilon
if vf_loss_coeff is not NotProvided:
self.vf_loss_coeff = vf_loss_coeff
self._learner_hps.vf_loss_coeff = vf_loss_coeff
if entropy_coeff is not NotProvided:
self.entropy_coeff = entropy_coeff
self._learner_hps.entropy_coeff = entropy_coeff
if entropy_coeff_schedule is not NotProvided:
self.entropy_coeff_schedule = entropy_coeff_schedule
if _separate_vf_optimizer is not NotProvided:
Expand All @@ -345,7 +329,6 @@ def training(
self.after_train_step = after_train_step
if gamma is not NotProvided:
self.gamma = gamma
self._learner_hps.discount_factor = self.gamma

return self

Expand Down Expand Up @@ -394,9 +377,20 @@ def validate(self) -> None:
"term/optimizer! Try setting config.training("
"_tf_policy_handles_more_than_one_loss=True)."
)
# learner hps need to be updated inside of config.validate in order to have
# the correct values for when a user starts an experiment from a dict. This is
# as oppposed to assigning the values inthe builder functions such as `training`
self._learner_hps.rollout_frag_or_episode_len = (
self.get_rollout_fragment_length()
)
self._learner_hps.discount_factor = self.gamma
self._learner_hps.entropy_coeff = self.entropy_coeff
self._learner_hps.vf_loss_coeff = self.vf_loss_coeff
self._learner_hps.vtrace_drop_last_ts = self.vtrace_drop_last_ts
self._learner_hps.vtrace_clip_rho_threshold = self.vtrace_clip_rho_threshold
self._learner_hps.vtrace_clip_pg_rho_threshold = (
self.vtrace_clip_pg_rho_threshold
)

@override(AlgorithmConfig)
def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfig:
Expand Down
19 changes: 11 additions & 8 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,12 @@ def training(
self.use_critic = use_critic
# TODO (Kourosh) This is experimental. Set learner_hps parameters as
# well. Don't forget to remove .use_critic from algorithm config.
self._learner_hps.use_critic = use_critic
if use_gae is not NotProvided:
self.use_gae = use_gae
if lambda_ is not NotProvided:
self.lambda_ = lambda_
if kl_coeff is not NotProvided:
self.kl_coeff = kl_coeff
self._learner_hps.kl_coeff = kl_coeff
if sgd_minibatch_size is not NotProvided:
self.sgd_minibatch_size = sgd_minibatch_size
if num_sgd_iter is not NotProvided:
Expand All @@ -244,24 +242,18 @@ def training(
self.shuffle_sequences = shuffle_sequences
if vf_loss_coeff is not NotProvided:
self.vf_loss_coeff = vf_loss_coeff
self._learner_hps.vf_loss_coeff = vf_loss_coeff
if entropy_coeff is not NotProvided:
self.entropy_coeff = entropy_coeff
self._learner_hps.entropy_coeff = entropy_coeff
if entropy_coeff_schedule is not NotProvided:
self.entropy_coeff_schedule = entropy_coeff_schedule
self._learner_hps.entropy_coeff_schedule = entropy_coeff_schedule
if clip_param is not NotProvided:
self.clip_param = clip_param
self._learner_hps.clip_param = clip_param
if vf_clip_param is not NotProvided:
self.vf_clip_param = vf_clip_param
self._learner_hps.vf_clip_param = vf_clip_param
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if kl_target is not NotProvided:
self.kl_target = kl_target
self._learner_hps.kl_target = kl_target

return self

Expand Down Expand Up @@ -304,6 +296,17 @@ def validate(self) -> None:
# Check `entropy_coeff` for correctness.
if self.entropy_coeff < 0.0:
raise ValueError("`entropy_coeff` must be >= 0.0")
# learner hps need to be updated inside of config.validate in order to have
# the correct values for when a user starts an experiment from a dict. This is
# as oppposed to assigning the values inthe builder functions such as `training`
self._learner_hps.use_critic = self.use_critic
self._learner_hps.kl_coeff = self.kl_coeff
self._learner_hps.vf_loss_coeff = self.vf_loss_coeff
self._learner_hps.entropy_coeff = self.entropy_coeff
self._learner_hps.entropy_coeff_schedule = self.entropy_coeff_schedule
self._learner_hps.clip_param = self.clip_param
self._learner_hps.vf_clip_param = self.vf_clip_param
self._learner_hps.kl_target = self.kl_target


class UpdateKL:
Expand Down

0 comments on commit 5dea598

Please sign in to comment.