Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove inf's in TruncatedNormal log_prob & sample (#1492) #1581

Merged
merged 1 commit into from
May 5, 2023

Conversation

nikmich1
Copy link
Contributor

This commit fixes #1492 by removing inf return values in log_prob and sample method of the TruncatedNormal distribution, when it is truncated in the tail of the distribution.

Changes made to remove inf's in log_prob method of TruncatedNormal distribution:

  • Added log_cdf method to the Normal distribution
    • Calls the corresponding JAX implementation of the logarithmic cdf
  • Modified log_prob calculation for the truncated Normal
    • New method is more stable in tail of distribution, since it uses the new log_cdf, if it is available
    • Uses logsumexp function to subtract required log_cdf values more numerically stable
    • If log_cdf method is not available, it falls back to original implementation with cdf

Changes made to remove inf's in sample method of TruncatedNormal distribution:

  • Clamped input to inverse cdf icdf to open interval (0,1) using the clamp_probs utility

Tests added:

  • Test implementation of log_cdf method by comparing to JAX implementation and to cdf method (test_normal_log_cdf)
  • Test log_prob implementation by comparing to JAX (test_truncated_normal_log_prob_in_tail)
  • Test sample method in tail of distribution to check, if any inf's are returned (test_sample_truncated_normal_in_tail)

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me, thanks @nikmich1!

@fehiepsi fehiepsi merged commit d63dae4 into pyro-ppl:master May 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

inf's with TruncatedNormal
2 participants