From 3dadc7411e33688b4d2918c212258539ba92eeb1 Mon Sep 17 00:00:00 2001 From: Olaf Lipinski <5785856+olipinski@users.noreply.github.com> Date: Fri, 9 Sep 2022 14:51:49 +0100 Subject: [PATCH] [RLlib] Fix OneHotPreprocessor, use gym.spaces.utils.flatten. (#27540) --- rllib/models/preprocessors.py | 8 +--- rllib/models/tests/test_preprocessors.py | 56 ++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 5398a342a55b..f30c1d560bb3 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -186,13 +186,7 @@ def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: @override(Preprocessor) def transform(self, observation: TensorType) -> np.ndarray: self.check_shape(observation) - arr = np.zeros(self._init_shape(self._obs_space, {}), dtype=np.float32) - if isinstance(self._obs_space, gym.spaces.Discrete): - arr[observation] = 1 - else: - for i, o in enumerate(observation): - arr[np.sum(self._obs_space.nvec[:i]) + o] = 1 - return arr + return gym.spaces.utils.flatten(self._obs_space, observation).astype(np.float32) @override(Preprocessor) def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None: diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py index ca2f869665e6..1d97f00aaeab 100644 --- a/rllib/models/tests/test_preprocessors.py +++ b/rllib/models/tests/test_preprocessors.py @@ -154,6 +154,62 @@ def test_nested_multidiscrete_one_hot_preprocessor(self): [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0], ) + def test_multidimensional_multidiscrete_one_hot_preprocessor(self): + space2d = MultiDiscrete([[2, 2], [3, 3]]) + space3d = MultiDiscrete([[[2, 2], [3, 4]], [[5, 6], [7, 8]]]) + pp2d = get_preprocessor(space2d)(space2d) + pp3d = get_preprocessor(space3d)(space3d) + self.assertTrue(isinstance(pp2d, OneHotPreprocessor)) + self.assertTrue(isinstance(pp3d, OneHotPreprocessor)) + self.assertTrue(pp2d.shape == (10,)) + self.assertTrue(pp3d.shape == (37,)) + check( + pp2d.transform(np.array([[1, 0], [2, 1]])), + [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0], + ) + check( + pp3d.transform(np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])), + [ + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + ], + ) + if __name__ == "__main__": import pytest