diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 5bb787a998f7..c071be0425c5 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -439,8 +439,18 @@ def _unpack_obs(obs: TensorType, space: Space, tensorlib: Any = tf) -> TensorStr ) offset = 0 if tensorlib == tf: - batch_dims = [v if isinstance(v, int) else v.value for v in obs.shape[:-1]] - batch_dims = [-1 if v is None else v for v in batch_dims] + + def get_value(v): + if v is None: + return -1 + elif isinstance(v, int): + return v + elif v.value is None: + return -1 + else: + return v.value + + batch_dims = [get_value(v) for v in obs.shape[:-1]] else: batch_dims = list(obs.shape[:-1]) if isinstance(space, gym.spaces.Tuple):