-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Pten】Support data transform in C++ API #39263
Conversation
… api_data_transform
… api_data_transform
Thanks for your contribution! |
auto* dev_ctx = static_cast<pten::CPUContext*>(pool.Get(tensor.place())); | ||
return pten::TransferLayout(*dev_ctx, tensor, layout); | ||
} else { | ||
PADDLE_THROW(platform::errors::PreconditionNotMet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的报错也可以改成pten::errors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done,thx
case DataType::UINT8: | ||
return pten::Cast<uint8_t>(dev_ctx, tensor, dtype); | ||
default: | ||
PADDLE_THROW(platform::errors::Unimplemented( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx
} | ||
|
||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
pten::DenseTensor CastDateType(const pten::GPUContext& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么需要单独重载一个函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU的Cast
不支持bfloat16类型,所以单独分出来处理了
… api_data_transform
PR types
New features
PR changes
Others
Describe
C++ API内部支持自动data transform