diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 0e920692b669..4493f3eb7028 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -728,7 +728,10 @@ def _initialize_loss_from_dummy_batch( logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box( - -1.0, 1.0, shape=value.shape[1:], dtype=value.dtype.name + -1.0, + 1.0, + shape=value.shape.as_list()[1:], + dtype=value.dtype.name, ), used_for_compute_actions=False, ) diff --git a/rllib/policy/dynamic_tf_policy_v2.py b/rllib/policy/dynamic_tf_policy_v2.py index 0aecf006dbf6..6ff080415adf 100644 --- a/rllib/policy/dynamic_tf_policy_v2.py +++ b/rllib/policy/dynamic_tf_policy_v2.py @@ -712,7 +712,10 @@ def _initialize_loss_from_dummy_batch( logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box( - -1.0, 1.0, shape=value.shape[1:], dtype=value.dtype.name + -1.0, + 1.0, + shape=value.shape.as_list()[1:], + dtype=value.dtype.name, ), used_for_compute_actions=False, )