diff --git a/src/garage/_environment.py b/src/garage/_environment.py index 78bf7b0f61..0425610e64 100644 --- a/src/garage/_environment.py +++ b/src/garage/_environment.py @@ -159,6 +159,8 @@ class Environment(abc.ABC): +-----------------------+ | visualize() | +-----------------------+ + | seed() | + +-----------------------+ | close() | +-----------------------+ @@ -350,6 +352,16 @@ def _validate_render_mode(self, mode): 'got render mode {} instead.'.format( self.render_modes, mode)) + @abc.abstractmethod + def seed(self, seed): + """Sets environment seeds. + + This method should set all seeds specific to the environment library. + + Args: + seed (int): The seed value to set + """ + def __del__(self): """Environment destructor.""" self.close() @@ -452,6 +464,14 @@ def visualize(self): """Creates a visualization of the wrapped environment.""" self._env.visualize() + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + self._env.seed() + def close(self): """Close the wrapped env.""" self._env.close() diff --git a/src/garage/envs/dm_control/dm_control_env.py b/src/garage/envs/dm_control/dm_control_env.py index 5c47b4634f..69764f5968 100644 --- a/src/garage/envs/dm_control/dm_control_env.py +++ b/src/garage/envs/dm_control/dm_control_env.py @@ -184,6 +184,16 @@ def visualize(self): self._viewer = DmControlViewer(title=title) self._viewer.launch(self._env) + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + # pylint: disable=protected-access + self._env._task._random = np.random.RandomState(seed) + self.action_space.seed(seed) + def close(self): """Close the environment.""" if self._viewer: diff --git a/src/garage/envs/grid_world_env.py b/src/garage/envs/grid_world_env.py index af624e7b76..40d8f12559 100644 --- a/src/garage/envs/grid_world_env.py +++ b/src/garage/envs/grid_world_env.py @@ -184,6 +184,13 @@ def render(self, mode): def visualize(self): """Creates a visualization of the environment.""" + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + def close(self): """Close the env.""" diff --git a/src/garage/envs/gym_env.py b/src/garage/envs/gym_env.py index 321fe0ecaa..2940aee3e7 100644 --- a/src/garage/envs/gym_env.py +++ b/src/garage/envs/gym_env.py @@ -288,6 +288,15 @@ def visualize(self): self._env.render(mode='human') self._visualize = True + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + self._env.seed(seed) + self.action_space.seed(seed) + def close(self): """Close the wrapped env.""" self._close_viewer_window() diff --git a/src/garage/envs/metaworld_set_task_env.py b/src/garage/envs/metaworld_set_task_env.py index e5d5ad5fe1..5d9f6be1c8 100644 --- a/src/garage/envs/metaworld_set_task_env.py +++ b/src/garage/envs/metaworld_set_task_env.py @@ -251,6 +251,13 @@ def visualize(self): """Creates a visualization of the wrapped environment.""" self._current_env.visualize() + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + def close(self): """Close the wrapped env.""" for env in self._envs.values(): diff --git a/src/garage/envs/point_env.py b/src/garage/envs/point_env.py index 8d4f2fc0ad..47262db4fe 100644 --- a/src/garage/envs/point_env.py +++ b/src/garage/envs/point_env.py @@ -182,6 +182,14 @@ def visualize(self): def close(self): """Close the env.""" + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + + """ + # pylint: disable=no-self-use def sample_tasks(self, num_tasks): """Sample a list of `num_tasks` tasks. diff --git a/src/garage/sampler/_functions.py b/src/garage/sampler/_functions.py index ac6545b063..d8140859f0 100644 --- a/src/garage/sampler/_functions.py +++ b/src/garage/sampler/_functions.py @@ -1,5 +1,6 @@ """Functions used by multiple Samplers or Workers.""" from garage import Environment +from garage.experiment import deterministic from garage.sampler.env_update import EnvUpdate @@ -33,6 +34,7 @@ def _apply_env_update(old_env, env_update): elif isinstance(env_update, Environment): if old_env is not None: old_env.close() + env_update.seed(deterministic.get_seed()) return env_update, True else: raise TypeError('Unknown environment update type.') diff --git a/tests/garage/sampler/test_local_sampler.py b/tests/garage/sampler/test_local_sampler.py index 653fa46cfa..d5ed17d502 100644 --- a/tests/garage/sampler/test_local_sampler.py +++ b/tests/garage/sampler/test_local_sampler.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from garage.envs import PointEnv +from garage.envs import GymEnv, PointEnv +from garage.envs.dm_control import DMControlEnv from garage.experiment.task_sampler import SetTaskSampler from garage.np.policies import FixedPolicy, ScriptedPolicy from garage.sampler import LocalSampler, WorkerFactory @@ -103,3 +104,69 @@ def test_no_seed(): sampler = LocalSampler.from_worker_factory(workers, policy, env) episodes = sampler.obtain_samples(0, 160, policy) assert sum(episodes.lengths) >= 160 + + +def test_deterministic_on_policy_sampling_gym_env(): + max_episode_length = 10 + env1 = GymEnv('LunarLander-v2') + env2 = GymEnv('LunarLander-v2') + # Fix the action sequence + env1.action_space.seed(10) + env2.action_space.seed(10) + policy1 = FixedPolicy(env1.spec, + scripted_actions=[ + env1.action_space.sample() + for _ in range(max_episode_length) + ]) + policy2 = FixedPolicy(env2.spec, + scripted_actions=[ + env2.action_space.sample() + for _ in range(max_episode_length) + ]) + n_workers = 2 + worker1 = WorkerFactory(seed=10, + max_episode_length=max_episode_length, + n_workers=n_workers) + worker2 = WorkerFactory(seed=10, + max_episode_length=max_episode_length, + n_workers=n_workers) + sampler1 = LocalSampler.from_worker_factory(worker1, policy1, env1) + sampler2 = LocalSampler.from_worker_factory(worker2, policy2, env2) + episodes1 = sampler1.obtain_samples(0, 10, policy1) + episodes2 = sampler2.obtain_samples(0, 10, policy2) + assert np.array_equal(episodes1.observations, episodes2.observations) + assert np.array_equal(episodes1.next_observations, + episodes2.next_observations) + + +def test_deterministic_on_policy_sampling_dm_env(): + max_episode_length = 10 + env1 = DMControlEnv.from_suite('cartpole', 'balance') + env2 = DMControlEnv.from_suite('cartpole', 'balance') + # Fix the action sequence + env1.action_space.seed(10) + env2.action_space.seed(10) + policy1 = FixedPolicy(env1.spec, + scripted_actions=[ + env1.action_space.sample() + for _ in range(max_episode_length) + ]) + policy2 = FixedPolicy(env2.spec, + scripted_actions=[ + env2.action_space.sample() + for _ in range(max_episode_length) + ]) + n_workers = 2 + worker1 = WorkerFactory(seed=10, + max_episode_length=max_episode_length, + n_workers=n_workers) + worker2 = WorkerFactory(seed=10, + max_episode_length=max_episode_length, + n_workers=n_workers) + sampler1 = LocalSampler.from_worker_factory(worker1, policy1, env1) + sampler2 = LocalSampler.from_worker_factory(worker2, policy2, env2) + episodes1 = sampler1.obtain_samples(0, 10, policy1) + episodes2 = sampler2.obtain_samples(0, 10, policy2) + assert np.array_equal(episodes1.observations, episodes2.observations) + assert np.array_equal(episodes1.next_observations, + episodes2.next_observations)