-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve pruning module #2354
Improve pruning module #2354
Conversation
@@ -131,6 +134,73 @@ class StaticPruningHook : public IParameterUpdaterHook { | |||
std::vector<bool> mask_; | |||
}; | |||
|
|||
class DynamicPruningHook : public IParameterUpdaterHook { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add the dynamicpruning hook, which calculate mask according to sparsity_ratio
|
||
/** | ||
* ParameterUpdaterHook actually factory method. | ||
*/ | ||
static IParameterUpdaterHook* createImpl( | ||
const ParameterUpdaterHookConfig& config) { | ||
auto& type = config.type(); | ||
if (type == "pruning") { | ||
if (config.has_purning_mask_filename()) { | ||
if (type == "pruning_static") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the specific 'pruning' type is dynamic one, and 'pruning_static' is static which read mask from the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增的DynamicPruningHook
,和原来StaticPruningHook
的区别是:
DynamicPruningHook
是根据config里面的sparsity_ratio
生成一个mask
StaticPruningHook
是从文件读入mask
二者对于Parameter
的操作是一样的,只是初始化mask
的方式不一样。我不认为这种实现是一种Dynamic Pruning
。它还是Static Pruning
,可以替换原来的Staic Pruning
实现。
Dynamic Pruning
是类似这篇文章Dynamic Network Surgery for Efficient DNNs中的方法,不需要指定sparsity_ratio
,可以自动地设置和调整sparsity_ratio
,直至达到最大的压缩率。
@@ -25,6 +25,9 @@ limitations under the License. */ | |||
#include "paddle/utils/Flags.h" | |||
#include "paddle/utils/Util.h" | |||
|
|||
using std::vector; | |||
using std::pair; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要使用using
语句
@Xreki 嗯, 那么就只保留sparsity_ratio 的这个 |
只保留这种就可以了,原来的 |
* define which link/weight between neural is disabled. | ||
* Static means user specific a sparsity_ratio map before training started. The | ||
* network will | ||
* hold the sparsity_ratio maximum numbers of parameters, and cut off the rest. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line 33,多于的map
。line 35,不通顺。
SameThreadChecker updateThreadChecker_; | ||
std::atomic<size_t> initCount_; | ||
VectorPtr maskVec_; | ||
std::vector<bool> mask_; | ||
VectorPtr maskTemp_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看看能不能不把maskTemp
做成成员变量。
} | ||
|
||
LOG(FATAL) << "Unknown Hook type: " << type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认一下,如果没有配置hook
时,是不会调用该函数的吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块是为了保证如果指定的hook type不在我们提供的当中的话,会报错。 python v2 /python/paddle/trainer/config_parser.py ParameterHook 处有关于hook type的检测,但是未来也有其他不同过python方式来调用这个的吧。
proto/ParameterConfig.proto
Outdated
@@ -26,7 +26,8 @@ enum ParameterInitStrategy { | |||
|
|||
message ParameterUpdaterHookConfig { | |||
required string type = 1; | |||
optional string purning_mask_filename = 2; | |||
//hook type such as 'pruning' | |||
optional double sparsity_ratio = 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然purning_mask_filename
类型删掉了,sparsity_ratio
应该设置成2。另外,sparsity_ratio
若是可选的,则应该设置默认值。
proto/ParameterConfig.proto
Outdated
@@ -26,7 +26,8 @@ enum ParameterInitStrategy { | |||
|
|||
message ParameterUpdaterHookConfig { | |||
required string type = 1; | |||
optional string purning_mask_filename = 2; | |||
//hook type such as 'pruning' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释一般放在前面,并且//
后要有一个空格。
Hook Attribute object. The hook is an auxiliary operation that occurs | ||
during network propagation. Such as pruning operation, It will cut off | ||
redundant parameters in the network before training. More detail can see | ||
here paddle/parameter/ParameterUpdaterHook.cpp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可改成引用论文。另外,最后一句检查一下语法。
|
||
:param sparsity_ratio: Must be specified if hook type is 'pruning', | ||
the network will hold the sparsity_ratio maximum parameters, and cut off the rest. | ||
:type sparsity_ratio: float number between 0 and 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:type xxx:
后面只跟类型,范围限制应该放到:param xxx:
后面。
… improve_pruning
… improve_pruning
|
||
for (size_t i = 0; i < para->getSize(); i++) | ||
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); | ||
std::sort(param.begin(), param.end(), sortPairAscend); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以用std::partial_sort
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hedaoyuan 嗯,这里用这个排序会更好一些,我会修改一下
dataPtr[i++] = m ? 1.0 : 0.0; | ||
} | ||
} | ||
|
||
// Currently just use a mask vector for hack. | ||
// @TODO(yuyang18): Implemented the mask operation in vector. | ||
if (para->useGpu()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
86-91这段代码可以挪到generateMask里面去吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hedaoyuan 这个是已经删掉的代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哦,我指的是修改后的文件的86-91行。if (para->useGpu())
这段逻辑。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以放到那里边去
if (config.has_purning_mask_filename()) { | ||
return new StaticPruningHook(config.purning_mask_filename()); | ||
} | ||
if (config.has_sparsity_ratio()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config.has_sparsity_ratio()
的判断放到StaticPruningHook的构造里面去吧。另外,这个我看python里面是有default值的,这里为什么是报错,而不是加default值?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StaticPruningHook
只是赋值没有判断。我指的是There must be sparsity_ratio parameter for pruning Hook.
这个逻辑本身是属于StaticPruningHook
的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python端会修改一下: 如果没有指定sparsit_ratio, 自动使用默认值, c++这块的判断将删掉
@@ -60,17 +61,28 @@ class StaticPruningHook : public IParameterUpdaterHook { | |||
maskTemp_ = Vector::create(para->getSize(), false); | |||
maskTemp_->zeroMem(); | |||
real* dataPtr = maskTemp_->getData(); | |||
size_t sparsityNum = para->getSize() * (1 - sparsityRatio_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sparsityRatio_
这个指的是非零元还是零元的ratio?我看左边的第72行,原先的定义看起来是非零元的ratio,这里为什么换了?
···
for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++)
···
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是0元的ratio,mask初始值全为0, 这里要做的是将非0元素的mask 设置为1,所以为(1 - sparsityRatio_), 这里的sparsityNum确实起的不是太好,会修改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proto/ParameterConfig.proto
Outdated
required string type = 1; | ||
optional string purning_mask_filename = 2; | ||
optional double sparsity_ratio = 2 [default = 0.8]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要注释清楚sparsity_ratio指的是非零还是零元的。比如default=0.8指的是80%的零元?实际上我看到sparsity_ratio
的第一反应以为是非零元的占比。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hedaoyuan 嗯,好
* Static means user load a mask map before training started. This map will | ||
* define which link/weight between neural is disabled. | ||
* Static means user specific a sparsity_ratio before training start, and the | ||
* network will prune the parameters based on the sparsity_ratio. More deatils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: specific -> specify, start -> started
More deatils can see -> More details can be found
SetDevice device(para->getDeviceId()); | ||
void generateMask(Parameter* para) { | ||
VectorPtr vec = para->getBuf(PARAMETER_VALUE); | ||
maskTemp_ = Vector::create(para->getSize(), false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maskTemp_
改成局部变量
|
||
std::partial_sort( | ||
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); | ||
for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一些变量的名字看能不能改下,比如dataPtr
,我得往上找到63行才知道这个变量指的是什么内容,最好用一些能顾名思义的名字。包括vec
、vecCpu
等
during network propagation. | ||
NOTE: IT IS A HIGH LEVEL USER INTERFACE. | ||
|
||
:param type: Hook type, eg: 'pruning' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的注释会用来生成api文档,所以最好写详细点。比如所有支持的type类型,以及参考的论文工作。
""" | ||
Hook Attribute object. The hook is an auxiliary operation that occurs | ||
during network propagation. | ||
NOTE: IT IS A HIGH LEVEL USER INTERFACE. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好能说明下Hook是用来干什么的,作用的对象。另外,我认为这个NOTE就没有必要了。
assert is_compatible_with( | ||
self.sparsity_ratio, | ||
float), 'sparisity_ratio must be float type' | ||
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'sparisity must be a flaot between [0, 1] ' -> 'sparisity_ratio must be a float between [0, 1] ',错误提示与变量名保持一致,另外还有typo。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
|
… improve_pruning
… improve_pruning
… improve_pruning
… improve_pruning
resolve #2284
add dynamicPruningHook in ParameterUpdaterHook.cpp
improve python v2 api