From 5e1770927684b3a2f841b636bb613686258e657f Mon Sep 17 00:00:00 2001 From: qingshui Date: Wed, 9 Sep 2020 14:27:02 +0800 Subject: [PATCH] Add fuse mixallgather op, Add sclice tensor op, Fix gcc82 error (#39) --- paddle/fluid/framework/data_feed.h | 2 +- paddle/fluid/framework/fleet/box_wrapper.cu | 7 +- paddle/fluid/framework/fleet/box_wrapper.h | 5 +- .../operators/collective/c_mixallgather_op.cc | 317 ++++++++++++++++++ .../fluid/operators/slice_multi_tensor_op.cc | 147 ++++++++ 5 files changed, 473 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_mixallgather_op.cc create mode 100644 paddle/fluid/operators/slice_multi_tensor_op.cc diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 166a7b7f118d1..74c576c61d1fb 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -840,7 +840,7 @@ class SlotObjAllocator { delete tmp; --capacity_; } - CHECK_EQ(capacity_, 0); + CHECK_EQ(capacity_, static_cast(0)); } T* acquire(void) { T* x = NULL; diff --git a/paddle/fluid/framework/fleet/box_wrapper.cu b/paddle/fluid/framework/fleet/box_wrapper.cu index 62d2910d46db6..5b7a955e963c7 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cu +++ b/paddle/fluid/framework/fleet/box_wrapper.cu @@ -19,6 +19,7 @@ #include #include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { @@ -149,9 +150,11 @@ __device__ void add_calculator_value(const int table_size, const float pred, pos = table_size - 1; } if (label == 0) { - atomicAdd(negative + pos, 1.0); + // atomicAdd(negative + pos, 1.0); + paddle::platform::CudaAtomicAdd(negative + pos, 1.0); } else { - atomicAdd(positive + pos, 1.0); + // atomicAdd(positive + pos, 1.0); + paddle::platform::CudaAtomicAdd(positive + pos, 1.0); } double err = pred - label; abs_error[idx] += fabs(err); diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 6c2a927808b6b..b4db3671ae0ef 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -235,7 +235,8 @@ class BoxWrapper { void InitializeGPUAndLoadModel( const char* conf_file, const std::vector& slot_vector, const std::vector& slot_omit_in_feedpass, - const std::string& model_path, const std::map &lr_map) { + const std::string& model_path, + const std::map& lr_map) { if (nullptr != s_instance_) { VLOG(3) << "Begin InitializeGPU"; std::vector stream_list; @@ -262,7 +263,7 @@ class BoxWrapper { device_caches_ = new DeviceBoxData[gpu_num]; VLOG(0) << "lr_map.size(): " << lr_map.size(); - for (const auto e: lr_map) { + for (const auto e : lr_map) { VLOG(0) << e.first << "'s lr is " << e.second; if (e.first.find("param") != std::string::npos) { lr_map_[e.first + ".w_0"] = e.second; diff --git a/paddle/fluid/operators/collective/c_mixallgather_op.cc b/paddle/fluid/operators/collective/c_mixallgather_op.cc new file mode 100644 index 0000000000000..bc449b4153815 --- /dev/null +++ b/paddle/fluid/operators/collective/c_mixallgather_op.cc @@ -0,0 +1,317 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device_memory_aligment.h" +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif +#include "paddle/fluid/operators/tensor_formatter.h" + +namespace paddle { +namespace operators { + +class CMixAllGatherOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override {} + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } +}; + +template +class CMixAllGatherOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_THROW("CMixAllGather op do not support CPUKernel for now."); + } +}; + +// template +// int check_illegal_count(const T* a, int size, char *buf) { +// int zero = 0; +// int nan = 0; +// int inf = 0; +// for (int i = 0; i < size; ++i) { +// if (a[i] == 0) { +// zero = zero + 1; +// } else if (isnan(a[i])) { +// nan = nan + 1; +// } else if (isinf(a[i])) { +// inf = inf + 1; +// } +// } +// return snprintf(buf, 2048, "(SIZE:%d,NA:%d,INF:%d,ZERO:%d),", size, nan, +// inf, zero); +//} +// void print_cpu_data(const char *name, int device, const void *address, const +// paddle::platform::float16 *a, int size) { +// +//} +// void print_cpu_data(const char *name, int device, const void *address, const +// double *a, int size) { +// +//} +// void print_cpu_data(const char *name, int device, const void *address, const +// int *a, int size) { +// +//} +// void print_cpu_data(const char *name, int device, const void *address, const +// int64_t *a, int size) { +// +//} +// template +// void print_cpu_data(const char *name, int device, const void *address, const +// T *a, int size) { +// char szbuf[8193] = {0}; +// int offset = check_illegal_count(a, size, szbuf); +// if (size > 100) { +// int step = size / 100; +// for (int i = 0; i < size; i = i + step) { +// offset += snprintf(szbuf + offset, 8192 - offset, "%f,", a[i]); +// } +// } else { +// for (int i = 0; i < size; ++ i) { +// offset += snprintf(szbuf + offset, 8192 - offset, "%f,", a[i]); +// } +// } +// fprintf(stdout, "[%d]%s(%p):%s\n", device, name, address, szbuf); +//} +// +// template +// void print_gpu_data(const char *name, const T *a, int size, int device, +// cudaStream_t stream) { +// T *buf = 0; +// cudaHostAlloc((void **)&buf, sizeof(T) * size, cudaHostAllocDefault); +// cudaMemcpyAsync(buf, a, size * sizeof(float), cudaMemcpyDeviceToHost, +// stream); +// cudaStreamSynchronize(stream); +// print_cpu_data(name, device, a, buf, size); +// cudaFreeHost(buf); +//} + +template +class CMixAllGatherOpCUDAKernel : public framework::OpKernel { + static const int NCCL_MIXALLGATHER = 1; + static const int NCCL_ALLGATHER = 2; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { +#if defined(PADDLE_WITH_NCCL) + auto in_tensors = ctx.MultiInput("Input"); + auto fused_tensor = ctx.Output("Output"); + // auto in_var_names = ctx.InputNames("Input"); + + int nranks = ctx.Attr("nranks"); + int rank_id = ctx.Attr("rankid"); + int nccl_mode = ctx.Attr("nccl_mode"); + int ring_id = ctx.Attr("ring_id"); + + auto place = ctx.GetPlace(); + + int device_id = boost::get(place).GetDeviceId(); + + size_t numel = 0; + auto dtype = + static_cast(in_tensors[0]->type()); + GetTensorMemSize(in_tensors, &numel); + + int64_t offset = 0; + size_t recv_len = 0; + T *recvbuff = nullptr; + T *sendbuff = nullptr; + + auto comm = platform::NCCLCommContext::Instance().Get(0, device_id); + int device_num = comm->nranks(); + + if (nccl_mode == NCCL_MIXALLGATHER) { // mixallgather + offset = numel * rank_id; + recvbuff = fused_tensor->mutable_data( + {static_cast(numel * nranks), 1}, place); + sendbuff = &recvbuff[offset]; + recv_len = numel * nranks; + } else if (nccl_mode == NCCL_ALLGATHER) { // allgather + offset = numel * (device_num * rank_id + device_id); + recvbuff = fused_tensor->mutable_data( + {static_cast(numel * nranks * device_num), 1}, place); + sendbuff = &recvbuff[offset]; + recv_len = numel * nranks * device_num; + } else { // allreduce + recvbuff = + fused_tensor->mutable_data({static_cast(numel), 1}, place); + sendbuff = recvbuff; + recv_len = numel; + } + CHECK(static_cast(recv_len) == fused_tensor->numel()); + + auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); + // copy input datas + for (size_t i = 0; i < in_tensors.size(); ++i) { + size_t len = static_cast(in_tensors[i]->numel()); + auto sub_tensor = fused_tensor->Slice(static_cast(offset), + static_cast(offset + len)); + framework::TensorCopy(*in_tensors[i], place, *dev_ctx, &sub_tensor); + offset += len; + } + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + stream = static_cast(dev_ctx)->stream(); + } else { + stream = static_cast(dev_ctx)->stream(); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + stream = comm->stream(); + } + + ncclDataType_t nccl_dtype = platform::ToNCCLDataType(dtype); + // reduce device 0 + if (nranks > 1) { // multi node + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + if (nccl_mode == NCCL_ALLGATHER) { // allgather + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + sendbuff, &recvbuff[numel * device_num * rank_id], numel, + nccl_dtype, comm->comm(), stream)); + if (device_id == 0) { + // node allgather + auto node_comm = + platform::NCCLCommContext::Instance().Get(ring_id, 0); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + &recvbuff[numel * device_num * rank_id], recvbuff, + numel * device_num, nccl_dtype, node_comm->comm(), stream)); + } + } else { // mixallgather allreduce + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclReduce(sendbuff, sendbuff, numel, nccl_dtype, + ncclSum, 0, comm->comm(), stream)); + if (device_id == 0) { + auto node_comm = + platform::NCCLCommContext::Instance().Get(ring_id, 0); + if (nccl_mode == NCCL_MIXALLGATHER) { + // allgather + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + sendbuff, recvbuff, numel, nccl_dtype, node_comm->comm(), + stream)); + } else { + // allreduce + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, numel, nccl_dtype, ncclSum, + node_comm->comm(), stream)); + } + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + // broadcast to all device + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( + recvbuff, recv_len, nccl_dtype, 0, comm->comm(), stream)); + } else { // single node + if (nccl_mode == NCCL_ALLGATHER) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + sendbuff, recvbuff, numel, nccl_dtype, comm->comm(), stream)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, numel, nccl_dtype, ncclSum, comm->comm(), + stream)); + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +// print_gpu_data("fuse_nccl", recvbuff, static_cast(recv_len), +// device_id, stream); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } + + protected: + void GetTensorMemSize( + const std::vector &lod_tensors, + size_t *numel) const { + *numel = 0; + for (size_t i = 0; i < lod_tensors.size(); ++i) { + CHECK(lod_tensors[i]->IsInitialized()); + *numel += lod_tensors[i]->numel(); + } + } +}; + +class CMixAllGatherOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("Input", + "(vector) The input tensors of mixallgather_tensor " + "operator.") + .AsDuplicable(); + AddOutput("Output", + "(LoDTensor) The output tensor " + "of mixallgather_tensor operator. And the tensors of" + " Output is sliced from the tensor of FusedOutput."); + AddAttr("rankid", "(int default 0) communication node id.") + .SetDefault(0); + AddAttr("nranks", "(int default 1) communication node num.") + .SetDefault(1); + AddAttr("nccl_mode", + "(int default 0) one node 0 allreduce, 1 mixallgather mode , " + "2 allgather mode.") + .SetDefault(0); + AddAttr("ring_id", "(int default -1) nccl ring id num.") + .SetDefault(-1); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(true); + AddComment(string::Sprintf(R"DOC( +MixAllGather %s Operator + +Call collective MixAllGather with reduce type %s. If input and output are +the same variable, in-place allreduce will be used. +Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allreduce +)DOC", + GetName(), GetName())); + } + + protected: + virtual std::string GetName() { return "MixAllGather"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(c_mixallgather, ops::CMixAllGatherOp, + ops::CMixAllGatherOpMaker); +REGISTER_OP_CPU_KERNEL(c_mixallgather, ops::CMixAllGatherOpCPUKernel, + ops::CMixAllGatherOpCPUKernel, + ops::CMixAllGatherOpCPUKernel, + ops::CMixAllGatherOpCPUKernel, + ops::CMixAllGatherOpCPUKernel); +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL(c_mixallgather, ops::CMixAllGatherOpCUDAKernel, + ops::CMixAllGatherOpCUDAKernel, + ops::CMixAllGatherOpCUDAKernel, + ops::CMixAllGatherOpCUDAKernel, + ops::CMixAllGatherOpCUDAKernel); +#endif diff --git a/paddle/fluid/operators/slice_multi_tensor_op.cc b/paddle/fluid/operators/slice_multi_tensor_op.cc new file mode 100644 index 0000000000000..8b6c678612901 --- /dev/null +++ b/paddle/fluid/operators/slice_multi_tensor_op.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_memory_aligment.h" + +namespace paddle { +namespace operators { + +template +class SliceMultiTensorOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto fuse_tensor = context.Input("Input"); + auto in_tensors = context.MultiInput("X"); + // Init the continuous space + auto out_tensors = context.MultiOutput("Output"); + + int id = context.Attr("id"); + int num = context.Attr("num"); + + size_t in_size = in_tensors.size(); + size_t out_size = out_tensors.size(); + // num data + CHECK(in_size == out_size || out_size / num == in_size); + + // Make the outputs point to the continuous space. + int64_t numel = fuse_tensor->numel(); + int64_t offset = (id * numel) / num; + + // fprintf(stdout, "fuse length: %d(dim: %s), in size: %d(dim: %s), + // offset: %d\n", + // int(fused_tensor->numel()), + // fused_tensor->dims().to_str().c_str(), + // int(in_tensors[0]->numel()), + // in_tensors[0]->dims().to_str().c_str(), + // int(offset)); + + auto &fuse_dim = fuse_tensor->dims(); + // adjust fuse + if (fuse_dim.size() > 1 && fuse_dim[0] != numel) { + paddle::framework::DDim dim(fuse_dim); + dim[0] = numel; + dim[1] = 1; + const_cast(fuse_tensor)->Resize(dim); + } + + for (size_t i = 0; i < out_tensors.size(); ++i) { + size_t idx = i % in_size; + auto dim = in_tensors[idx]->dims(); + size_t len = static_cast(in_tensors[idx]->numel()); + CHECK(static_cast(offset + len) <= numel) + << "fuse dim: " << fuse_dim.to_str() << ", dim:" << dim.to_str() + << ", offset:" << offset << ", len:" << len; + // slice tensor + out_tensors[i] + ->ShareDataWith(fuse_tensor->Slice( + static_cast(offset), static_cast(offset + len))) + .Resize(dim); + offset += len; + } + } +}; + +class SliceMultiTensorOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override {} + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } +}; + +class SliceMultiTensorOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "(LoDTensor) The input tensor of" + " slice_multi_tensor operator.") + .AsDuplicable(); + AddInput("X", + "(vector) The input tensor of" + " slice_multi_tensor operator.") + .AsDuplicable(); + AddOutput("Output", + "(vector) The output " + "tensors of slice_multi_tensor operator. And the address " + "of output tensors are continuous, they are sliced from the " + "tensor of FusedOutput.") + .AsDuplicable(); + AddAttr("id", "split id").SetDefault(0); + AddAttr("num", "split input tensor time").SetDefault(1); + AddComment(R"DOC( +SliceMultiTensor Operator. + +slice_multi_tensor is used split one ternsor to mulit child tensor + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(slice_multi_tensor, paddle::operators::SliceMultiTensorOp, + paddle::operators::SliceMultiTensorOpMaker); +REGISTER_OP_CPU_KERNEL( + slice_multi_tensor, + ops::SliceMultiTensorOpKernel, + ops::SliceMultiTensorOpKernel, + ops::SliceMultiTensorOpKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL( + slice_multi_tensor, + ops::SliceMultiTensorOpKernel, + ops::SliceMultiTensorOpKernel, + ops::SliceMultiTensorOpKernel, + ops::SliceMultiTensorOpKernel); +#endif