diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index fe29792b6e75c..2a5b158d315c3 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -29,15 +29,11 @@ namespace paddle { namespace framework { +/* --- Static maps to handle corner cases --- */ static std::unordered_map operators_with_attrs = {}; -static std::unordered_set operators_to_skip = { - "minus", -}; - static std::unordered_set operators_to_codegen = {}; -static std::unordered_set skipped_operators = {}; static std::string LegalizeVariableName(const std::string& var_name) { std::string ret = var_name; @@ -45,6 +41,132 @@ static std::string LegalizeVariableName(const std::string& var_name) { return ret; } +/* --- Helper Objects --- */ +class ForwardGenerationInfo { + public: + const std::string& GetOpType() const { return op_type_; } + void SetOpType(const std::string& op_type) { op_type_ = op_type; } + + const std::unordered_map& GetFwdInputsNamePosMap() + const { + return fwd_inputs_name_pos_map_; + } + std::unordered_map* GetMutableFwdInputsNamePosMap() { + return &fwd_inputs_name_pos_map_; + } + + const std::unordered_map& GetFwdOutputsNamePosMap() + const { + return fwd_outputs_name_pos_map_; + } + std::unordered_map* GetMutableFwdOutputsNamePosMap() { + return &fwd_outputs_name_pos_map_; + } + + const std::vector& GetInVars() const { return in_vars_; } + std::vector* GetMutableInVars() { return &in_vars_; } + + const std::vector& GetOutVars() const { + return out_vars_; + } + std::vector* GetMutableOutVars() { return &out_vars_; } + + private: + std::string op_type_; + std::unordered_map fwd_inputs_name_pos_map_; + std::unordered_map fwd_outputs_name_pos_map_; + std::vector in_vars_; + std::vector out_vars_; +}; + +class GradNodeGenerationInfo { + class OpBaseGenerationInfo { + public: + const std::string& GetOpBaseType() const { return op_base_type_; } + void SetOpBaseType(const std::string& op_type) { op_base_type_ = op_type; } + + const std::map& GetGradOutsSlotnameMap() const { + return grad_outs_slotname_map_; + } + std::map* GetMutableGradOutsSlotnameMap() { + return &grad_outs_slotname_map_; + } + + const std::map& GetGradInsFwdSlotnameMap() const { + return grad_ins_fwd_slotname_map_; + } + std::map* GetMutableGradInsFwdSlotnameMap() { + return &grad_ins_fwd_slotname_map_; + } + + const std::map& GetGradInsGradSlotnameMap() + const { + return grad_ins_grad_slotname_map_; + } + std::map* GetMutableGradInsGradSlotnameMap() { + return &grad_ins_grad_slotname_map_; + } + + const std::map< + std::string, + std::vector>>& + GetGradIns() const { + return grad_ins_; + } + std::map>>* + GetMutableGradIns() { + return &grad_ins_; + } + + const std::map< + std::string, + std::vector>>& + GetGradOuts() const { + return grad_outs_; + } + std::map>>* + GetMutableGradOuts() { + return &grad_outs_; + } + + private: + std::string op_base_type_; + std::map grad_outs_slotname_map_; + std::map grad_ins_fwd_slotname_map_; + std::map grad_ins_grad_slotname_map_; + std::map>> + grad_ins_; + std::map>> + grad_outs_; + }; + + public: + const std::string& GetFwdOpType() const { return fwd_op_type_; } + void SetFwdOpType(const std::string& op_type) { fwd_op_type_ = op_type; } + + bool GenerateForwardOnly() const { return generate_forward_only_; } + void SetGenerateForwardOnly(bool generate_forward_only) { + generate_forward_only_ = generate_forward_only; + } + + const std::vector& GetOpBaseInfos() const { + return op_base_infos_; + } + std::vector* GetMutableOpBaseInfos() { + return &op_base_infos_; + } + + private: + std::string fwd_op_type_; + bool generate_forward_only_ = false; + std::vector op_base_infos_; +}; + +/* --- Helper Functions --- */ static std::string AttrTypeToString(const proto::AttrType& type) { std::string ret; switch (type) { @@ -348,7 +470,6 @@ static bool CheckOpProto(proto::OpProto* op_proto) { VLOG(1) << "------ Analyzing Op ------: " << op_type; if (!operators_to_codegen.count(op_type)) return false; - if (operators_to_skip.count(op_type)) return false; return true; } @@ -356,15 +477,16 @@ static bool CheckOpProto(proto::OpProto* op_proto) { /* --------------------------------------- */ /* --------- Preprocess Ins/Outs --------- */ /* --------------------------------------- */ -static void PurifyForwardOpProto( - const proto::OpProto& op_proto, - std::unordered_map* fwd_inputs_name_pos_map, - std::unordered_map* fwd_outputs_name_pos_map, - std::vector* in_vars, - std::vector* out_vars) { +static void PurifyForwardOpProto(const proto::OpProto& op_proto, + ForwardGenerationInfo* fwd_info) { // Op Name const std::string op_name = op_proto.type(); + auto* in_vars = fwd_info->GetMutableInVars(); + auto* out_vars = fwd_info->GetMutableOutVars(); + auto* fwd_inputs_name_pos_map = fwd_info->GetMutableFwdInputsNamePosMap(); + auto* fwd_outputs_name_pos_map = fwd_info->GetMutableFwdOutputsNamePosMap(); + // Handle dispensable inputs for (const proto::OpProto::Var& input : op_proto.inputs()) { std::string input_name = input.name(); @@ -426,6 +548,104 @@ static void PurifyForwardOpProto( } } +static void PurifyGradNodeGenerationInfo(const proto::OpProto& op_proto, + GradNodeGenerationInfo* bwd_info) { + auto* op_base_infos = bwd_info->GetMutableOpBaseInfos(); + for (auto& iter : *op_base_infos) { + std::map* grad_outs_slotname_map = + iter.GetMutableGradOutsSlotnameMap(); + std::map* grad_ins_fwd_slotname_map = + iter.GetMutableGradInsFwdSlotnameMap(); + std::map* grad_ins_grad_slotname_map = + iter.GetMutableGradInsGradSlotnameMap(); + std::map>>* + grad_ins = iter.GetMutableGradIns(); + std::map>>* + grad_outs = iter.GetMutableGradOuts(); + + // Op Name + const std::string op_name = op_proto.type(); + + // Handle dispensable inputs + for (const proto::OpProto::Var& input : op_proto.inputs()) { + std::string input_name = input.name(); + + // Delete dispensable tensor unless specified in op_ins_map + if (input.dispensable()) { + if (!op_ins_map.count(op_name) || + !op_ins_map[op_name].count(input_name)) { + VLOG(6) << "Removing Dispensable Input: " << input_name; + + // grad_outs_slotname_map + auto grad_outs_slotname_map_purified = *grad_outs_slotname_map; + for (const auto& iter : *grad_outs_slotname_map) { + const std::string& grad_output_name = iter.first; + const std::string& matched_input_name = iter.second; + if (matched_input_name == input_name) { + grad_outs_slotname_map_purified.erase(grad_output_name); + + PADDLE_ENFORCE( + grad_outs->count(grad_output_name) > 0, + paddle::platform::errors::Fatal( + "Unable to find gradient output name in grad_outs.")); + // grad_outs + grad_outs->erase(grad_output_name); + } + } + *grad_outs_slotname_map = grad_outs_slotname_map_purified; + + // grad_ins_fwd_slotname_map: output as tensorwrapper + if (grad_ins_fwd_slotname_map->count(input_name)) + grad_ins_fwd_slotname_map->erase(input_name); + + // grad_ins: output as tensorwrapper + if (grad_ins->count(input_name)) grad_ins->erase(input_name); + } + } + } + + for (const proto::OpProto::Var& output : op_proto.outputs()) { + std::string output_name = output.name(); + + // Delete dispensable tensor unless specified in op_outs_map + if (output.dispensable()) { + if (!op_outs_map.count(op_name) || + !op_outs_map[op_name].count(output_name)) { + VLOG(6) << "Removing Dispensable Output: " << output_name; + + // grad_ins_grad_slotname_map + auto grad_ins_grad_slotname_map_purified = + *grad_ins_grad_slotname_map; + for (const auto& iter : *grad_ins_grad_slotname_map) { + const std::string& grad_input_name = iter.first; + const std::string& matched_output_name = iter.second; + if (matched_output_name == output_name) { + grad_ins_grad_slotname_map_purified.erase(grad_input_name); + + PADDLE_ENFORCE( + grad_ins->count(grad_input_name) > 0, + paddle::platform::errors::Fatal( + "Unable to find gradient input name in grad_ins.")); + // grad_ins + grad_ins->erase(grad_input_name); + } + } + *grad_ins_grad_slotname_map = grad_ins_grad_slotname_map_purified; + + // grad_ins_fwd_slotname_map: output as tensorwrapper + if (grad_ins_fwd_slotname_map->count(output_name)) + grad_ins_fwd_slotname_map->erase(output_name); + + // grad_ins: output as tensorwrapper + if (grad_ins->count(output_name)) grad_ins->erase(output_name); + } + } + } + } +} + static void PurifyGradOpProto( const proto::OpProto& op_proto, std::map* grad_outs_slotname_map, @@ -520,31 +740,22 @@ static void PurifyGradOpProto( /* --------- Collect Info --------- */ /* -------------------------------- */ static void CollectForwardInformationFromOpInfo( - const paddle::framework::OpInfo& op_info, - std::vector* in_vars, - std::vector* out_vars) { + const paddle::framework::OpInfo& op_info, ForwardGenerationInfo* fwd_info) { const proto::OpProto& op_proto = *op_info.proto_; + + fwd_info->SetOpType(op_proto.type()); + for (const proto::OpProto::Var& input : op_proto.inputs()) { - in_vars->push_back(input); + fwd_info->GetMutableInVars()->push_back(input); } for (const proto::OpProto::Var& output : op_proto.outputs()) { - out_vars->push_back(output); + fwd_info->GetMutableOutVars()->push_back(output); } } static bool CollectGradInformationFromOpInfo( - const paddle::framework::OpInfo& op_info, bool* generate_forward_only, - std::vector* grad_op_types, // grad - std::map* grad_outs_slotname_map, // grad - std::map* grad_ins_fwd_slotname_map, // grad - std::map* grad_ins_grad_slotname_map, // grad - std::map>>* - grad_ins, // grad - std::map>>* - grad_outs // grad - ) { + const paddle::framework::OpInfo& op_info, + GradNodeGenerationInfo* bwd_info) { const proto::OpProto& op_proto = *op_info.proto_; const std::string& op_type = op_proto.type(); std::vector dims = {1, 1, 1, 1}; @@ -645,7 +856,7 @@ static bool CollectGradInformationFromOpInfo( /* ------ Run GradOpMaker ------ */ if (!op_info.dygraph_grad_op_maker_) { VLOG(6) << op_type << " has no GradOpMaker"; - *generate_forward_only = true; + bwd_info->SetGenerateForwardOnly(true); return false; } @@ -656,32 +867,31 @@ static bool CollectGradInformationFromOpInfo( if (!grad_node) { VLOG(6) << "Got nullptr GradOpNode for " << op_type << " likely registered EmptyGradOpMaker"; - *generate_forward_only = true; + bwd_info->SetGenerateForwardOnly(true); return false; } - /* - if (grad_node->size() > 1) { - // Backward attributes can be super complicated - VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type; - skipped_operators.insert(op_type); - return false; - } - */ - VLOG(6) << "Prepared GradOpNode"; - /* ---- Collect Default Attr Map ---- */ + /* ---- Collect OpBase's op_types ---- */ + bwd_info->SetFwdOpType(op_type); + auto* op_base_infos = bwd_info->GetMutableOpBaseInfos(); + op_base_infos->resize(grad_node->size()); for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) { // Each OpBase + int index = std::distance(grad_node->begin(), iter); paddle::imperative::OpBase& op_base = *iter; - grad_op_types->push_back(op_base.Type()); + (*op_base_infos)[index].SetOpBaseType(op_base.Type()); } /* ------ Get Grad ins/outs ---- */ // In case of multiple OpBase, stitch all the respective ins/outs into one VLOG(6) << "In function size: " << grad_node->size(); for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) { + int index = std::distance(grad_node->begin(), iter); + auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns(); + auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts(); + const paddle::imperative::OpBase& op_base = *iter; const std::map& g_ins = op_base.GetInsMap(); @@ -689,34 +899,47 @@ static bool CollectGradInformationFromOpInfo( g_outs = op_base.GetOutsMap(); for (const auto& it : g_ins) { - if (!grad_ins->count(it.first)) (*grad_ins)[it.first] = {}; + if (!op_base_grad_ins->count(it.first)) + (*op_base_grad_ins)[it.first] = {}; + for (auto vw_iter = it.second.begin(); vw_iter != it.second.end(); vw_iter++) { std::shared_ptr vw = *vw_iter; - (*grad_ins)[it.first].push_back(vw); + + (*op_base_grad_ins)[it.first].push_back(vw); + + VLOG(6) << "GradIns Name: " << it.first; } } for (const auto& it : g_outs) { - if (!grad_outs->count(it.first)) (*grad_outs)[it.first] = {}; + if (!op_base_grad_outs->count(it.first)) + (*op_base_grad_outs)[it.first] = {}; + for (auto vw_iter = it.second.begin(); vw_iter != it.second.end(); vw_iter++) { std::shared_ptr vw = *vw_iter; - (*grad_outs)[it.first].push_back(vw); + + (*op_base_grad_outs)[it.first].push_back(vw); + + VLOG(6) << "GradOuts Name: " << it.first; } } } /* ------ Slot Name Matching ---- */ - // grad_ins -> fwd_ins, fwd_outs - SlotNameMatching(*grad_ins, fwd_ins, fwd_outs, grad_ins_fwd_slotname_map, - grad_ins_grad_slotname_map); - VLOG(6) << "Finished Slotname Matching for Grad_Ins"; - - // grad_outs -> fwd_ins, fwd_outs - SlotNameMatching(*grad_outs, fwd_ins, fwd_outs, grad_outs_slotname_map, - grad_outs_slotname_map); - VLOG(6) << "Finished Slotname Matching for Grad_Outs"; + for (auto& iter : *op_base_infos) { + // grad_ins -> fwd_ins, fwd_outs + SlotNameMatching(iter.GetGradIns(), fwd_ins, fwd_outs, + iter.GetMutableGradInsFwdSlotnameMap(), + iter.GetMutableGradInsGradSlotnameMap()); + + // grad_outs -> fwd_ins, fwd_outs + SlotNameMatching(iter.GetGradOuts(), fwd_ins, fwd_outs, + iter.GetMutableGradOutsSlotnameMap(), + iter.GetMutableGradOutsSlotnameMap()); + } + VLOG(6) << "Finished Slotname Matching"; return true; } @@ -725,13 +948,20 @@ static bool CollectGradInformationFromOpInfo( /* --------- CodeGen: Forward GradNode Creation ------ */ /* --------------------------------------------------- */ static std::string GenerateGradNodeCreationContent( - const std::unordered_map& fwd_inputs_name_pos_map, - const std::unordered_map& fwd_outputs_name_pos_map, - const std::map& grad_ins_fwd_slotname_map, - const std::string& op_type, const std::vector& in_vars, - const std::vector& out_vars) { + const ForwardGenerationInfo& fwd_info, + const GradNodeGenerationInfo& bwd_info) { VLOG(6) << "Generating GradNode Creation codes"; + const std::string& op_type = fwd_info.GetOpType(); + const std::unordered_map& fwd_inputs_name_pos_map = + fwd_info.GetFwdInputsNamePosMap(); + const std::unordered_map& fwd_outputs_name_pos_map = + fwd_info.GetFwdOutputsNamePosMap(); + const std::vector& in_vars = fwd_info.GetInVars(); + const std::vector& out_vars = fwd_info.GetOutVars(); + + const auto& op_base_infos = bwd_info.GetOpBaseInfos(); + // [Generation] Construct GradOpNode // Run ComputeRequiredGrad @@ -817,12 +1047,17 @@ static std::string GenerateGradNodeCreationContent( // [GradOpNode] Set TensorWrappers grad_node_creation_str += " // Set Tensor Wrappers\n"; - for (auto& kv : grad_ins_fwd_slotname_map) { - const std::string& tensor_wrapper_name = kv.second; - const char* SET_TENSOR_WRAPPER_TEMPLATE = - " grad_node->SetTensorWrapper%s(%s);\n"; - grad_node_creation_str += paddle::string::Sprintf( - SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, tensor_wrapper_name); + for (const auto& iter : op_base_infos) { + const std::map& grad_ins_fwd_slotname_map = + iter.GetGradInsFwdSlotnameMap(); + for (auto& kv : grad_ins_fwd_slotname_map) { + const std::string& tensor_wrapper_name = kv.second; + const char* SET_TENSOR_WRAPPER_TEMPLATE = + " grad_node->SetTensorWrapper%s(%s);\n"; + grad_node_creation_str += + paddle::string::Sprintf(SET_TENSOR_WRAPPER_TEMPLATE, + tensor_wrapper_name, tensor_wrapper_name); + } } grad_node_creation_str += "\n"; VLOG(6) << "Generated SetTensorWrapper"; @@ -892,22 +1127,17 @@ static std::string GenerateGradNodeCreationContent( /* --------- CodeGen: Forward ----- */ /* -------------------------------- */ static std::pair GenerateForwardFunctionContents( - bool generate_forward_only, - const std::unordered_map& fwd_inputs_name_pos_map, - const std::unordered_map& fwd_outputs_name_pos_map, - const std::map& grad_ins_fwd_slotname_map, - const std::map& grad_ins_grad_slotname_map, - const std::map& grad_outs_slotname_map, - const std::map< - std::string, - std::vector>>& - grad_ins, - const std::map< - std::string, - std::vector>>& - grad_outs, - const std::string& op_type, const std::vector& in_vars, - const std::vector& out_vars) { + const ForwardGenerationInfo& fwd_info, + const GradNodeGenerationInfo& bwd_info) { + /* --- Process Forward Info ---*/ + const std::string& op_type = fwd_info.GetOpType(); + const std::unordered_map& fwd_inputs_name_pos_map = + fwd_info.GetFwdInputsNamePosMap(); + const std::unordered_map& fwd_outputs_name_pos_map = + fwd_info.GetFwdOutputsNamePosMap(); + const std::vector& in_vars = fwd_info.GetInVars(); + const std::vector& out_vars = fwd_info.GetOutVars(); + /* // Forward Function Example: std::tuple, Tensor, vector> @@ -999,24 +1229,53 @@ static std::pair GenerateForwardFunctionContents( for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); std::string outnum = "1"; - if (output.duplicable()) { - outnum = output_name + "Num"; - - const char* FWD_NUM_ARG_TEMPLATE = ", size_t %s"; - std::string arg_str = - paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum); - dygraph_function_args_str += arg_str; - const char* FWD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },"; - outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, - output_name, outnum); + if (op_passing_outs_map[op_type].count(output_name)) { + const std::string output_var_name = output_name + "Var"; + + // Pass Output from function argument, + // in form of shared_ptr/vector> + if (output.duplicable()) { + const char* FWD_NUM_ARG_TEMPLATE = + ", std::vector>& %s"; + std::string arg_str = + paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); + dygraph_function_args_str += arg_str; + + const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },"; + outs_contents_str += paddle::string::Sprintf( + FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name); + } else { + const char* FWD_NUM_ARG_TEMPLATE = + ", std::shared_ptr& %s"; + std::string arg_str = + paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); + dygraph_function_args_str += arg_str; + + const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", {%s} },"; + outs_contents_str += paddle::string::Sprintf( + FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name); + } + } else { - const char* FWD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", " - "{std::make_shared(egr::Controller::Instance()." - "GenerateUniqueName())}},"; - outs_contents_str += - paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name); + if (output.duplicable()) { + outnum = output_name + "Num"; + + const char* FWD_NUM_ARG_TEMPLATE = ", size_t %s"; + std::string arg_str = + paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum); + dygraph_function_args_str += arg_str; + const char* FWD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },"; + outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, + output_name, outnum); + } else { + const char* FWD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", " + "{std::make_shared(egr::Controller::Instance()." + "GenerateUniqueName())}},"; + outs_contents_str += + paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name); + } } } if (outs_contents_str.size() > 0) @@ -1084,10 +1343,9 @@ static std::pair GenerateForwardFunctionContents( VLOG(6) << "Converted Output VarBase to EagerTensor(s)"; // [Generation] ComputeRequireGrad -> GradNodeCreation - if (!generate_forward_only) { - std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( - fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, - grad_ins_fwd_slotname_map, op_type, in_vars, out_vars); + if (!bwd_info.GenerateForwardOnly()) { + std::string grad_node_creation_body_str = + GenerateGradNodeCreationContent(fwd_info, bwd_info); generated_function_body += grad_node_creation_body_str; generated_function_body += "\n"; VLOG(6) << "Generated GradNode Creation codes"; @@ -1162,22 +1420,16 @@ static std::pair GenerateForwardFunctionContents( /* --------- CodeGen: GradNode::operator() ------ */ /* ---------------------------------------------- */ static std::string GenerateGradNodeCCContents( - const std::vector& grad_op_types, - const std::unordered_map& fwd_inputs_name_pos_map, - const std::unordered_map& fwd_outputs_name_pos_map, - const std::map& grad_ins_fwd_slotname_map, - const std::map& grad_ins_grad_slotname_map, - const std::map& grad_outs_slotname_map, - const std::map< - std::string, - std::vector>>& - grad_ins, - const std::map< - std::string, - std::vector>>& - grad_outs, - const std::string& op_type, const std::vector& in_vars, - const std::vector& out_vars) { + const ForwardGenerationInfo& fwd_info, + const GradNodeGenerationInfo& bwd_info) { + /* --- Process Forward Info --- */ + const std::string& fwd_op_type = fwd_info.GetOpType(); + const std::unordered_map& fwd_inputs_name_pos_map = + fwd_info.GetFwdInputsNamePosMap(); + const std::unordered_map& fwd_outputs_name_pos_map = + fwd_info.GetFwdOutputsNamePosMap(); + const std::vector& in_vars = fwd_info.GetInVars(); + VLOG(6) << "Generating Grad Node CC"; /* [Outline] @@ -1224,227 +1476,247 @@ static std::string GenerateGradNodeCCContents( */ std::string generated_grad_function_body = ""; + size_t outs_size = 0; + const auto& op_base_infos = bwd_info.GetOpBaseInfos(); + for (size_t i = 0; i < op_base_infos.size(); i++) { + const auto& op_base_info = op_base_infos[i]; + + const auto& grad_ins_fwd_slotname_map = + op_base_info.GetGradInsFwdSlotnameMap(); + const auto& grad_ins_grad_slotname_map = + op_base_info.GetGradInsGradSlotnameMap(); + const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap(); + const auto& grad_ins = op_base_info.GetGradIns(); + const auto& grad_outs = op_base_info.GetGradOuts(); + + const std::string& op_base_type = op_base_info.GetOpBaseType(); + const std::string& ins_name = "ins" + std::to_string(i); + const std::string& outs_name = "outs" + std::to_string(i); + + outs_size += grad_outs.size(); + + // [Generation] Get Ins Map + std::string ins_contents_str = ""; + for (auto iter : grad_ins) { + const std::string& grad_input_name = iter.first; + + if (grad_ins_fwd_slotname_map.count(grad_input_name)) { + // Fwd Tensor + std::string struct_fwd_input_name = + grad_ins_fwd_slotname_map.at(grad_input_name) + "_"; + const char* GRAD_INS_FWD_CONTENT_TEMPLATE = + "{ \"%s\", " + "egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(" + "&" + "this->%s, " + "nullptr)) },"; + ins_contents_str += + paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, + grad_input_name, struct_fwd_input_name); + + } else if (grad_ins_grad_slotname_map.count(grad_input_name)) { + // Fwd Tensor's Grad + size_t fwd_output_position = fwd_outputs_name_pos_map.at( + grad_ins_grad_slotname_map.at(grad_input_name)); + const char* GRAD_INS_GRAD_CONTENT_TEMPLATE = + "{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },"; + ins_contents_str += + paddle::string::Sprintf(GRAD_INS_GRAD_CONTENT_TEMPLATE, + grad_input_name, fwd_output_position); - // [Generation] Get Tracer - generated_grad_function_body += "\n"; - generated_grad_function_body += "\n"; - - // [Generation] Get Ins Map - std::string ins_contents_str = ""; - for (auto iter : grad_ins) { - const std::string& grad_input_name = iter.first; - - if (grad_ins_fwd_slotname_map.count(grad_input_name)) { - // Fwd Tensor - std::string struct_fwd_input_name = - grad_ins_fwd_slotname_map.at(grad_input_name) + "_"; - const char* GRAD_INS_FWD_CONTENT_TEMPLATE = - "{ \"%s\", " - "egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&" - "this->%s, " - "nullptr)) },"; - ins_contents_str += - paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, - grad_input_name, struct_fwd_input_name); - - } else if (grad_ins_grad_slotname_map.count(grad_input_name)) { - // Fwd Tensor's Grad - size_t fwd_output_position = fwd_outputs_name_pos_map.at( - grad_ins_grad_slotname_map.at(grad_input_name)); - const char* GRAD_INS_GRAD_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },"; - ins_contents_str += paddle::string::Sprintf( - GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position); - - } else { - PADDLE_THROW(platform::errors::Fatal( - "Detected mismatched slot names." - "Unable to find forward slot name that matches %s", - grad_input_name)); + } else { + PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." + "Unable to find forward slot name that matches %s", + grad_input_name)); + } + } + if (ins_contents_str.size() > 0) + ins_contents_str.pop_back(); // // Remove trailing "," + + const char* BWD_INS_MAP_TEMPLATE = + " std::map>> %s = { " + "%s };\n"; + std::string ins_map_str = paddle::string::Sprintf( + BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str); + generated_grad_function_body += ins_map_str; + + VLOG(6) << "Generated Ins Map"; + + // [Generation] Get Outs Map + std::unordered_set duplicable_input_name_set; + for (const auto& in : in_vars) { + if (in.duplicable()) duplicable_input_name_set.insert(in.name()); } - } - if (ins_contents_str.size() > 0) - ins_contents_str.pop_back(); // // Remove trailing "," - - const char* BWD_INS_MAP_TEMPLATE = - " std::map>> ins = { " - "%s };\n"; - std::string ins_map_str = - paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_contents_str); - generated_grad_function_body += ins_map_str; - - VLOG(6) << "Generated Ins Map"; - - // [Generation] Get Outs Map - std::unordered_set duplicable_input_name_set; - for (const auto& in : in_vars) { - if (in.duplicable()) duplicable_input_name_set.insert(in.name()); - } - - std::string outs_contents_str = ""; - for (auto iter : grad_outs) { - const std::string& grad_output_name = iter.first; - - if (grad_outs_slotname_map.count(grad_output_name)) { - // Fwd Tensor - const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name); - - /* Handle Special Case: "PullSparseOp", etc - - Forward: - - Ids W - | | - PullSparseOp - | - Out - - Backward: - - Ids GradOut W - | | | - PullSparseGradOp - | - GradOut - - Its grad output "GradOut" corresponds to forward output "Out", - where there is a hiden inplace involved. So we find "GradOut"'s index - in - grads, and perform the inplace operation by constructing outs = - {{"Out", grads[i]}} - - GradOut -> Out -> fwd_output_pos -> grads position -> grads[i] - outs = {{"Out", grads[i]}} - - For returns, append "GradOut" to the very end of return list. - */ - if (!fwd_inputs_name_pos_map.count(fwd_name)) { - PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), - paddle::platform::errors::Fatal( - "fwd_name not found in fwd_inputs_name_pos_map nor " - "fwd_outputs_name_pos_map")); - - size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name); - std::string grad_ptr_name = fwd_name + "_ptrs"; - const char* GET_GRADS_PTR_TEMPLATE = - " std::vector> %s;\n" - " for(const auto& t : grads[%d]) {\n " - "%s.emplace_back(std::move(std::make_shared(t)));" - "\n }\n"; - std::string grads_ptr_str = - paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name, - grads_position, grad_ptr_name); - generated_grad_function_body += grads_ptr_str; - generated_grad_function_body += "\n"; - - const char* GRAD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },"; - outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grad_ptr_name); - } else { - size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); - if (duplicable_input_name_set.count(fwd_name)) { - const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " - "this->OutputMeta()[%d].Size() ) },"; + std::string outs_contents_str = ""; + for (auto iter : grad_outs) { + const std::string& grad_output_name = iter.first; + + if (grad_outs_slotname_map.count(grad_output_name)) { + // Fwd Tensor + const std::string& fwd_name = + grad_outs_slotname_map.at(grad_output_name); + + /* Handle Special Case: "PullSparseOp", etc + + Forward: + + Ids W + | | + PullSparseOp + | + Out + + Backward: + + Ids GradOut W + | | | + PullSparseGradOp + | + GradOut + + Its grad output "GradOut" corresponds to forward output "Out", + where there is a hiden inplace involved. So we find "GradOut"'s + index + in + grads, and perform the inplace operation by constructing outs = + {{"Out", grads[i]}} + + GradOut -> Out -> fwd_output_pos -> grads position -> grads[i] + outs = {{"Out", grads[i]}} + + For returns, append "GradOut" to the very end of return list. + */ + if (!fwd_inputs_name_pos_map.count(fwd_name)) { + PADDLE_ENFORCE( + fwd_outputs_name_pos_map.count(fwd_name), + paddle::platform::errors::Fatal( + "fwd_name not found in fwd_inputs_name_pos_map nor " + "fwd_outputs_name_pos_map")); + + size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name); + std::string grad_ptr_name = fwd_name + "_ptrs"; + const char* GET_GRADS_PTR_TEMPLATE = + " std::vector> %s;\n" + " for(const auto& t : grads[%d]) {\n " + "%s.emplace_back(std::move(std::make_shared(t))" + ");" + "\n }\n"; + std::string grads_ptr_str = + paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name, + grads_position, grad_ptr_name); + generated_grad_function_body += grads_ptr_str; + generated_grad_function_body += "\n"; + + const char* GRAD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },"; outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grad_ptr_name); + } else { - const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", " - "{std::make_shared(egr::Controller::Instance()." - "GenerateUniqueName())}},"; - outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name); + size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); + if (duplicable_input_name_set.count(fwd_name)) { + const char* GRAD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " + "this->OutputMeta()[%d].Size() ) },"; + outs_contents_str += + paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, + grad_output_name, fwd_input_position); + } else { + const char* GRAD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", " + "{std::make_shared(egr::Controller::Instance(" + ")." + "GenerateUniqueName())}},"; + outs_contents_str += paddle::string::Sprintf( + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name); + } } + } else { + PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." + "Unable to find forward slot name that matches %s", + grad_output_name)); } - } else { - PADDLE_THROW(platform::errors::Fatal( - "Detected mismatched slot names." - "Unable to find forward slot name that matches %s", - grad_output_name)); } - } - if (outs_contents_str.size() > 0) - outs_contents_str.pop_back(); // // Remove trailing "," + if (outs_contents_str.size() > 0) + outs_contents_str.pop_back(); // // Remove trailing "," - const char* BWD_OUTS_MAP_TEMPLATE = - " std::map>> outs = { " - "%s };\n"; - std::string outs_map_str = - paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_contents_str); - generated_grad_function_body += outs_map_str; - generated_grad_function_body += "\n"; - - VLOG(6) << "Generated Outs Map"; + const char* BWD_OUTS_MAP_TEMPLATE = + " std::map>> %s = { " + "%s };\n"; + std::string outs_map_str = paddle::string::Sprintf( + BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str); + generated_grad_function_body += outs_map_str; + generated_grad_function_body += "\n"; - // [Generation] Get Attrs Map - std::string trace_opbase_str = ""; - for (size_t i = 0; i < grad_op_types.size(); i++) { - const std::string& op_base_type = grad_op_types[i]; + VLOG(6) << "Generated Outs Map"; + // [Generation] Get Attrs Map const char* TRACE_OP_TEMPLATE = " // Pass the entire attribute map to TraceOp\n" " // The underlying kernel will pickup whatever attribute they need " "at runtime\n" - " egr::legacy::RunOp(\"%s\", ins, outs, this->attr_map_,\n" + " egr::legacy::RunOp(\"%s\", %s, %s, this->attr_map_,\n" " egr::Controller::Instance().GetExpectedPlace(),\n" " &this->default_attr_map_, false, {});\n"; - trace_opbase_str = paddle::string::Sprintf(TRACE_OP_TEMPLATE, op_base_type); - } + std::string trace_opbase_str = paddle::string::Sprintf( + TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name); - generated_grad_function_body += trace_opbase_str; + generated_grad_function_body += trace_opbase_str; - VLOG(6) << "Generated Attrs Map"; + VLOG(6) << "Generated Attrs Map"; - // [Generation] Get Return - std::string outputs_str = ""; - size_t num_appended_outputs = 0; - for (auto iter : grad_outs) { - const std::string& grad_out_name = iter.first; - const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); + // [Generation] Get Return + std::string outputs_str = ""; + size_t num_appended_outputs = 0; + for (auto iter : grad_outs) { + const std::string& grad_out_name = iter.first; + const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); - if (fwd_inputs_name_pos_map.count(fwd_name)) { - size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); - const char* BWD_OUTPUT_TEMPLATE = - " outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; - outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, - fwd_input_position, grad_out_name); - num_appended_outputs++; - } else { - PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), - paddle::platform::errors::Fatal( - "fwd_name not found in fwd_inputs_name_pos_map nor " - "fwd_outputs_name_pos_map")); + if (fwd_inputs_name_pos_map.count(fwd_name)) { + size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); + const char* BWD_OUTPUT_TEMPLATE = + " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; + outputs_str += paddle::string::Sprintf( + BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name); + num_appended_outputs++; + } else { + PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), + paddle::platform::errors::Fatal( + "fwd_name not found in fwd_inputs_name_pos_map nor " + "fwd_outputs_name_pos_map")); + } } - } - /* Handle Special Case: "PullSparseOp", etc - For returns, append "GradOut" to the very end of return list. */ - for (auto iter : grad_outs) { - const std::string& grad_out_name = iter.first; - const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); - - if (fwd_outputs_name_pos_map.count(fwd_name)) { - const char* BWD_OUTPUT_TEMPLATE = - " outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; - outputs_str += paddle::string::Sprintf( - BWD_OUTPUT_TEMPLATE, num_appended_outputs, grad_out_name); - num_appended_outputs++; + /* Handle Special Case: "PullSparseOp", etc + For returns, append "GradOut" to the very end of return list. */ + for (auto iter : grad_outs) { + const std::string& grad_out_name = iter.first; + const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); + + if (fwd_outputs_name_pos_map.count(fwd_name)) { + const char* BWD_OUTPUT_TEMPLATE = + " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; + outputs_str += + paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, num_appended_outputs, + outs_name, grad_out_name); + num_appended_outputs++; + } } + + generated_grad_function_body += outputs_str; + generated_grad_function_body += "\n"; } const char* BWD_RETURN_TEMPLATE = - " std::vector> " - "outputs(outs.size());\n%s\n " - "return outputs;"; - std::string return_str = - paddle::string::Sprintf(BWD_RETURN_TEMPLATE, outputs_str); - - generated_grad_function_body += "\n"; - generated_grad_function_body += return_str; + " std::vector> outputs(%d);\n" + " %s\n" + " return outputs;\n"; + generated_grad_function_body = paddle::string::Sprintf( + BWD_RETURN_TEMPLATE, outs_size, generated_grad_function_body); // [Generation] Get Full Grad Function const char* GRAD_FUNCTION_TEMPLATE = @@ -1452,7 +1724,7 @@ static std::string GenerateGradNodeCCContents( "GradNode%s::operator()(const " "std::vector>& grads) {\n%s\n}"; std::string grad_function_str = paddle::string::Sprintf( - GRAD_FUNCTION_TEMPLATE, op_type, generated_grad_function_body); + GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body); VLOG(6) << "Generated returns"; @@ -1463,9 +1735,14 @@ static std::string GenerateGradNodeCCContents( /* --------- CodeGen: GradNode Header ------ */ /* ----------------------------------------- */ static std::string GenerateGradNodeHeaderContents( - const std::map& grad_ins_fwd_slotname_map, - const std::string& op_type, const std::vector& in_vars, - const std::vector& out_vars) { + const ForwardGenerationInfo& fwd_info, + const GradNodeGenerationInfo& bwd_info) { + const std::string& op_type = fwd_info.GetOpType(); + const std::vector& in_vars = fwd_info.GetInVars(); + const std::vector& out_vars = fwd_info.GetOutVars(); + + const auto& op_base_infos = bwd_info.GetOpBaseInfos(); + VLOG(6) << "Generating Grad Node Header"; const char* GRAD_NODE_TEMPLATE = @@ -1522,55 +1799,60 @@ static std::string GenerateGradNodeHeaderContents( std::string set_tensor_wrappers_str = ""; std::string tensor_wrapper_members_str = ""; - for (const auto& kv : grad_ins_fwd_slotname_map) { - const std::string& tensor_wrapper_name = kv.second; - const std::string& struct_tensor_wrapper_name = kv.second + "_"; - - std::string tensor_wrapper_arg_str; - std::string tensor_wrapper_body_str; - if (duplicable_tensors.count(tensor_wrapper_name)) { - const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = - "const std::vector& %s"; - tensor_wrapper_arg_str = paddle::string::Sprintf( - ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); - - const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = - " std::vector %s;\n"; - tensor_wrapper_members_str += paddle::string::Sprintf( - TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); - - const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = - "for(const auto& eager_tensor : %s) {\n" - " %s.emplace_back( egr::TensorWrapper(eager_tensor, true " - "/*full_reserved*/) );\n" - " }\n"; - tensor_wrapper_body_str = paddle::string::Sprintf( - SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, - struct_tensor_wrapper_name); + for (const auto& iter : op_base_infos) { + const std::map& grad_ins_fwd_slotname_map = + iter.GetGradInsFwdSlotnameMap(); + + for (const auto& kv : grad_ins_fwd_slotname_map) { + const std::string& tensor_wrapper_name = kv.second; + const std::string& struct_tensor_wrapper_name = kv.second + "_"; + + std::string tensor_wrapper_arg_str; + std::string tensor_wrapper_body_str; + if (duplicable_tensors.count(tensor_wrapper_name)) { + const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = + "const std::vector& %s"; + tensor_wrapper_arg_str = paddle::string::Sprintf( + ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); + + const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = + " std::vector %s;\n"; + tensor_wrapper_members_str += paddle::string::Sprintf( + TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); + + const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = + "for(const auto& eager_tensor : %s) {\n" + " %s.emplace_back( egr::TensorWrapper(eager_tensor, true " + "/*full_reserved*/) );\n" + " }\n"; + tensor_wrapper_body_str = paddle::string::Sprintf( + SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, + struct_tensor_wrapper_name); - } else { - const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = - "const egr::EagerTensor& %s"; - tensor_wrapper_arg_str = paddle::string::Sprintf( - ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); - - const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = - " egr::TensorWrapper %s;\n"; - tensor_wrapper_members_str += paddle::string::Sprintf( - TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); - - const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = - "%s = egr::TensorWrapper(%s, true /*full_reserved*/);"; - tensor_wrapper_body_str = paddle::string::Sprintf( - SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, - tensor_wrapper_name); - } - - const char* SET_TENSOR_WRAPPER_TEMPLATE = - " void SetTensorWrapper%s(%s) {\n %s\n }\n"; - set_tensor_wrappers_str += paddle::string::Sprintf( - SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, - tensor_wrapper_arg_str, tensor_wrapper_body_str); + } else { + const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = + "const egr::EagerTensor& %s"; + tensor_wrapper_arg_str = paddle::string::Sprintf( + ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); + + const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = + " egr::TensorWrapper %s;\n"; + tensor_wrapper_members_str += paddle::string::Sprintf( + TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); + + const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = + "%s = egr::TensorWrapper(%s, true /*full_reserved*/);"; + tensor_wrapper_body_str = paddle::string::Sprintf( + SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, + tensor_wrapper_name); + } + + const char* SET_TENSOR_WRAPPER_TEMPLATE = + " void SetTensorWrapper%s(%s) {\n %s\n }\n"; + set_tensor_wrappers_str += paddle::string::Sprintf( + SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, + tensor_wrapper_arg_str, tensor_wrapper_body_str); + } } VLOG(6) << "Generated TensorWrapper"; @@ -1682,97 +1964,62 @@ static void DygraphCodeGeneration(const std::string& output_dir) { /* ----------------------------- */ /* ---- Collect Information ---- */ /* ----------------------------- */ - std::vector grad_op_types; - std::vector in_vars; - std::vector out_vars; - std::map grad_outs_slotname_map; - std::map grad_ins_fwd_slotname_map; - std::map grad_ins_grad_slotname_map; - std::map>> - grad_ins; - std::map>> - grad_outs; + + ForwardGenerationInfo fwd_info; + GradNodeGenerationInfo bwd_info; VLOG(6) << "-------- CollectInformationFromOpInfo -------"; - CollectForwardInformationFromOpInfo(op_info, &in_vars, &out_vars); + CollectForwardInformationFromOpInfo(op_info, &fwd_info); - bool generate_forward_only = false; - bool is_available = CollectGradInformationFromOpInfo( - op_info, &generate_forward_only, &grad_op_types, - &grad_outs_slotname_map, &grad_ins_fwd_slotname_map, - &grad_ins_grad_slotname_map, &grad_ins, &grad_outs); + bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info); - if (!is_available && !generate_forward_only) { + if (!is_available && !bwd_info.GenerateForwardOnly()) { VLOG(6) << "Skipped operator: " << op_type; continue; } VLOG(6) << "-------- PurifyOpProto -------"; - std::unordered_map fwd_inputs_name_pos_map; - std::unordered_map fwd_outputs_name_pos_map; - PurifyForwardOpProto(*op_proto, &fwd_inputs_name_pos_map, - &fwd_outputs_name_pos_map, &in_vars, &out_vars); - - if (!generate_forward_only) { - PurifyGradOpProto(*op_proto, &grad_outs_slotname_map, - &grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, - &grad_ins, &grad_outs); + PurifyForwardOpProto(*op_proto, &fwd_info); + if (!bwd_info.GenerateForwardOnly()) { + PurifyGradNodeGenerationInfo(*op_proto, &bwd_info); } /* --------------------------- */ /* --------- CodeGen --------- */ /* --------------------------- */ - /* ---- forward_dygraph_functions.cc ---- */ VLOG(6) << "-------- GenerateForwardFunctionContents -------"; std::pair body_and_declaration = - GenerateForwardFunctionContents( - generate_forward_only, fwd_inputs_name_pos_map, - fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, - grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, - grad_outs, op_type, in_vars, out_vars); + GenerateForwardFunctionContents(fwd_info, bwd_info); fwd_function_str += body_and_declaration.first + "\n"; - /* ---- dygraph_forward_api.h ---- */ + VLOG(6) << "-------- GenerateDygraphForwardAPIContents -------"; std::string fwd_function_declare_str = body_and_declaration.second; dygraph_forward_api_str += fwd_function_declare_str; - if (generate_forward_only) continue; + if (bwd_info.GenerateForwardOnly()) continue; - /* ---- nodes.h ---- */ VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; - grad_node_h_str += - GenerateGradNodeHeaderContents(grad_ins_fwd_slotname_map, op_type, - in_vars, out_vars) + - "\n"; + grad_node_h_str += GenerateGradNodeHeaderContents(fwd_info, bwd_info); + grad_node_h_str += "\n"; - /* ---- nodes.cc ---- */ VLOG(6) << "-------- GenerateGradNodeCCContents -------"; - grad_node_cc_str += GenerateGradNodeCCContents( - grad_op_types, fwd_inputs_name_pos_map, - fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, - grad_ins_grad_slotname_map, grad_outs_slotname_map, - grad_ins, grad_outs, op_type, in_vars, out_vars) + - "\n"; + grad_node_cc_str += GenerateGradNodeCCContents(fwd_info, bwd_info); + grad_node_cc_str += "\n"; VLOG(6) << op_type << ": Finished Generating Op: " << op_type; } - /* ---- dygraph_forward_function.cc ---- */ + VLOG(6) << "-------- GenerateDygraphForwardCCFile -------"; GenerateForwardDygraphFile(output_dir, fwd_function_str); - /* ---- dygraph_forward_api.h ---- */ VLOG(6) << "-------- GenerateForwardHFile -------"; GenerateForwardHFile(output_dir, dygraph_forward_api_str); - /* ---- nodes.h ---- */ VLOG(6) << "-------- GenerateNodeHFile -------"; GenerateNodeHFile(output_dir, grad_node_h_str); - /* ---- nodes.cc ---- */ VLOG(6) << "-------- GenerateNodeCCFile -------"; GenerateNodeCCFile(output_dir, grad_node_cc_str); } diff --git a/paddle/fluid/eager/auto_code_generator/op_list.txt b/paddle/fluid/eager/auto_code_generator/op_list.txt index 699a84169d700..d3e835a1d0355 100644 --- a/paddle/fluid/eager/auto_code_generator/op_list.txt +++ b/paddle/fluid/eager/auto_code_generator/op_list.txt @@ -237,6 +237,7 @@ spp floor gelu retinanet_detection_output +minus push_dense silu sequence_erase