-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Asp movement #36525
Asp movement #36525
Conversation
…modified examples in doc.
@@ -66,7 +66,7 @@ def decorate(optimizer): | |||
|
|||
import paddle | |||
import paddle.fluid as fluid | |||
from paddle.fluid.contrib import sparsity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要对用户暴露的接口的示例代码中,不要再import fluid
,也不要用任何fluid
的API。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的 已經全面修正
@@ -158,6 +154,8 @@ def prune_model(place, | |||
optimizer = sparsity.decorate(optimizer) | |||
optimizer.minimize(loss, startup_program) | |||
|
|||
place = paddle.CPUPlace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要设置CPUPlace
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改成使用paddle.get_device來判斷Place的裝置
from ...fluid.contrib.sparsity import create_mask #noqa: F401 | ||
from ...fluid.contrib.sparsity import check_sparsity #noqa: F401 | ||
from ...fluid.contrib.sparsity import MaskAlgo #noqa: F401 | ||
from ...fluid.contrib.sparsity import CheckMethod #noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再确认一下,以上API都需要向用户暴露吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一定需要暴露給用戶的API僅為MaskAlgo,目的是要讓用戶能夠決定Pruning的方式 (1D, 2D_greedy, 2D_best)等。
其餘依照建議移除
from ...fluid.contrib.sparsity import decorate #noqa: F401 | ||
from ...fluid.contrib.sparsity import prune_model #noqa: F401 | ||
from ...fluid.contrib.sparsity import set_excluded_layers #noqa: F401 | ||
from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
将需要对外公开的API 放到__all__列表,paddle会根据__all__列表区别公开API和内部使用API,
公开API需要提供使用文档和示例
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# limitations under the License. | ||
|
||
from ...fluid.contrib.sparsity import calculate_density #noqa: F401 | ||
from ...fluid.contrib.sparsity import MaskAlgo #noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MaskAlgo这个类型可以作为内部使用的数据类型使用,不建议对外公开。
比如这种写法暴露了较多内部实现的细节,书写和阅读都不是很方便
prune_model(... func_name=sparsity.MaskAlgo.MASK_1D)
建议改成这种形式更简洁,减少对外公开的数据类型,采用小写方便输入
prune_model(... mask_algo='mask_1d')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
採用建議使用小寫string作為輸入選項,並隱藏 MaskAlgo
@@ -114,7 +112,7 @@ def prune_model(place, | |||
inference only. To obtain OptimizerWithSparsityGuarantee, please see `sparsity.decoreate()`. | |||
|
|||
Args: | |||
place (fluid.CPUPlace()|fluid.CUDAPlace(N)): Device place for pruned parameter and mask Variables, and N means the GPU's id. It should be the same as created instance of Executor. | |||
place (paddle.CPUPlace()|paddle.CUDAPlace(N)): Device place for pruned parameter and mask Variables, and N means the GPU's id. It should be the same as created instance of Executor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认下,place这个参数是否是必须的?
是否可以用这两组参数来修改全局的place信息?
paddle.device.get_device()
paddle.device.set_device()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以的,依照建議移除place argument 採用paddle.device.get_device(), paddle.device.set_device()來自動判斷
1. Do not expose MaskAlgo, instead using string as data type of func_name. 2. Removed 'place' arg in prune_model() 3. Added __all__ in paddle.static.sparsity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
n=2, | ||
m=4, | ||
func_name=sparsity.MaskAlgo.MASK_1D, | ||
func_name='mask_1d', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里api的参数名称func_name是否指mask algorithm?建议再改一下,更清晰一些
func_name -> mask_algo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
依照建議修改完成。
main_program (Program, optional): Program with model definition and its parameters. Default is `paddle.static.default_main_program() | ||
n (int): n of `n:m` sparse pattern. | ||
m (int): m of `n:m` sparse pattern. | ||
func_name (MaskAlgo, optional): The function name to generate spase mask. Default is `MaskAlgo.MASK_1D`. All options please refer to `MaskAlgo`. | ||
func_name (string, optional): The function name to generate spase mask. Default is `mask_1d`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
APIs
Describe
Added Automatic SParsity (ASP)'s alias into
paddel.static
.Since the folder organization changes in v2.2, the
paddle.fluid.xxx
is not the official way to use modules. We move ASP module to paddle.static to make its usage be consistent with official ways, also modified code examples and related unittests.Example: