-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add not_equal NPU op #34560
add not_equal NPU op #34560
Conversation
Thanks for your contribution! |
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int8_t>, | ||
ops::LessEqualNPUKernel<plat::NPUDeviceContext, uint8_t>, | ||
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int16_t>, | ||
// ops::LessEqualNPUKernel<plat::NPUDeviceContext, uint16_t>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要支持数据类型的问题请 @qili93 关注下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的数据类型对应CANN的以下数据类型,都可以支持
static TensorType RealNumberType() {
return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64,
DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, DT_BF16};
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已注释的数据类型paddle不支持,取消注释编译会报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释掉的数据类型都从代码里面删除一下,Paddle尽量不包含冗余代码,如果需要有需要说明的地方可以卸载注释里面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int8_t>, | ||
ops::LessEqualNPUKernel<plat::NPUDeviceContext, uint8_t>, | ||
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int16_t>, | ||
// ops::LessEqualNPUKernel<plat::NPUDeviceContext, uint16_t>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的数据类型对应CANN的以下数据类型,都可以支持
static TensorType RealNumberType() {
return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64,
DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, DT_BF16};
}
@@ -142,11 +142,12 @@ def test_attr_name(self): | |||
globals()[cls_name] = Cls | |||
|
|||
|
|||
for _type_name in {'float16', 'float32', 'int32'}: | |||
for _type_name in {'float16', 'float32', 'int32', 'int64'}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在C++算子中支持的所有数据类型都需要在单测里面跑一下验证一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测这边只支持bool、float32、float64、int32、int64,因此只添加了equal算子支持的bool类型
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int8_t>, | ||
ops::LessEqualNPUKernel<plat::NPUDeviceContext, uint8_t>, | ||
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int16_t>, | ||
// ops::LessEqualNPUKernel<plat::NPUDeviceContext, uint16_t>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释掉的数据类型都从代码里面删除一下,Paddle尽量不包含冗余代码,如果需要有需要说明的地方可以卸载注释里面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
add not_equal NPU op