Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the GetFetchNames method in CinnGraphSymbolization. #37218

Merged
merged 7 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

Expand Down Expand Up @@ -381,16 +382,16 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
input_names.emplace_back(n->Name());
}
});
cinn_op_desc.SetInput("X", input_names);
cinn_op_desc.SetInput(operators::kX, input_names);
std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names, &deny_var_set](Node* n) {
if (n->Var() != nullptr && !deny_var_set.count(n->Name())) {
output_names.emplace_back(n->Name());
}
});
cinn_op_desc.SetOutput("Out", output_names);
cinn_op_desc.SetAttr(kCompilationKey, compilation_key);
cinn_op_desc.SetOutput(operators::kOutputs, output_names);
cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster));
cinn_op_desc.Flush();
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/paddle2cinn/build_cinn_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ namespace framework {
namespace paddle2cinn {

constexpr char kCinnLaunchOp[] = "cinn_launch";
constexpr char kCompilationKey[] = "compilation_key";

// A pass named BuildCinnPass, the function of this pass is:
//
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/operators/cinn_launch_op.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -91,8 +92,8 @@ std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
compilation_keys.emplace_back(BOOST_GET_CONST(
std::string, node->Op()->GetAttr(operators::kCompilationKey)));
}
}
return compilation_keys;
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
ApplyPass(cinn_graph.get(), "OpFusion");
auto scope = BuildScope(target, cinn_graph);

auto fetch_ids = symbol.GetFetchIds();
VLOG(4) << "All fetch var ids in CINN: "
<< string::join_strings(fetch_ids, ',');

auto graph_compiler =
std::make_unique<GraphCompiler>(target, scope, cinn_graph);
GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false;
auto compiled_res = graph_compiler->Build(options);
auto compiled_res = graph_compiler->Build(options, std::move(fetch_ids));
auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(graph_compiler),
std::move(compiled_res.runtime_program), scope,
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"

Expand Down Expand Up @@ -62,8 +63,8 @@ std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
compilation_keys.emplace_back(BOOST_GET_CONST(
std::string, node->Op()->GetAttr(operators::kCompilationKey)));
}
}
return compilation_keys;
Expand All @@ -86,7 +87,8 @@ std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo(
std::unordered_set<std::string> inputs;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
if (BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)) !=
if (BOOST_GET_CONST(std::string,
node->Op()->GetAttr(operators::kCompilationKey)) !=
key) {
continue;
}
Expand Down
18 changes: 17 additions & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <algorithm>
#include <queue>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -225,6 +226,21 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
}
}

std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const {
std::unordered_set<std::string> fetch_names;
fetch_names.reserve(fetch_var_names_.size());
std::for_each(
fetch_var_names_.begin(), fetch_var_names_.end(),
[this, &fetch_names](const std::string& name) {
PADDLE_ENFORCE_EQ(
var_model_to_program_map_.count(name), 1,
platform::errors::PreconditionNotMet(
"Cannot find %s in var_model_to_program_map_", name.c_str()));
fetch_names.insert(var_model_to_program_map_.at(name));
});
return fetch_names;
}

::cinn::frontend::Program CinnGraphSymbolization::operator()() {
std::string builder_name = "NetBuilder_of_graph_" + std::to_string(graph_id_);
VLOG(4) << "NetBuilder Name " << builder_name;
Expand All @@ -235,7 +251,7 @@ ::cinn::frontend::Program CinnGraphSymbolization::operator()() {
auto cinn_scope = CreateCinnScope(feed_map);

OpMapperContext ctx(*cinn_scope, target_, &builder, &var_map_,
&var_model_to_program_map_);
&var_model_to_program_map_, &fetch_var_names_);
// add all tensor's feed info into context
for (auto& feed_pair : feed_map) {
ctx.AddFeedInfo(feed_pair.first, feed_pair.second);
Expand Down
10 changes: 9 additions & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ limitations under the License. */
#pragma once

#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
Expand Down Expand Up @@ -84,6 +86,9 @@ class CinnGraphSymbolization {
return var_model_to_program_map_;
}

// get fetch var ids used in CINN
std::unordered_set<std::string> GetFetchIds() const;

using OpMapperContext = ::cinn::frontend::OpMapperContext;
using FeedInfoMap =
std::unordered_map<std::string, OpMapperContext::FeedInfo>;
Expand All @@ -95,10 +100,13 @@ class CinnGraphSymbolization {
const ::cinn::common::Target& target_;
const std::map<std::string, const LoDTensor*>& input_tensors_;

// preserve local variable map
// preserve cinn variable map
std::unordered_map<std::string, ::cinn::frontend::Variable> var_map_;
std::unordered_map<std::string, std::string> var_model_to_program_map_;

// fetch var names used in paddle
std::unordered_set<std::string> fetch_var_names_;

// transform all paddle var desc in feed list into cinn_var_descs_
FeedInfoMap GetFeedInfoMapFromInput() const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class CinnGraphSymbolizationForTest {
return OpMapperContext(*cinn_symbol_->CreateCinnScope(feed_map),
cinn_symbol_->target_, builder,
&cinn_symbol_->var_map_,
&cinn_symbol_->var_model_to_program_map_);
&cinn_symbol_->var_model_to_program_map_,
&cinn_symbol_->fetch_var_names_);
}

FeedInfoMap GetFeedInfoMapFromInput() {
Expand Down Expand Up @@ -292,6 +293,7 @@ TEST_F(CinnGraphSymbolizationTest, basic) {
ASSERT_NO_THROW((*symbol_)());
ASSERT_FALSE(symbol_->var_map().empty());
ASSERT_FALSE(symbol_->var_model_to_program_map().empty());
ASSERT_TRUE(symbol_->GetFetchIds().empty());
}

} // namespace paddle2cinn
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/cinn_launch_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
namespace paddle {
namespace operators {

static constexpr char kX[] = "X";
static constexpr char kOutputs[] = "Out";
static constexpr char kCompilationKey[] = "compilation_key";
constexpr char kX[] = "X";
constexpr char kOutputs[] = "Out";
constexpr char kCompilationKey[] = "compilation_key";

using LoDTensor = framework::LoDTensor;
using CinnTensor = ::cinn::hlir::framework::Tensor;
Expand Down