-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Created action_dist_v2 for RLModule examples, RLModule PR 2/N #29600
Conversation
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
rllib/models/action_dist_v2.py
Outdated
|
||
|
||
@ExperimentalAPI | ||
class ActionDistributionV2(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should really take this opportunity and get rid of the restriction to actions. Can we just call this Distribution
?
This would then cover models that predict next states, rewards, actions, etc.. as well as serve inside variational autoencoders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree. fixed.
rllib/models/action_dist_v2.py
Outdated
options. | ||
|
||
Args: | ||
action_space: The action space this distribution will be used for, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
action_space -> space
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
rllib/models/action_dist_v2.py
Outdated
def required_model_output_shape( | ||
action_space: gym.Space, model_config: ModelConfigDict | ||
) -> Tuple[int, ...]: | ||
"""Returns the required shape of an input parameter tensor for a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we give 2 examples here of how this method will be used (for Categorical and DiagGaussian)?
- Does the shape include the batch dim (time dim, etc..)?
- What are typical cases, where the model config would be used and how would it be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this will actually get used tbh. I have it here just as a reminder to the next developer that this was part of the old distribution that we may or may not keep depending on how the catalog gets written. added a comment to reflect this.
rllib/models/action_dist_v2.py
Outdated
return_logp: bool = False, | ||
**kwargs | ||
) -> Union[TensorType, Tuple[TensorType, TensorType]]: | ||
"""Draw a re-parameterized sample from the action distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we quickly explain what the difference is to sample
? Basically that rsample is backprop'able and sample is not (iiuc)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep added.
rllib/models/action_dist_v2.py
Outdated
self, | ||
*, | ||
sample_shape: Tuple[int, ...] = None, | ||
return_logp: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great choice to add this option to the signature!
rllib/models/action_dist_v2.py
Outdated
"""The policy action distribution of an agent. | ||
|
||
Args: | ||
inputs: input vector to define the distribution over. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always a vector? Or could this also be a dict of tensors (multi-distribution) or a >1D tensor (multi-variate diag gaussian)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is a residue of the old stuff. Doesn't mean anything anymore. Removed. :)
rllib/models/action_dist_v2.py
Outdated
"""Draw a sample from the action distribution. | ||
|
||
Args: | ||
sample_shape: The shape of the sample to draw. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we be more specific here? Doesn't the input (currently self.inputs
) already determine the batch/time dimensions and the rest is "fixed"? If this is not the case, can we add examples to this docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be addressed in the docstrings of each individual distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome PR @kouroshHakha ! Just a few questions and nits on the docstrings and one renaming request.
Exciting to see all these things being done-over! :) |
def rsample( | ||
self, *, sample_shape=torch.Size(), return_logp: bool = False | ||
) -> Union[TensorType, Tuple[TensorType, TensorType]]: | ||
sample = self.dist.rsample(sample_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: rename local var sample
to rsample
for clarity?
self.dist = self._get_distribution(*args, **kwargs) | ||
|
||
@abc.abstractmethod | ||
def _get_distribution(self, *args, **kwargs) -> torch.distributions.Distribution: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, the more I think about this, we should rename this into _get_torch_distribution
for clarity (vs _get_tf_distribution
).
Or even _get_underlying_torch_distribution
. B/c self
is already a TorchDistribution
, albeit an RLlib one :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. fixed.
logits: torch.Tensor = None, | ||
temperature: float = 1.0, | ||
) -> None: | ||
super().__init__(probs=probs, logits=logits, temperature=temperature) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
loc: torch.Tensor, | ||
scale: Optional[torch.Tensor] = None, | ||
): | ||
super().__init__(loc=loc, scale=scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
""" | ||
|
||
def __init__(self, loc: torch.Tensor) -> None: | ||
super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
|
||
@DeveloperAPI | ||
class TorchDeterministic(ActionDistributionV2): | ||
"""Action distribution that returns the input values directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Action distribution" -> "Distribution"
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@sven1977 please re-review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
|
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this new base class!
just a little worried that we are not adding TF version of the implementation yet.
should be pretty easy actually?
but this is good regardless.
The goal is to get to the torch ppo POC first on RLModules. Then come back and modify for TF. |
@gjoliver This is ready for merge. The failing tests are not relevant. Thanks. |
…ray-project#29600) * created action_dist_v2 for RLModule examples Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi [email protected]
Why are these changes needed?
submitting this pr in pieces:
These action distributions have simpler interface and are more explicit. Previously users would have to pass in this magical input and modelv2 instance (only for auto-regressive distributions), but now the interface is more explicit and more familiar since it looks like pytorch distributions. The auto-regressive distribution can still be created by sub-classing this base-class.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.