Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix AlphaStar for tf2+tracing; smaller cleanups around avoiding to wrap a TFPolicy as_eager() or with_tracing more than once. #24271

Merged
merged 7 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions rllib/agents/alpha_star/distributed_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict


Expand Down Expand Up @@ -173,6 +174,10 @@ def add_policy(self, policy_id: PolicyID, policy_spec: PolicySpec):

assert len(self.policy_actors) < self.max_num_policies

actual_policy_class = get_tf_eager_cls_if_necessary(
policy_spec.policy_class, cfg
)

colocated = create_colocated_actors(
actor_specs=[
(
Expand All @@ -181,7 +186,7 @@ def add_policy(self, policy_id: PolicyID, policy_spec: PolicySpec):
num_gpus=self.num_gpus_per_policy
if not cfg["_fake_gpus"]
else 0,
)(policy_spec.policy_class),
)(actual_policy_class),
# Policy c'tor args.
(policy_spec.observation_space, policy_spec.action_space, cfg),
# Policy c'tor kwargs={}.
Expand All @@ -207,6 +212,10 @@ def _add_replay_buffer_and_policy(
assert self.replay_actor is None
assert len(self.policy_actors) == 0

actual_policy_class = get_tf_eager_cls_if_necessary(
policy_spec.policy_class, config
)

colocated = create_colocated_actors(
actor_specs=[
(self.replay_actor_class, self.replay_actor_args, {}, 1),
Expand All @@ -218,7 +227,7 @@ def _add_replay_buffer_and_policy(
num_gpus=self.num_gpus_per_policy
if not config["_fake_gpus"]
else 0,
)(policy_spec.policy_class),
)(actual_policy_class),
# Policy c'tor args.
(policy_spec.observation_space, policy_spec.action_space, config),
# Policy c'tor kwargs={}.
Expand Down
23 changes: 19 additions & 4 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ def _func(self_, *args, **kwargs):
return _func


class EagerTFPolicy(Policy):
"""Dummy class to recognize any eagerized TFPolicy by its inheritance."""

pass


def traced_eager_policy(eager_policy_cls):
"""Wrapper class that enables tracing for all eager policy methods.

Expand Down Expand Up @@ -237,6 +243,11 @@ def apply_gradients(self, grads: ModelGradients) -> None:
# `apply_gradients()` (which will call the traced helper).
return super(TracedEagerPolicy, self).apply_gradients(grads)

@classmethod
def with_tracing(cls):
# Already traced -> Return same class.
return cls

TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
return TracedEagerPolicy
Expand Down Expand Up @@ -287,7 +298,7 @@ def build_eager_tf_policy(

This has the same signature as build_tf_policy()."""

base = add_mixins(Policy, mixins)
base = add_mixins(EagerTFPolicy, mixins)

if obs_include_prev_action_reward != DEPRECATED_VALUE:
deprecation_warning(old="obs_include_prev_action_reward", error=False)
Expand All @@ -309,7 +320,7 @@ def __init__(self, observation_space, action_space, config):
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
self.framework = config.get("framework", "tfe")
Policy.__init__(self, observation_space, action_space, config)
EagerTFPolicy.__init__(self, observation_space, action_space, config)

# Global timestep should be a tensor.
self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
Expand Down Expand Up @@ -594,7 +605,7 @@ def postprocess_trajectory(
):
assert tf.executing_eagerly()
# Call super's postprocess_trajectory first.
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
if postprocess_fn:
return postprocess_fn(self, sample_batch, other_agent_batches, episode)
return sample_batch
Expand Down Expand Up @@ -848,7 +859,11 @@ def _compute_actions_helper(

return actions, state_out, extra_fetches

def _learn_on_batch_helper(self, samples):
# TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
# AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
# It seems there may be a clash between the traced-by-tf function and the
# traced-by-ray functions (for making the policy class a ray actor).
def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
# Increase the tracing counter to make sure we don't re-trace too
# often. If eager_tracing=True, this counter should only get
# incremented during the @tf.function trace operations, never when
Expand Down
7 changes: 5 additions & 2 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,8 +737,11 @@ def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
"""
# Store the current global time step (sum over all policies' sample
# steps).
# Make sure, we keep global_timestep as a Tensor.
if self.framework in ["tf2", "tfe"]:
# Make sure, we keep global_timestep as a Tensor for tf-eager
# (leads to memory leaks if not doing so).
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy

if self.framework in ["tf2", "tfe"] and isinstance(self, EagerTFPolicy):
self.global_timestep.assign(global_vars["timestep"])
else:
self.global_timestep = global_vars["timestep"]
Expand Down
47 changes: 25 additions & 22 deletions rllib/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,28 +233,31 @@ class for.
"""
cls = orig_cls
framework = config.get("framework", "tf")
if framework in ["tf2", "tf", "tfe"]:
if not tf1:
raise ImportError("Could not import tensorflow!")
if framework in ["tf2", "tfe"]:
assert tf1.executing_eagerly()

from ray.rllib.policy.tf_policy import TFPolicy

# Create eager-class.
if hasattr(orig_cls, "as_eager"):
cls = orig_cls.as_eager()
if config.get("eager_tracing"):
cls = cls.with_tracing()
# Could be some other type of policy or already
# eager-ized.
elif not issubclass(orig_cls, TFPolicy):
pass
else:
raise ValueError(
"This policy does not support eager "
"execution: {}".format(orig_cls)
)

if framework in ["tf2", "tf", "tfe"] and not tf1:
raise ImportError("Could not import tensorflow!")

if framework in ["tf2", "tfe"]:
assert tf1.executing_eagerly()

from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy

# Create eager-class (if not already one).
if hasattr(orig_cls, "as_eager") and not issubclass(orig_cls, EagerTFPolicy):
cls = orig_cls.as_eager()
# Could be some other type of policy or already
# eager-ized.
elif not issubclass(orig_cls, TFPolicy):
pass
else:
raise ValueError(
"This policy does not support eager " "execution: {}".format(orig_cls)
)

# Now that we know, policy is an eager one, add tracing, if necessary.
if config.get("eager_tracing") and issubclass(cls, EagerTFPolicy):
cls = cls.with_tracing()
return cls


Expand Down