Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix AlphaStar for tf2+tracing; smaller cleanups around avoiding to wrap a TFPolicy as_eager() or with_tracing more than once. #24271

Merged
merged 7 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ def compute_gradients(self, loss, var_list):
return list(zip(self.tape.gradient(loss, var_list), var_list))


class EagerTFPolicy(Policy):
"""Dummy class to recognize any eagerized TFPolicy by its inheritance."""
pass


def build_eager_tf_policy(
name,
loss_fn,
Expand Down Expand Up @@ -309,7 +314,7 @@ def __init__(self, observation_space, action_space, config):
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
self.framework = config.get("framework", "tfe")
Policy.__init__(self, observation_space, action_space, config)
EagerTFPolicy.__init__(self, observation_space, action_space, config)

# Global timestep should be a tensor.
self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
Expand Down Expand Up @@ -594,7 +599,7 @@ def postprocess_trajectory(
):
assert tf.executing_eagerly()
# Call super's postprocess_trajectory first.
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
if postprocess_fn:
return postprocess_fn(self, sample_batch, other_agent_batches, episode)
return sample_batch
Expand Down
6 changes: 4 additions & 2 deletions rllib/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,11 @@ class for.
assert tf1.executing_eagerly()

from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy

# Create eager-class.
if hasattr(orig_cls, "as_eager"):
# Create eager-class (if not already one).
if hasattr(orig_cls, "as_eager") and \
not issubclass(orig_cls, EagerTFPolicy):
cls = orig_cls.as_eager()
if config.get("eager_tracing"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per our offline message, I think the real fix should actually be to pull this

if config.get("eager_tracing"): cls = cls.with_tracing()

block out of as_eager block.
so we will try to enable tracing regardless if we as_eager() here, or somewhere else.
just FYI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Thanks for the catch.

cls = cls.with_tracing()
Expand Down