Skip to content

Commit

Permalink
[RLlib] Add a flag to allow disabling initialize_loss_from_dummy_batc…
Browse files Browse the repository at this point in the history
…h logit. (ray-project#34208)

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: elliottower <[email protected]>
  • Loading branch information
kouroshHakha authored and elliottower committed Apr 22, 2023
1 parent 07cb238 commit 746578e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
6 changes: 6 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def __init__(self, algo_class=None):
self._disable_preprocessor_api = False
self._disable_action_flattening = False
self._disable_execution_plan_api = True
self._disable_initialize_loss_from_dummy_batch = False

# Has this config object been frozen (cannot alter its attributes anymore).
self._is_frozen = False
Expand Down Expand Up @@ -2437,6 +2438,7 @@ def experimental(
_disable_preprocessor_api: Optional[bool] = NotProvided,
_disable_action_flattening: Optional[bool] = NotProvided,
_disable_execution_plan_api: Optional[bool] = NotProvided,
_disable_initialize_loss_from_dummy_batch: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's experimental settings.
Expand Down Expand Up @@ -2476,6 +2478,10 @@ def experimental(
self._disable_action_flattening = _disable_action_flattening
if _disable_execution_plan_api is not NotProvided:
self._disable_execution_plan_api = _disable_execution_plan_api
if _disable_initialize_loss_from_dummy_batch is not NotProvided:
self._disable_initialize_loss_from_dummy_batch = (
_disable_initialize_loss_from_dummy_batch
)

return self

Expand Down
3 changes: 3 additions & 0 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,9 @@ def _initialize_loss_from_dummy_batch(
TensorType]]]): An optional stats function to be called after
the loss.
"""

if self.config.get("_disable_initialize_loss_from_dummy_batch", False):
return
# Signal Policy that currently we do not like to eager/jit trace
# any function calls. This is to be able to track, which columns
# in the dummy batch are accessed by the different function (e.g.
Expand Down

0 comments on commit 746578e

Please sign in to comment.