Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi] Migrate Adam and AdamW into Phi #40351

Merged
merged 23 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
358af3d
[Phi] Migrate Adam and Adamw into Phi
Aurelius84 Mar 8, 2022
231e607
fix compile error and unittest ok
Aurelius84 Mar 9, 2022
30c10ca
Merge remote-tracking branch 'upstream/develop' into move_adamw_grad
Aurelius84 Mar 9, 2022
5ad1710
fix compile error and unittest ok
Aurelius84 Mar 9, 2022
3e74332
Merge branch 'develop' into move_adamw_grad
Aurelius84 Mar 9, 2022
1f62ee3
fix undefined reference to fLI::FLAGS
Aurelius84 Mar 9, 2022
7e012b2
test depend on operator
Aurelius84 Mar 9, 2022
70fd083
fix cmake
Aurelius84 Mar 10, 2022
31321a9
Merge branch 'develop' into move_adamw_grad
Aurelius84 Mar 10, 2022
97eaddf
fix xpu compile
Aurelius84 Mar 11, 2022
75e76d9
Merge remote-tracking branch 'upstream/develop' into move_adamw_grad
Aurelius84 Mar 11, 2022
cc4020c
fix infrt
Aurelius84 Mar 11, 2022
6eed365
fix amp_type_traits
Aurelius84 Mar 11, 2022
d0da6e4
Merge remote-tracking branch 'upstream/develop' into move_adamw_grad
Aurelius84 Mar 11, 2022
f11a13b
fix amp_type_traits
Aurelius84 Mar 11, 2022
042324c
modify according reviewer
Aurelius84 Mar 14, 2022
0ee26f6
modify according reviewer
Aurelius84 Mar 14, 2022
cd50e49
Merge remote-tracking branch 'upstream/develop' into move_adamw_grad
Aurelius84 Mar 14, 2022
21be3e3
fix dtype float16
Aurelius84 Mar 15, 2022
c657b59
fix typo
Aurelius84 Mar 15, 2022
c5acc09
Merge remote-tracking branch 'upstream/develop' into move_adamw_grad
Aurelius84 Mar 24, 2022
42aa15e
fix Cmake
Aurelius84 Mar 24, 2022
c5e9cc1
fix code style
Aurelius84 Mar 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class DenseTensor;
DECLARE_bool(benchmark);
DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check);
PADDLE_DEFINE_EXPORTED_int32(inner_op_parallelism, 0,
"number of threads for inner op");
DECLARE_bool(run_kp_kernel);

namespace paddle {
Expand Down
132 changes: 92 additions & 40 deletions paddle/fluid/operators/math/selected_rows_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,30 +294,30 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
// add or mul.
namespace scatter {

template <typename T>
template <typename T, typename DeviceContext>
typename std::enable_if<!std::is_integral<T>::value>::type elementwise_add_to(
phi::funcs::BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) {
phi::funcs::BlasT<DeviceContext, T>* blas, size_t data_len, const T* in,
T* out) {
blas->AXPY(data_len, T(1.f), in, out);
}

template <typename T>
template <typename T, typename DeviceContext>
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
phi::funcs::BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) {
phi::funcs::BlasT<DeviceContext, T>* blas, size_t data_len, const T* in,
T* out) {
for (size_t i = 0; i < data_len; i++) {
out[i] += in[i];
}
}

template <typename T>
template <typename T, typename DeviceContext>
typename std::enable_if<std::is_same<T, platform::bfloat16>::value>::type
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
const std::unordered_map<int64_t, size_t>& rows_to_id,
int64_t input_width,
const platform::CPUDeviceContext& context, T* out_data) {
int64_t input_width, const DeviceContext& context,
T* out_data) {
#ifndef PADDLE_WITH_MKLDNN
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
#endif
for (auto* input : inputs) {
if (input->rows().size() == 0) {
Expand All @@ -336,22 +336,22 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
#else
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id.at(input_rows[i]);
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
&input_data[i * input_width],
&out_data[out_i * input_width]);
elementwise_add_to<T, DeviceContext>(
&blas, static_cast<size_t>(input_width), &input_data[i * input_width],
&out_data[out_i * input_width]);
}
#endif
}
}

