Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference utilities to transform between unconstrained and constrained space #1564

Merged
merged 1 commit into from
Mar 30, 2023

Conversation

aymgal
Copy link
Contributor

@aymgal aymgal commented Mar 27, 2023

This is a proposal to fix #1554

names.
:return: `dict` of transformation keyed by site names.
"""
transforms = get_transforms(model, model_args, model_kwargs, params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here params is in unconstrained space, so we can't substitute natively. If you want to deal with params, please feel free to adjust constrain_fn for it. Maybe also rename unconstrain_values to unconstrain_fn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed you're right. Let me know if your points have been addressed by the latest commit :)

@aymgal aymgal requested a review from fehiepsi March 27, 2023 21:47
@@ -157,7 +157,8 @@ def transform_fn(transforms, params, invert=False):
return {k: transforms[k](v) if k in transforms else v for k, v in params.items()}


def constrain_fn(model, model_args, model_kwargs, params, return_deterministic=False):
def constrain_fn(model, model_args, model_kwargs, params,
include_param_sites=False, return_deterministic=False):
Copy link
Member

@fehiepsi fehiepsi Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove those include_param_sites flags and keep the True behavior at all functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was for ensuring backward compatibility. I removed these in the last comit

Copy link
Member

@fehiepsi fehiepsi Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I guess it won't affect the old behavior. If users provide param sites, your solution will return expected results, while the current master branch will raise an error or skip them - so this is an improvement.

substituted_model = substitute(model, data=params)
transforms, _, _, _ = _get_model_transforms(substituted_model, model_args, model_kwargs)
return transforms

Copy link
Member

@fehiepsi fehiepsi Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you run make format to fix lint issue? I guess you need to add a new line here

…ained space

Improve and simplify constrain_fn and unconstrain_fn implementation

Add missing doctstrings

Constrain/unconstrain functions now always consider param sites

Fix syntax for lint tests

Fix syntax for lint tests

Fix syntax for lint tests
@aymgal
Copy link
Contributor Author

aymgal commented Mar 30, 2023

@fehiepsi fyi I squashed the commits of this PR as I had to do lots of syntax fixing recently... should be ok now

@fehiepsi
Copy link
Member

Thanks @aymgal! It's great to have this utility available.

@fehiepsi fehiepsi merged commit abe456c into pyro-ppl:master Mar 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inverse bjiector transformation (from constrained to unconstrained space)
2 participants