-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add log_softmax_op_npu #35006
add log_softmax_op_npu #35006
Conversation
Thanks for your contribution! |
cd38d83
to
f27a0dc
Compare
|
||
#include "paddle/fluid/operators/log_softmax_op.h" | ||
#include "paddle/fluid/framework/tensor_util.h" | ||
#include "paddle/fluid/operators/npu_op_runner.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16行的tensor_util.h不需要
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, float>, | ||
ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, double>, | ||
// ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, int>, // | ||
// used to debug |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
48-49行的注释删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat( | ||
axis_dim, axis=axis) | ||
return dx | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以上两个func与test_log_softmax相同,可以直接修改为以下代码,不需要重复实现
from test_log_softmax import ref_log_softmax, test_log_softmax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
globals()[cls_name] = TestLogSoftmaxAxis | ||
|
||
|
||
for _typename in {'float32'}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以支持fp64和fp16的情况下,这里增加float64 / float16的数据类型的单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float64 / float16精度有问题,这里只注册了fp32算子
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已将注册部分的fp16/fp64删去。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
晟腾log_softmax算子适配
算子调用成功截图: