Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add log_softmax_op_npu #35006

Merged
merged 3 commits into from
Sep 3, 2021
Merged

add log_softmax_op_npu #35006

merged 3 commits into from
Sep 3, 2021

Conversation

juneweng
Copy link
Contributor

PR types

New features

PR changes

OPs

Describe

晟腾log_softmax算子适配
算子调用成功截图:
log_softmax

@CLAassistant
Copy link

CLAassistant commented Aug 19, 2021

CLA assistant check
All committers have signed the CLA.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

firestonelib
firestonelib previously approved these changes Aug 19, 2021

#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/npu_op_runner.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16行的tensor_util.h不需要

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, double>,
// ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, int>, //
// used to debug
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

48-49行的注释删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat(
axis_dim, axis=axis)
return dx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以上两个func与test_log_softmax相同,可以直接修改为以下代码,不需要重复实现

from test_log_softmax import ref_log_softmax, test_log_softmax

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

globals()[cls_name] = TestLogSoftmaxAxis


for _typename in {'float32'}:
Copy link
Contributor

@qili93 qili93 Aug 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以支持fp64和fp16的情况下,这里增加float64 / float16的数据类型的单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float64 / float16精度有问题,这里只注册了fp32算子

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已将注册部分的fp16/fp64删去。

qili93
qili93 previously approved these changes Sep 1, 2021
Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qili93 qili93 merged commit ba6a312 into PaddlePaddle:develop Sep 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants