Skip to content

Commit

Permalink
build a Paddle Graph from CINN compiled program for execution with PE (
Browse files Browse the repository at this point in the history
…#39724)

* build a Paddle Graph from CINN compiled program for execution with PE

* update names of some variables

* fix random fail in build_cinn_pass_test and update some comments

* fix compiler error by merging phi pr
  • Loading branch information
CtfGo authored Feb 24, 2022
1 parent df0b443 commit 4d042a8
Show file tree
Hide file tree
Showing 10 changed files with 477 additions and 223 deletions.
9 changes: 1 addition & 8 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ DECLARE_string(deny_cinn_ops);

namespace paddle {
namespace framework {

namespace ir {
class MemOptVarInfo;
} // namespace ir

namespace paddle2cinn {

using framework::ir::Graph;
Expand Down Expand Up @@ -398,9 +393,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
kNoNeedBufferFeeds, no_need_buffer_feeds.release());
// initialize empty map for kMemOptVarInfoFromMainGraph attribute,
// it will be filled on the share_mem_opt_info_to_subgraph pass
subgraph->GetOrInit<std::unordered_map<
std::string, std::shared_ptr<framework::ir::MemOptVarInfo>>>(
kMemOptVarInfoFromMainGraph);
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
return subgraph;
}

Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ limitations under the License. */

namespace paddle {
namespace framework {
namespace ir {
class MemOptVarInfo;
} // namespace ir

namespace paddle2cinn {

constexpr char kCinnLaunchOp[] = "cinn_launch";
Expand All @@ -27,6 +31,9 @@ constexpr char kInternalVars[] = "InternalVars";
constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph";
using Name2VarInfoMap =
std::unordered_map<std::string,
std::shared_ptr<framework::ir::MemOptVarInfo>>;

// A pass named BuildCinnPass, the function of this pass is:
//
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_EQ(
std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()),
std::unordered_set<Node*>({v0, v1, v2, v4}));
ASSERT_EQ(cinn_op->outputs, std::vector<Node*>({v6, v7}));
ASSERT_EQ(std::unordered_set<Node*>(cinn_op->outputs.begin(),
cinn_op->outputs.end()),
std::unordered_set<Node*>({v6, v7}));
ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op}));
ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op}));

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,10 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
*compiled_obj = {std::move(graph_compiler),
std::move(compiled_res.runtime_program), scope,
symbol.var_model_to_program_map()};
compiled_obj->launch_context =
std::make_unique<operators::details::CinnLaunchContext>(
compiled_obj->paddle2cinn_varmap, compiled_obj->scope);
compiled_obj->cached_index = compiled_num;
compiled_obj->launch_context =
std::make_unique<operators::details::CinnLaunchContext>(graph,
*compiled_obj);
return compiled_obj;
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
include(operators)

cc_library(cinn_op_helper SRCS cinn_op_helper.cc DEPS operator device_context)
cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope cinn)
cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope proto_desc graph build_strategy parallel_executor cinn)

SET(CINN_OP_DEPS string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context)
register_operators(DEPS ${CINN_OP_DEPS})

if (WITH_TESTING)
cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS ddim lod_tensor scope cinn_launch_context)
cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS ddim lod_tensor scope proto_desc graph cinn_launch_context cinn_instruction_run_op cinn)
set_tests_properties(cinn_launch_context_test PROPERTIES LABELS "RUN_TYPE=CINN")

SET(CINN_RUN_ENVIRONMENT "OMP_NUM_THREADS=1;runtime_include_dir=${PADDLE_BINARY_DIR}/third_party/CINN/src/external_cinn/cinn/runtime/cuda")
Expand Down
Loading

0 comments on commit 4d042a8

Please sign in to comment.