From 7ceb761330cb3edaaf457b4ea53ea9fb66dd3611 Mon Sep 17 00:00:00 2001 From: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Date: Mon, 10 Apr 2023 11:31:28 -0700 Subject: [PATCH] [RLlib] Add a flag to allow disabling initialize_loss_from_dummy_batch logit. (#34208) Signed-off-by: Kourosh Hakhamaneshi --- rllib/algorithms/algorithm_config.py | 6 ++++++ rllib/policy/policy.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index af6c82e0e46a..2b721670f71b 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -427,6 +427,7 @@ def __init__(self, algo_class=None): self._disable_preprocessor_api = False self._disable_action_flattening = False self._disable_execution_plan_api = True + self._disable_initialize_loss_from_dummy_batch = False # Has this config object been frozen (cannot alter its attributes anymore). self._is_frozen = False @@ -2437,6 +2438,7 @@ def experimental( _disable_preprocessor_api: Optional[bool] = NotProvided, _disable_action_flattening: Optional[bool] = NotProvided, _disable_execution_plan_api: Optional[bool] = NotProvided, + _disable_initialize_loss_from_dummy_batch: Optional[bool] = NotProvided, ) -> "AlgorithmConfig": """Sets the config's experimental settings. @@ -2476,6 +2478,10 @@ def experimental( self._disable_action_flattening = _disable_action_flattening if _disable_execution_plan_api is not NotProvided: self._disable_execution_plan_api = _disable_execution_plan_api + if _disable_initialize_loss_from_dummy_batch is not NotProvided: + self._disable_initialize_loss_from_dummy_batch = ( + _disable_initialize_loss_from_dummy_batch + ) return self diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 2fa9e46e0863..bb534497e3d2 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -1383,6 +1383,9 @@ def _initialize_loss_from_dummy_batch( TensorType]]]): An optional stats function to be called after the loss. """ + + if self.config.get("_disable_initialize_loss_from_dummy_batch", False): + return # Signal Policy that currently we do not like to eager/jit trace # any function calls. This is to be able to track, which columns # in the dummy batch are accessed by the different function (e.g.