Skip to content

Commit

Permalink
[RLlib] Issue 17900: Set seed in single vectorized sub-envs properl…
Browse files Browse the repository at this point in the history
…y, if `num_envs_per_worker > 1` (#18110)

* In case a worker runs multiple envs, make sure a different seed can be deterministically set on all of them.

* Revert a couple of whitespace changes.

* Fix a few style errors.

Co-authored-by: Jun Gong <[email protected]>
  • Loading branch information
gjoliver and Jun Gong authored Aug 26, 2021
1 parent edac59f commit a881367
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 47 deletions.
122 changes: 81 additions & 41 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,72 @@ def get_global_worker() -> "RolloutWorker":
return _global_worker


def _update_global_seed(policy_config: TrainerConfigDict, seed: int):
"""Set Python random, numpy, env, and torch/tf seeds.
This is useful for debugging and testing purposes.
"""
if not seed:
return

# Python random module.
random.seed(seed)
# Numpy.
np.random.seed(seed)

# Torch.
if torch and policy_config.get("framework") == "torch":
torch.manual_seed(seed)
# See https://github.com/pytorch/pytorch/issues/47672.
cuda_version = torch.version.cuda
if cuda_version is not None and float(torch.version.cuda) >= 10.2:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
else:
from distutils.version import LooseVersion

if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"):
# Not all Operations support this.
torch.use_deterministic_algorithms(True)
else:
torch.set_deterministic(True)
# This is only for Convolution no problem.
torch.backends.cudnn.deterministic = True
# Tf2.x.
elif tf and policy_config.get("framework") == "tf2":
tf.random.set_seed(seed)
# Tf-eager.
elif tf1 and policy_config.get("framework") == "tfe":
tf1.set_random_seed(seed)


def _update_env_seed(env: EnvType, seed: int, worker_idx: int,
vector_idx: int):
"""Set a deterministic random seed on environment.
TODO: does remote envs have seed() func?
TODO: if a gym env is wrapped in a gym.wrappers.Monitor,
is there still seed() func?
"""
if not seed:
return

# A single RL job is unlikely to have more than 10K
# rollout workers.
max_num_envs_per_workers: int = 1000
assert worker_idx < max_num_envs_per_workers, \
"Too many envs per worker. Random seeds may collide."
computed_seed: int = (
worker_idx * max_num_envs_per_workers + vector_idx + seed)

# Gym.env.
# This will silently fail for most OpenAI gyms
# (they do nothing and return None per default)
if not hasattr(env, "seed"):
logger.info("Env doesn't support env.seed(): {}".format(env))
else:
env.seed(computed_seed)


@DeveloperAPI
class RolloutWorker(ParallelIteratorWorker):
"""Common experience collection class.
Expand Down Expand Up @@ -306,6 +372,7 @@ class to use.
gym.spaces.Space]]]): An optional space dict mapping policy IDs
to (obs_space, action_space)-tuples. This is used in case no
Env is created on this RolloutWorker.
policy: Obsoleted arg. Use `policy_spec` instead.
monitor_path: Obsoleted arg. Use `record_env` instead.
"""

Expand Down Expand Up @@ -394,6 +461,8 @@ def gen_rollouts():

self.env = None

_update_global_seed(policy_config, seed)

# Create an env for this worker.
if not (worker_index == 0 and num_workers > 0
and not policy_config.get("create_env_on_driver")):
Expand Down Expand Up @@ -480,14 +549,25 @@ def wrap(env):

# Wrap env through the correct wrapper.
self.env: EnvType = wrap(self.env)
# Ideally, we would use the same make_env() function below
# to create self.env, but wrap(env) and self.env has a cyclic
# dependency on each other right now, so we would settle on
# duplicating the random seed setting logic for now.
_update_env_seed(self.env, seed, worker_index, 0)

def make_env(vector_index):
return wrap(
env = wrap(
env_creator(
env_context.copy_with_overrides(
worker_index=worker_index,
vector_index=vector_index,
remote=remote_worker_envs)))
# make_env() is used to created additional environments
# during environment vectorization below.
# So we make sure a deterministic random seed is set on
# all the environments if specified.
_update_env_seed(env, seed, worker_index, vector_index)
return env

self.make_env_fn = make_env
self.spaces = spaces
Expand All @@ -507,46 +587,6 @@ def make_env(vector_index):
self.policy_map: PolicyMap = None
self.preprocessors: Dict[PolicyID, Preprocessor] = None

# Set Python random, numpy, env, and torch/tf seeds.
if seed is not None:
# Python random module.
random.seed(seed)
# Numpy.
np.random.seed(seed)
# Gym.env.
# This will silently fail for most OpenAI gyms
# (they do nothing and return None per default)
if not hasattr(self.env, "seed"):
logger.info("Env doesn't support env.seed(): {}".format(
self.env))
else:
self.env.seed(seed)

# Torch.
if torch and policy_config.get("framework") == "torch":
torch.manual_seed(seed)
# See https://github.com/pytorch/pytorch/issues/47672.
cuda_version = torch.version.cuda
if cuda_version is not None and float(
torch.version.cuda) >= 10.2:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
else:
from distutils.version import LooseVersion
if LooseVersion(
torch.__version__) >= LooseVersion("1.8.0"):
# Not all Operations support this.
torch.use_deterministic_algorithms(True)
else:
torch.set_deterministic(True)
# This is only for Convolution no problem.
torch.backends.cudnn.deterministic = True
# Tf2.x.
elif tf and policy_config.get("framework") == "tf2":
tf.random.set_seed(seed)
# Tf-eager.
elif tf1 and policy_config.get("framework") == "tfe":
tf1.set_random_seed(seed)

# Check available number of GPUs.
num_gpus = policy_config.get("num_gpus", 0) if \
self.worker_index == 0 else \
Expand Down
11 changes: 11 additions & 0 deletions rllib/evaluation/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,17 @@ def test_no_env_seed(self):
assert not hasattr(ev.env, "seed")
ev.stop()

def test_multi_env_seed(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv2(100),
num_envs=3,
policy_spec=MockPolicy,
seed=1)
seeds = ev.foreach_env(lambda env: env.rng_seed)
# Make sure all environments get a different deterministic seed.
assert seeds == [1, 2, 3]
ev.stop()

def sample_and_flush(self, ev):
time.sleep(2)
ev.sample()
Expand Down
2 changes: 2 additions & 0 deletions rllib/examples/deterministic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 2, # parallelism
# Make sure every environment gets a fixed seed.
"num_envs_per_worker": 2,
"framework": args.framework,
"seed": args.seed,

Expand Down
20 changes: 14 additions & 6 deletions rllib/examples/env/env_using_remote_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ def __init__(self, env_config):
# Get our param server (remote actor) by name.
self._handler = ray.get_actor(
env_config.get("param_server", "param-server"))
self.rng_seed = None

# def seed(self, seed=None):
# print(f"Seeding env (worker={self.env_config.worker_index}) "
# f"with {seed}")
# self.np_random, seed = seeding.np_random(seed)
# return [seed]
def seed(self, rng_seed: int = None):
if not rng_seed:
return

print(f"Seeding env (worker={self.env_config.worker_index}) "
f"with {rng_seed}")

self.rng_seed = rng_seed
self.np_random, _ = seeding.np_random(rng_seed)

def reset(self):
# Pass in our RNG to guarantee no race conditions.
Expand All @@ -43,7 +48,10 @@ def reset(self):
# IMPORTANT: Advance the state of our RNG (self._rng was passed
# above via ray (serialized) and thus not altered locally here!).
# Or create a new RNG from another random number:
new_seed = self.np_random.randint(0, 1000000)
# Seed the RNG with a deterministic seed if set, otherwise, create
# a random one.
new_seed = (self.np_random.randint(0, 1000000)
if not self.rng_seed else self.rng_seed)
self.np_random, _ = seeding.np_random(new_seed)

print(f"Env worker-idx={self.env_config.worker_index} "
Expand Down
4 changes: 4 additions & 0 deletions rllib/examples/env/mock_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, episode_length):
self.i = 0
self.observation_space = gym.spaces.Discrete(100)
self.action_space = gym.spaces.Discrete(2)
self.rng_seed = None

def reset(self):
self.i = 0
Expand All @@ -48,6 +49,9 @@ def step(self, action):
self.i += 1
return self.i, 100.0, self.i >= self.episode_length, {}

def seed(self, rng_seed):
self.rng_seed = rng_seed


class VectorizedMockEnv(VectorEnv):
"""Vectorized version of the MockEnv.
Expand Down

0 comments on commit a881367

Please sign in to comment.