Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Created action_dist_v2 for RLModule examples, RLModule PR 2/N #29600

Merged
merged 10 commits into from
Oct 24, 2022

Conversation

kouroshHakha
Copy link
Contributor

@kouroshHakha kouroshHakha commented Oct 24, 2022

Signed-off-by: Kourosh Hakhamaneshi [email protected]

Why are these changes needed?

submitting this pr in pieces:
These action distributions have simpler interface and are more explicit. Previously users would have to pass in this magical input and modelv2 instance (only for auto-regressive distributions), but now the interface is more explicit and more familiar since it looks like pytorch distributions. The auto-regressive distribution can still be created by sub-classing this base-class.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@kouroshHakha kouroshHakha changed the title [RLlib] Created action_dist_v2 for RLModule examples [RLlib] Created action_dist_v2 for RLModule examples, RLModule PR 2/N Oct 24, 2022


@ExperimentalAPI
class ActionDistributionV2(abc.ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should really take this opportunity and get rid of the restriction to actions. Can we just call this Distribution?
This would then cover models that predict next states, rewards, actions, etc.. as well as serve inside variational autoencoders.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. fixed.

options.

Args:
action_space: The action space this distribution will be used for,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

action_space -> space

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

def required_model_output_shape(
action_space: gym.Space, model_config: ModelConfigDict
) -> Tuple[int, ...]:
"""Returns the required shape of an input parameter tensor for a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we give 2 examples here of how this method will be used (for Categorical and DiagGaussian)?

  • Does the shape include the batch dim (time dim, etc..)?
  • What are typical cases, where the model config would be used and how would it be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this will actually get used tbh. I have it here just as a reminder to the next developer that this was part of the old distribution that we may or may not keep depending on how the catalog gets written. added a comment to reflect this.

return_logp: bool = False,
**kwargs
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a re-parameterized sample from the action distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we quickly explain what the difference is to sample? Basically that rsample is backprop'able and sample is not (iiuc)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep added.

self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
Copy link
Contributor

@sven1977 sven1977 Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great choice to add this option to the signature!

"""The policy action distribution of an agent.

Args:
inputs: input vector to define the distribution over.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always a vector? Or could this also be a dict of tensors (multi-distribution) or a >1D tensor (multi-variate diag gaussian)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is a residue of the old stuff. Doesn't mean anything anymore. Removed. :)

"""Draw a sample from the action distribution.

Args:
sample_shape: The shape of the sample to draw.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be more specific here? Doesn't the input (currently self.inputs) already determine the batch/time dimensions and the rest is "fixed"? If this is not the case, can we add examples to this docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be addressed in the docstrings of each individual distribution.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome PR @kouroshHakha ! Just a few questions and nits on the docstrings and one renaming request.

@sven1977
Copy link
Contributor

Exciting to see all these things being done-over! :)
@kouroshHakha

def rsample(
self, *, sample_shape=torch.Size(), return_logp: bool = False
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
sample = self.dist.rsample(sample_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: rename local var sample to rsample for clarity?

self.dist = self._get_distribution(*args, **kwargs)

@abc.abstractmethod
def _get_distribution(self, *args, **kwargs) -> torch.distributions.Distribution:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the more I think about this, we should rename this into _get_torch_distribution for clarity (vs _get_tf_distribution).
Or even _get_underlying_torch_distribution. B/c self is already a TorchDistribution, albeit an RLlib one :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. fixed.

logits: torch.Tensor = None,
temperature: float = 1.0,
) -> None:
super().__init__(probs=probs, logits=logits, temperature=temperature)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

loc: torch.Tensor,
scale: Optional[torch.Tensor] = None,
):
super().__init__(loc=loc, scale=scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

"""

def __init__(self, loc: torch.Tensor) -> None:
super().__init__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


@DeveloperAPI
class TorchDeterministic(ActionDistributionV2):
"""Action distribution that returns the input values directly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Action distribution" -> "Distribution"

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@kouroshHakha
Copy link
Contributor Author

@sven1977 please re-review.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sven1977
Copy link
Contributor

test_distributions_v2 failing
cc: @kouroshHakha .

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this new base class!
just a little worried that we are not adding TF version of the implementation yet.
should be pretty easy actually?
but this is good regardless.

@kouroshHakha
Copy link
Contributor Author

I like this new base class! just a little worried that we are not adding TF version of the implementation yet. should be pretty easy actually? but this is good regardless.

The goal is to get to the torch ppo POC first on RLModules. Then come back and modify for TF.

@kouroshHakha
Copy link
Contributor Author

@gjoliver This is ready for merge. The failing tests are not relevant. Thanks.

@gjoliver gjoliver merged commit 3562cb4 into ray-project:master Oct 24, 2022
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
…ray-project#29600)

* created action_dist_v2 for RLModule examples

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants