-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make transforms stateless #4551
Make transforms stateless #4551
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading the entire diff I'm now quite sure I got the purposes of the rv_var
and rv_value
args wrong.
pymc3/distributions/__init__.py
Outdated
|
||
if transform is not None and rv_var is None: | ||
warnings.warn( | ||
f"A transform was found for {measure_var}" " but no corresponding random variable" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
String is a bit messed up.
More importantly: The sentence is a bit incomplete - no variable corresponding to what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That measure_var
doesn't have a random variable associated with it, so there's really nothing else to print or say. If anything, this should probably be an error condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might actually make more sense to associate the transform object with the rv_var
(i.e. the random variable). I'll have to think about that.
pymc3/distributions/transforms.py
Outdated
rv_var | ||
The random variable being transformed | ||
rv_value | ||
The parameters required for the transform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rv_value
doesn't sound very intuitive for something that holds the transform parameters. (I was confused by this above already.)
How about rv
and transform_params
?
Or rv
and params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to be clear about the rv_var
and rv_value[_var]
distinctions.
rv_var
s are the "sample-space" variables that are produced by RandomVariable
Op
s.
rv_value[_var]
s are the "measure-space" (or log-likelihood) variables that correspond to a specific value of an rv_var
.
These are the same two types of variables described here, where the sloppy P(X = x)
or x ~ X
notation denotes the rv_var
with X
(i.e. the random variable), and the value variable with rv_value[_var]
.
These transform methods are getting those two variables, so any new name that involves "params" would be inaccurate, because the rv_value
variable does not provide parameters. The first argument, rv_var
, does provide access to a random variable's parameters via rv_var.owner.inputs
, and—again—rv_value
is a value that's compatible with the random variable rv_var
(i.e. a value that could've been a sample from it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So rv_var
is the tensor of the user-provided, observed values? (A TensorConstant
?)
We might still want to copy parts of your explanation into the docstring.
b97dc37
to
0762608
Compare
0762608
to
43bd711
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few threads still open.
Nevertheless I'll say LGTM, but don't count too much on my judgement. Most of my trust comes from the facts that Brandon did this and that the CI Tests are now ✔.
with pytest.warns( | ||
DeprecationWarning, match="The argument `eps` is deprecated and will not be used." | ||
): | ||
tr.StickBreaking(eps=1e-9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Where) do we keep a list of these changes? We should mention them in the release notes. The alternative is to raise
the DeprecationWarning
which saves users from complicated digging.
Sorry, been pretty busy, but I have another commit to push, and it's a big refactor that should address most/all of the open |
This make `aesara.graph.basic.clone_replace` work correctly when `Scan`s are included in a graph.
6d8c136
to
05d4e19
Compare
@@ -161,80 +157,119 @@ def rv_log_likelihood_args( | |||
variable). | |||
|
|||
""" | |||
if not var.owner: | |||
return None, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't match with the return type hints and docstring.
Can you explain (maybe in the docstring) why and under what circumstances None, None
is returned?
rv_value = rv_var.type.filter_variable(rv_value.astype(rv_var.dtype)) | ||
|
||
if rv_value_var is None: | ||
rv_value_var = rv_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the case when rv_value
has no observations, right?
mean = alpha / (alpha + beta) | ||
variance = (alpha * beta) / ((alpha + beta) ** 2 * (alpha + beta + 1)) | ||
# mean = alpha / (alpha + beta) | ||
# variance = (alpha * beta) / ((alpha + beta) ** 2 * (alpha + beta + 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be removed?
# | ||
# @logp_transform.register(rv_type) | ||
# def transform(op, *args, **kwargs): | ||
# return class_transform(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO
super().__init__(shape, dtype, defaults=defaults, *args, **kwargs) | ||
if kwargs.get("transform", None): | ||
raise ValueError("Transformations for discrete distributions") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we keep the dtype
checks? (Based on intX
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are done at the Aesara Op
-level now (i.e. within RandomVariable.make_node
); although I'm not sure if float-to-int conversion is part of that. It might only raise an exception for the wrong dtype. If it's not, then we might need to add that at this level.
The failing test looks like the non-deterministic |
05d4e19
to
00dcfad
Compare
00dcfad
to
049c5f8
Compare
This PR addresses a few more transform changes/issues.
The primary change is that transforms are now stateless (i.e. they no longer carry their own parameters). Stateful transforms make it very easy to accidentally introduce old and/or irrelevant parameters into a graph, and are a source for some extremely confusing and difficult bugs. That's why this change was made.
Now, transforms only take a "parameter extraction function" that, when applied to a random variable, will extract the required transform parameters.
In other words, transform objects are no longer random variable instance-specific, but random variable class-specfic.