-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add probability distribution transformation APIs #40536
Add probability distribution transformation APIs #40536
Conversation
✅ This PR's description meets the template requirements! |
Thanks for your contribution! |
e7850b7
to
141c642
Compare
141c642
to
ae14d10
Compare
ae14d10
to
540de48
Compare
540de48
to
b14dae1
Compare
f37b54c
to
9fcd9e8
Compare
9fcd9e8
to
983c1fb
Compare
983c1fb
to
b96f2e1
Compare
b96f2e1
to
3421235
Compare
3421235
to
6985966
Compare
""" | ||
BIJECTION = 'bijection' # bijective(injective and surjective) | ||
INJECTION = 'injection' # injective-only | ||
SURJECTION = 'surjection' # surjective-inly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
TODO:Add Chinese documentation
print(affine.forward(x)) | ||
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||
# [1., 2.]) | ||
print(affine.inverse(power.forward(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"power" means paddle.distribution.PowerTransform?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新,粘贴错误,粘贴了一个旧的代码示例
Examples in Describe above: print(affine.inverse(power.forward(x))), "power" means paddle.distribution.PowerTransform? or should be affine? |
是 ''affine'',粘贴错误,文档和PR描述均已更新 |
@@ -96,7 +96,7 @@ def prob(self, value): | |||
Args: | |||
value (Tensor): value which will be evaluated | |||
""" | |||
raise NotImplementedError | |||
return self.log_prob(value).exp() | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add some abstract methods in base class, e.g. mean, variance and rsample
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加,这样能保持不同子类持有的方法一致;存量 normal, uniform, categorical目前是NotImplementedError
,按照设计文档计划后续统一更新
6985966
to
176b413
Compare
176b413
to
4af8502
Compare
4af8502
to
723a279
Compare
723a279
to
c8ef422
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
from paddle.distribution.kl import kl_divergence, register_kl | ||
from paddle.distribution.multinomial import Multinomial | ||
from paddle.distribution.normal import Normal | ||
from paddle.distribution.transform import * # noqa: F403 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要import *的原因是什么呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distribution/transform.py
文件中定义了需要公开API的__ALL__列表,在distribution/__init__
中用import *
全部导出,并添加到__init__.py
的__all__列表中, 可以通过paddle.distribution.xxx
访问,访问路径和竞品保持一致
PR types
New features
PR changes
APIs
Describe
Adds 13 transformation APIs and 2 distribution APIs :
new transformation APIs:
new distribution APIs:
Examples: