Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable Materializing Grads #6822

Merged
merged 11 commits into from
Mar 8, 2021
21 changes: 15 additions & 6 deletions onnxruntime/core/graph/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,24 @@ static void MoveAllNodeOutputs(Graph& graph, Node& src_node, Node& target_node)
//--- end of local helpers ---
//----------------------------

int GetNodeInputIndexFromInputName(const Node& node, const std::string& input_name) {
auto itr = std::find_if(node.InputDefs().begin(), node.InputDefs().end(),
[&input_name](const NodeArg* input) { return input->Name() == input_name; });
ORT_ENFORCE(itr != node.InputDefs().end(),
"Attempting to get index for an input which does not exist.");
auto index = std::distance(node.InputDefs().begin(), itr);
int GetIndexFromName(const Node& node, const std::string& name, bool is_input) {
const auto& node_args = is_input ? node.InputDefs() : node.OutputDefs();
auto itr = std::find_if(node_args.begin(), node_args.end(),
[&name](const NodeArg* node_arg) { return node_arg->Name() == name; });
ORT_ENFORCE(itr != node_args.end(),
"Attempting to get index by a name which does not exist.");
auto index = std::distance(node_args.begin(), itr);
return static_cast<int>(index);
}

int GetNodeInputIndexFromInputName(const Node& node, const std::string& input_name) {
return GetIndexFromName(node, input_name, true);
}

int GetNodeOutputIndexFromOutputName(const Node& node, const std::string& output_name) {
return GetIndexFromName(node, output_name, false);
}

const std::string& GetNodeInputName(const Node& node, int index) {
const auto& inputs = node.InputDefs();
ORT_ENFORCE(index >= 0 && static_cast<size_t>(index) < inputs.size(),
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/graph/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ const std::string& GetNodeInputName(const Node& node, int index);
/** Gets the index of an input arg with the specified input arg name. */
int GetNodeInputIndexFromInputName(const Node& node, const std::string& input_name);

/** Gets the index of an output arg with the specified output arg name. */
int GetNodeOutputIndexFromOutputName(const Node& node, const std::string& output_name);

/** Gets the name of the outgoing NodeArg with the specified index for the given node. */
const std::string& GetNodeOutputName(const Node& node, int index);

Expand Down
33 changes: 13 additions & 20 deletions onnxruntime/test/framework/execution_frame_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,15 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) {
onnxruntime::Graph& graph = model.MainGraph();
TypeProto tensor_float;
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
onnxruntime::NodeArg input_def("X", &tensor_float),
yield_out_def("T", &tensor_float),
onnxruntime::NodeArg input_def("X", &tensor_float), yield_out_def("T", &tensor_float),
gemm_out_def("Y", &tensor_float);

ONNX_NAMESPACE::AttributeProto required_grad;
const std::string attribute_name = "required_grad";
required_grad.set_name(attribute_name);
required_grad.set_type(ONNX_NAMESPACE::AttributeProto::INTS);
required_grad.add_ints(static_cast<int64_t>(0));
NodeAttributes attributes({{attribute_name, required_grad}});
ONNX_NAMESPACE::AttributeProto full_shape_outputs;
const std::string attribute_name = "full_shape_outputs";
full_shape_outputs.set_name(attribute_name);
full_shape_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS);
full_shape_outputs.add_ints(static_cast<int64_t>(0));
NodeAttributes attributes({{attribute_name, full_shape_outputs}});
graph.AddNode("node1", "YieldOp", "yield", ArgMap{&input_def}, ArgMap{&yield_out_def}, &attributes, kMSDomain)
.SetExecutionProviderType(xp_type);
// Add another node after YieldOp as YieldOp should not be graph output.
Expand All @@ -292,8 +291,8 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) {

DataTransferManager dtm;
profiling::Profiler profiler;
SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm,
DefaultLoggingManager().DefaultLogger(), profiler);
SessionState state(graph, execution_providers, true, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(),
profiler);

ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));

Expand All @@ -307,12 +306,8 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) {
auto cpu_allocator = execution_providers.Get(xp_type)->GetAllocator(0, OrtMemTypeDefault);

OrtValue x_value, t_value;
CreateMLValue<float>(cpu_allocator,
std::vector<int64_t>{2, 2},
std::vector<float>(4, 2.0f), &x_value);
CreateMLValue<float>(cpu_allocator,
std::vector<int64_t>{2, 2},
std::vector<float>(4, 1.0f), &t_value);
CreateMLValue<float>(cpu_allocator, std::vector<int64_t>{2, 2}, std::vector<float>(4, 2.0f), &x_value);
CreateMLValue<float>(cpu_allocator, std::vector<int64_t>{2, 2}, std::vector<float>(4, 1.0f), &t_value);

vector<OrtValue> outputs;
ExecutionFrame frame({x_idx}, {x_value}, {y_idx}, outputs, {}, state);
Expand All @@ -322,10 +317,8 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) {
ASSERT_TRUE(frame.GetMutableNodeInputOrOutputMLValue(t_idx)->IsTensor());

OrtValue& y_value = *frame.GetMutableNodeInputOrOutputMLValue(y_idx);
ASSERT_STATUS_OK(frame.AllocateMLValueTensorSelfOwnBuffer(y_value, y_idx,
DataTypeImpl::GetType<float>(),
cpu_allocator->Info(),
TensorShape(std::vector<int64_t>{2, 2})));
ASSERT_STATUS_OK(frame.AllocateMLValueTensorSelfOwnBuffer(
y_value, y_idx, DataTypeImpl::GetType<float>(), cpu_allocator->Info(), TensorShape(std::vector<int64_t>{2, 2})));

MemoryPatternGroup pattern;
ASSERT_STATUS_OK(frame.GeneratePatterns(&pattern));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Status ModuleGradientGraphBuilder::Initialize(std::istream& model_istream,
}

training_graph_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(),
config.initializer_names_to_train.end());
config.initializer_names_to_train.end());