template <typename T>
template <typename T, typename DeviceContext>
typename std::enable_if<!std::is_same<T, platform::bfloat16>::value>::type
add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
const std::unordered_map<int64_t, size_t>& rows_to_id,
int64_t input_width,
const platform::CPUDeviceContext& context, T* out_data) {
int64_t input_width, const DeviceContext& context,
T* out_data) {
VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name();
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context);
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
Expand All @@ -361,32 +361,31 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,

for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id.at(input_rows[i]);
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
&input_data[i * input_width],
&out_data[out_i * input_width]);
elementwise_add_to<T, DeviceContext>(
&blas, static_cast<size_t>(input_width), &input_data[i * input_width],
&out_data[out_i * input_width]);
}
}
}

template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
phi::SelectedRows operator()(const platform::CPUDeviceContext& context,
template <typename DeviceContext, typename T>
struct MergeAddImpl {
phi::SelectedRows operator()(const DeviceContext& context,
const phi::SelectedRows& input,
const bool sorted_result = false) {
phi::SelectedRows out;
(*this)(context, input, &out, sorted_result);
return out;
}

void operator()(const platform::CPUDeviceContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result = false) {
void operator()(const DeviceContext& context, const phi::SelectedRows& input,
phi::SelectedRows* output, const bool sorted_result = false) {
std::vector<const phi::SelectedRows*> inputs;
inputs.push_back(&input);
(*this)(context, inputs, output, sorted_result);
}

void operator()(const platform::CPUDeviceContext& context,
void operator()(const DeviceContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result = false) {
if (inputs.size() == 0) {
Expand Down Expand Up @@ -461,19 +460,83 @@ struct MergeAdd<platform::CPUDeviceContext, T> {

out.set_rows(merge_rows);

phi::funcs::SetConstant<platform::CPUDeviceContext, T> constant_functor;
phi::funcs::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), static_cast<T>(0.f));

std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
}

add_sparse_inputs<T>(inputs, rows_to_id, input_width, context, out_data);
add_sparse_inputs<T, DeviceContext>(inputs, rows_to_id, input_width,
context, out_data);
}
}
};

template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi::SelectedRows operator()(const platform::CPUDeviceContext& context,
const phi::SelectedRows& input,
const bool sorted_result) {
return MergeAddImpl<platform::CPUDeviceContext, T>()(context, input,
sorted_result);
}

void operator()(const platform::CPUDeviceContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result) {
MergeAddImpl<platform::CPUDeviceContext, T>()(context, input, output,
sorted_result);
}

void operator()(const platform::CPUDeviceContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result) {
MergeAddImpl<platform::CPUDeviceContext, T>()(context, inputs, output,
sorted_result);
}
};

template <typename T>
struct MergeAdd<phi::CPUContext, T> {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi::SelectedRows operator()(const phi::CPUContext& context,
const phi::SelectedRows& input,
const bool sorted_result) {
return MergeAddImpl<phi::CPUContext, T>()(context, input, sorted_result);
}

void operator()(const phi::CPUContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result) {
MergeAddImpl<phi::CPUContext, T>()(context, input, output, sorted_result);
}

void operator()(const phi::CPUContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result) {
MergeAddImpl<phi::CPUContext, T>()(context, inputs, output, sorted_result);
}
};

#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype) \
template struct MergeAddImpl<platform::CPUDeviceContext, dtype>; \
template struct MergeAddImpl<phi::CPUContext, dtype>; \
Aurelius84 marked this conversation as resolved.
Show resolved Hide resolved
template struct MergeAdd<platform::CPUDeviceContext, dtype>; \
template struct MergeAdd<phi::CPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(int64_t)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex<double>)