std::vector<const NodeArg*> input_args;
for (const auto& input_name : training_graph_info_.user_input_names) {
Expand Down Expand Up @@ -78,8 +78,8 @@ Status ModuleGradientGraphBuilder::Build(const std::vector<std::vector<int64_t>>
// Build the gradient graph.
ORT_RETURN_IF_ERROR(BuildGradientGraph());

// Add Yield Op.
AddYieldOp();
// Handle user outputs and output grads.
HandleOutputsAndGrads();

// Reorder outputs.
ReorderOutputs();
Expand Down Expand Up @@ -170,59 +170,70 @@ Status ModuleGradientGraphBuilder::BuildGradientGraph() {
return Status::OK();
}

void ModuleGradientGraphBuilder::AddYieldOp() {
void ModuleGradientGraphBuilder::HandleOutputsAndGrads() {
Graph& gradient_graph = gradient_model_->MainGraph();
GraphViewer gradient_graph_viewer(gradient_graph);
const auto& gradient_node_topology_list = gradient_graph_viewer.GetNodesInTopologicalOrder();
std::unordered_set<std::string> user_output_grad_names_set;
for (const auto& name : training_graph_info_.user_output_names) {
user_output_grad_names_set.insert(name + "_grad");
user_output_grad_names_set.insert(GradientBuilderBase::GradientName(name));
}

// If an NodeArg is output of one of nodes, it's not the user output gradient needed by backward graph.
std::unordered_set<std::string> non_backward_user_output_grad_names;
// If an output gradient is output of one of nodes, need to add this output to PT's output gradient.
std::unordered_set<std::string> internal_output_grad_names;
for (auto node_index : gradient_node_topology_list) {
auto& node = *gradient_graph.GetNode(node_index);
for (const auto& node_arg : node.OutputDefs()) {
if (user_output_grad_names_set.find(node_arg->Name()) != user_output_grad_names_set.end()) {
non_backward_user_output_grad_names.insert(node_arg->Name());
internal_output_grad_names.insert(node_arg->Name());
}
}
}

// YieldOps required_grad attribute specifies the indices of the required gradients.
ONNX_NAMESPACE::AttributeProto required_grad;
const std::string attribute_name = "required_grad";
required_grad.set_name(attribute_name);
required_grad.set_type(ONNX_NAMESPACE::AttributeProto::INTS);

training_graph_info_.backward_output_grad_names_map.clear();
for (std::size_t i = 0; i < training_graph_info_.user_output_names.size(); ++i) {
const auto& name = training_graph_info_.user_output_names[i];
std::string grad_name = name + "_grad";
if (non_backward_user_output_grad_names.find(grad_name) == non_backward_user_output_grad_names.end()) {
training_graph_info_.backward_output_grad_names_map.insert(std::make_pair(grad_name, i));
required_grad.add_ints(static_cast<int64_t>(i));
}
for (const auto& output_grad_name : internal_output_grad_names) {
Node* producer_node = gradient_graph.GetMutableProducerNode(output_grad_name);
int producer_node_arg_index = graph_utils::GetNodeOutputIndexFromOutputName(*producer_node, output_grad_name);
const TypeProto* type_info = producer_node->MutableOutputDefs()[producer_node_arg_index]->TypeAsProto();
auto& external_node_arg = gradient_graph.GetOrCreateNodeArg(
gradient_graph.GenerateNodeArgName(GradientBuilderBase::ExternalOutputName(output_grad_name)), type_info);
auto& output_node_arg = gradient_graph.GetOrCreateNodeArg(
gradient_graph.GenerateNodeArgName(output_grad_name + "_add_output"), type_info);
Node& add_node = gradient_graph.AddNode(
output_grad_name + "_add", "Add", "",
{&external_node_arg, producer_node->MutableOutputDefs()[producer_node_arg_index]}, {&output_node_arg});
graph_utils::ReplaceDownstreamNodeInput(gradient_graph, *producer_node, producer_node_arg_index, add_node, 0);
}

// YieldOps full_shape_outputs attribute specifies the indices of outputs that must be full shape.
// We need this info to set make TypeAndShapeInferenceFunction work properly.
ONNX_NAMESPACE::AttributeProto full_shape_outputs;
const std::string attribute_name = "full_shape_outputs";
full_shape_outputs.set_name(attribute_name);
full_shape_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS);

std::vector<NodeArg*> yield_input_node_args;
std::vector<NodeArg*> yield_output_node_args;
for (const auto& name : training_graph_info_.user_output_names) {
training_graph_info_.output_grad_indices_require_full_shape.clear();
for (size_t i = 0; i < training_graph_info_.user_output_names.size(); i++) {
std::string name = training_graph_info_.user_output_names[i];
yield_input_node_args.emplace_back(gradient_graph.GetNodeArg(name));
}
std::string grad_name = GradientBuilderBase::GradientName(name);
if (internal_output_grad_names.find(grad_name) != internal_output_grad_names.end()) {
grad_name = GradientBuilderBase::ExternalOutputName(grad_name);
} else {
// If output grad is the direct input of backward graph, we need to materialize it
// to a all-0 tensor with same shape of output, otherwise, since it will be an input of
// Add node, it's OK to use scalar-0 tensor to save memory.
training_graph_info_.output_grad_indices_require_full_shape.emplace_back(i);
full_shape_outputs.add_ints(static_cast<int64_t>(i));
}

for (const auto& name : training_graph_info_.user_output_names) {
std::string grad_name = name + "_grad";
auto element = training_graph_info_.backward_output_grad_names_map.find(grad_name);
if (element != training_graph_info_.backward_output_grad_names_map.end()) {
yield_output_node_args.emplace_back(gradient_graph.GetNodeArg(element->first));
}
yield_output_node_args.emplace_back(gradient_graph.GetNodeArg(grad_name));
}

NodeAttributes attributes({{attribute_name, required_grad}});

gradient_graph.AddNode("YieldOp", "YieldOp", "Yield Op", yield_input_node_args, yield_output_node_args, &attributes, kMSDomain);
NodeAttributes attributes({{attribute_name, full_shape_outputs}});
gradient_graph.AddNode("YieldOp", "YieldOp", "Yield Op", yield_input_node_args, yield_output_node_args, &attributes,
kMSDomain);
}

void ModuleGradientGraphBuilder::ReorderOutputs() {
Expand All @@ -243,7 +254,7 @@ void ModuleGradientGraphBuilder::ReorderOutputs() {
training_graph_info_.user_input_grad_names.clear();
for (const auto& input_name : training_graph_info_.user_input_names) {
if (user_input_require_grad_set.find(input_name) != user_input_require_grad_set.end()) {
std::string input_gradient_name = input_name + "_grad";
std::string input_gradient_name = GradientBuilderBase::GradientName(input_name);
ORT_ENFORCE(gradient_output_arg_map.find(input_gradient_name) != gradient_output_arg_map.end(),
"Required user input grad is not found on gradient graph.");
training_graph_info_.user_input_grad_names[input_name] = input_gradient_name;
Expand All @@ -254,7 +265,7 @@ void ModuleGradientGraphBuilder::ReorderOutputs() {
// Add initializer gradients to graph outputs.
training_graph_info_.initializer_grad_names_to_train.clear();
for (const auto& initializer_name : training_graph_info_.initializer_names_to_train) {
std::string initializer_gradient_name = initializer_name + "_grad";
std::string initializer_gradient_name = GradientBuilderBase::GradientName(initializer_name);
ORT_ENFORCE(gradient_output_arg_map.find(initializer_gradient_name) != gradient_output_arg_map.end(),
"Trainable initializer grad is not found on gradient graph.");
training_graph_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ struct TrainingGraphInfo {
std::vector<std::string> initializer_grad_names_to_train{};
// The user outputs.
std::vector<std::string> user_output_names{};
// The user output grad names that are actually required by the backward graph
// mapped to the index of the correspoinding output of inference graph.
std::unordered_map<std::string, size_t> backward_output_grad_names_map{};
// Indices of output grads that need to be materialized to full size all-0 tensor.
// Otherwise, we can use scalar-0 tensor.
std::vector<size_t> output_grad_indices_require_full_shape{};
};

class ModuleGradientGraphBuilder {
Expand Down Expand Up @@ -83,8 +83,8 @@ class ModuleGradientGraphBuilder {
// Build gradient graph.
Status BuildGradientGraph();

// Add Yield Op.
void AddYieldOp();
// Handle user outputs and output grads.
void HandleOutputsAndGrads();

// Reorder gradient graph outputs.
void ReorderOutputs();
Expand Down
4 changes: 4 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class GradientBuilderBase {
return name + "_grad";
}

static std::string ExternalOutputName(const std::string& name) {
return name + "_external";
}

protected:
virtual GradientDef GetGradientDefsImpl() const = 0;

Expand Down
29 changes: 14 additions & 15 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2220,27 +2220,26 @@ Return true if all elements are true and false otherwise.
.Output(0, "outputs_grad", "Gradient of outputs returned from pytorch.", "T", OpSchema::Variadic,
/*is_homogeneous*/ false,
/*min_arity*/ 1)
.Attr(
"required_grad",
"The indices of the outputs that require gradient outputs.",
AttributeProto::INTS)
.Attr("full_shape_outputs", "The indices of the outputs that must have full shape.", AttributeProto::INTS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also add an assert in yield's kernel?
output shape should match input shape if "full_shape" is true.

.TypeConstraint("T", OpSchema::all_tensor_types(), "Allow inputs and outputs to be any kind of tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
const std::string attribute_name = "required_grad";
auto required_grads = ctx.getAttribute(attribute_name);
if (nullptr == required_grads) { // attribute not present
ORT_ENFORCE(ctx.getNumInputs() == ctx.getNumOutputs());
for (size_t i = 0; i < ctx.getNumInputs(); ++i) {
propagateElemTypeFromInputToOutput(ctx, i, i);
}

const std::string attribute_name = "full_shape_outputs";
auto full_shape_outputs = ctx.getAttribute(attribute_name);
if (nullptr == full_shape_outputs) { // attribute not present
fail_type_inference("Value of attribute ", attribute_name, " not specified");
}
ORT_ENFORCE(ctx.getNumOutputs() == static_cast<size_t> (required_grads->ints_size()));
for (size_t i = 0, n = static_cast<size_t> (required_grads->ints_size()); i < n; ++i) {
size_t j = static_cast<size_t> (required_grads->ints(static_cast<int>(i)));
ORT_ENFORCE(ctx.getNumInputs() > j);
propagateElemTypeFromInputToOutput(ctx, j, i);

for (size_t i = 0, n = static_cast<size_t>(full_shape_outputs->ints_size()); i < n; ++i) {
size_t j = static_cast<size_t>(full_shape_outputs->ints(static_cast<int>(i)));
auto typeProto = ctx.getInputType(j);
if (!hasShape(*typeProto)) {
continue;
if (hasShape(*typeProto)) {
propagateShapeFromInputToOutput(ctx, j, j);
}
propagateShapeFromInputToOutput(ctx, j, i);
}
});
}
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("initializer_names_to_train", &TrainingGraphInfo::initializer_names_to_train)
.def_readwrite("initializer_grad_names_to_train", &TrainingGraphInfo::initializer_grad_names_to_train)
.def_readwrite("user_output_names", &TrainingGraphInfo::user_output_names)
.def_readwrite("backward_output_grad_names_map", &TrainingGraphInfo::backward_output_grad_names_map);
.def_readwrite("output_grad_indices_require_full_shape", &TrainingGraphInfo::output_grad_indices_require_full_shape);

py::class_<ModuleGradientGraphBuilder> module_gradient_graph_builder(m, "ModuleGradientGraphBuilder");
module_gradient_graph_builder.def(py::init([]() { return onnxruntime::make_unique<ModuleGradientGraphBuilder>(); }))
Expand Down
Loading