#ifdef PADDLE_WITH_XPU
template <typename T>
struct MergeAdd<platform::XPUDeviceContext, T> {
Expand Down Expand Up @@ -714,17 +777,6 @@ struct MergeAverage<platform::CPUDeviceContext, T> {
}
};

template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex<float>>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex<double>>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::bfloat16>;

#ifdef PADDLE_WITH_XPU
template struct MergeAdd<platform::XPUDeviceContext, float>;
#endif
Expand Down
89 changes: 71 additions & 18 deletions paddle/fluid/operators/math/selected_rows_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -319,19 +319,18 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
}
}

template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
phi::SelectedRows operator()(const platform::CUDADeviceContext& context,
template <typename DeviceContext, typename T>
struct MergeAddImpl {
phi::SelectedRows operator()(const DeviceContext& context,
const phi::SelectedRows& input,
const bool sorted_result = false) {
phi::SelectedRows out;
(*this)(context, input, &out);
return out;
}

void operator()(const platform::CUDADeviceContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result = false) {
void operator()(const DeviceContext& context, const phi::SelectedRows& input,
phi::SelectedRows* output, const bool sorted_result = false) {
framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) {
return;
Expand All @@ -350,7 +349,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());

phi::funcs::SetConstant<platform::CUDADeviceContext, T> constant_functor;
phi::funcs::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), static_cast<T>(0));

auto* out_data = out.mutable_value()->data<T>();
Expand All @@ -369,7 +368,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
mix_vector_out.CopyToCPU();
}

void operator()(const platform::CUDADeviceContext& context,
void operator()(const DeviceContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result = false) {
if (inputs.size() == 0) {
Expand Down Expand Up @@ -414,7 +413,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());

phi::funcs::SetConstant<platform::CUDADeviceContext, T> constant_functor;
phi::funcs::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), static_cast<T>(0));

auto* out_data = out.mutable_value()->data<T>();
Expand All @@ -441,15 +440,69 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
}
};

template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::bfloat16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
template struct MergeAdd<platform::CUDADeviceContext,
platform::complex<double>>;
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi::SelectedRows operator()(const platform::CUDADeviceContext& context,
const phi::SelectedRows& input,
const bool sorted_result) {
return MergeAddImpl<platform::CUDADeviceContext, T>()(context, input,
sorted_result);
}

void operator()(const platform::CUDADeviceContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result) {
MergeAddImpl<platform::CUDADeviceContext, T>()(context, input, output,
sorted_result);
}

void operator()(const platform::CUDADeviceContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result) {
MergeAddImpl<platform::CUDADeviceContext, T>()(context, inputs, output,
sorted_result);
}
};

template <typename T>
struct MergeAdd<phi::GPUContext, T> {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi::SelectedRows operator()(const phi::GPUContext& context,
const phi::SelectedRows& input,
const bool sorted_result) {
return MergeAddImpl<phi::GPUContext, T>()(context, input, sorted_result);
}

void operator()(const phi::GPUContext& context,
const phi::SelectedRows& input, phi::SelectedRows* output,
const bool sorted_result) {
MergeAddImpl<phi::GPUContext, T>()(context, input, output, sorted_result);
}

void operator()(const phi::GPUContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output, const bool sorted_result) {
MergeAddImpl<phi::GPUContext, T>()(context, inputs, output, sorted_result);
}
};

#define TEMPLATE_SPECIALIZED_FOR_MERGEADD(dtype) \
template struct MergeAddImpl<platform::CUDADeviceContext, dtype>; \
template struct MergeAddImpl<phi::GPUContext, dtype>; \
template struct MergeAdd<platform::CUDADeviceContext, dtype>; \
template struct MergeAdd<phi::GPUContext, dtype>;

TEMPLATE_SPECIALIZED_FOR_MERGEADD(float)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(double)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(int64_t)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::float16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::bfloat16)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<float>)
TEMPLATE_SPECIALIZED_FOR_MERGEADD(platform::complex<double>)

template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
Expand Down
Loading