Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Replace EliminateExceptions lowering pass #1859

Merged
merged 1 commit into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "torch/csrc/jit/passes/lower_graph.h"
#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/passes/remove_exceptions.h"
#include "torch/csrc/jit/passes/remove_mutation.h"

#include "core/lowering/lowering.h"
Expand Down Expand Up @@ -105,7 +104,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
torch::jit::InlineFunctionalGraphs(g);
torch::jit::PeepholeOptimize(g, false);
torch::jit::FuseLinear(g);
torch::jit::EliminateExceptions(g);
passes::EliminateExceptionsSafe(g);
if (!lower_info.disable_cse) {
torch::jit::EliminateCommonSubexpression(g);
}
Expand Down
66 changes: 66 additions & 0 deletions core/lowering/passes/exception_elimination.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
Expand Down Expand Up @@ -108,6 +109,71 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
}
}

/*
Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement
using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith,
so as to not invalidate the IR in challenging cases, such as nested Ifs

Original Source from which it was adapted:
https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp
*/

bool certainlyThrows(Block* block) {
// A block certainly throws an exception if it contains
// the prim::RaiseException operation
for (Node* n : block->nodes()) {
if (n->kind() == prim::RaiseException) {
return true;
}
}
return false;
}

void EliminateExceptionsSafe(Block* block) {
auto graph = block->owningGraph();
// Generate false and true constant placeholders
Value* false_const = graph->insertConstant(IValue(false));
Value* true_const = graph->insertConstant(IValue(true));

// For each prim::If node, if either block certainly throws an exception,
// replace input conditional of the node input with the logical opposite
for (Node* n : block->nodes()) {
if (n->kind() == prim::If) {
Block* true_block = n->blocks()[0];
Block* false_block = n->blocks()[1];
bool removed_exception = false;
Value* input_value_replacement;

// If the block throws an exception, replace input with logical opposite
if (certainlyThrows(true_block)) {
removed_exception = true;
input_value_replacement = false_const;
} else if (certainlyThrows(false_block)) {
removed_exception = true;
input_value_replacement = true_const;
}

// Log node and perform input replacement
if (removed_exception) {
LOG_WARNING("Detected and removing exception in TorchScript IR for node: " << util::node_info(n));
n->insertInput(0, input_value_replacement);
n->removeInput(1);
Comment on lines +142 to +160
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly modifies inputs to the prim::If node only, to avoid incorrect boolean modifications elsewhere.

Logs a warning informing the user that an exception was automatically removed from the TorchScript IR.

}
}

// Inspect and replace all instances within subblocks of the current node
for (Block* subblock : n->blocks()) {
EliminateExceptionsSafe(subblock);
}
}
}

void EliminateExceptionsSafe(std::shared_ptr<Graph>& graph) {
EliminateExceptionsSafe(graph->block());
ConstantPropagation(graph);
ConstantPooling(graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionsSafe(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
218 changes: 218 additions & 0 deletions tests/core/lowering/test_exception_elimination_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/passes/canonicalize.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/remove_exceptions.h"

TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
// parseIR does not support " = prim::If(%51)" with no return value
Expand Down Expand Up @@ -169,3 +173,217 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
}
EXPECT_EQ(1, if_count);
}

TEST(LoweringPasses, EliminateExceptionsSafeIfBlock) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these test cases, what's the difference between the graph before and after that pass?

Copy link
Collaborator Author

@gs-olive gs-olive Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is that the graph will be collapsed since the exception will be removed from the graph. This is the same as the graph effect of the torch::jit::EliminateExceptions lowering pass. For instance, the change would be:

Original Graph:

    graph(%x, %y):
      %dim : int = aten::dim(%x)
      %48 : int = prim::Constant[value=2]()
      %66 : bool = aten::eq(%48, %dim)
      %45 : str = prim::Constant[value="EXCEPTION"]()
      %4 : Tensor = prim::If(%66)
        block0():
          = prim::RaiseException(%45)
          -> (%x)
        block1():
          %res = aten::mul(%x, %y)
          -> (%res)
      return (%4)

New Graph:

    graph(%x : Tensor,
          %y : Tensor):
      %6 : Tensor = aten::mul(%x, %y)
      return (%6)

/*std::string source_graph = R"IR(
graph(%x, %y):
%dim : int = aten::dim(%x)
%48 : int = prim::Constant[value=2]()
%66 : bool = aten::eq(%48, %dim)
%45 : str = prim::Constant[value="EXCEPTION"]()
%4 : Tensor = prim::If(%66)
block0():
= prim::RaiseException(%45)
-> (%x)
block1():
%res = aten::mul(%x, %y)
-> (%res)
return (%4))IR";*/

std::string target_graph = R"IR(
graph(%x : Tensor,
%y : Tensor):
%6 : Tensor = aten::mul(%x, %y)
return (%6))IR";

// Construct graph via manual commands, to avoid IR parsing issues with
// unassigned variables (such as prim::RaiseException)
auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
auto none_const_val = g->insertConstant(torch::jit::IValue());
auto two_const_val = g->insertConstant(torch::jit::IValue(2));
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1);
g->appendNode(x_dims);
x_dims->output()->setType(torch::jit::IntType::get());
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1);
g->appendNode(eq);
eq->output()->setType(torch::jit::BoolType::get());
torch::jit::IValue except("EXCEPTION");
auto except_val = g->insertConstant(except);

auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1);
auto if_block0 = if_node->addBlock();
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
if_block0->appendNode(exception_node);
if_block0->registerOutput(x);

auto if_block1 = if_node->addBlock();
auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1);
if_block1->appendNode(sum_node);
if_block1->registerOutput(sum_node->output());

g->insertNode(if_node);
g->registerOutput(if_node->output());

// Apply lowering pass and canonicalization to the graph
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
g = torch::jit::Canonicalize(g, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == g->toString()));
}

TEST(LoweringPasses, EliminateExceptionsSafeElseBlock) {
/*std::string source_graph = R"IR(
graph(%x, %y):
%dim : int = aten::dim(%x)
%48 : int = prim::Constant[value=2]()
%66 : bool = aten::eq(%48, %dim)
%45 : str = prim::Constant[value="EXCEPTION"]()
%4 : Tensor = prim::If(%66)
block0():
%res = aten::matmul(%x, %y)
-> (%res)
block1():
= prim::RaiseException(%45)
-> (%x)
return (%4))IR";*/

std::string target_graph = R"IR(
graph(%x : Tensor,
%y : Tensor):
%6 : Tensor = aten::matmul(%x, %y)
return (%6))IR";

// Construct graph via manual commands, to avoid IR parsing issues with
// unassigned variables (such as prim::RaiseException)
auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
auto none_const_val = g->insertConstant(torch::jit::IValue());
auto two_const_val = g->insertConstant(torch::jit::IValue(2));
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1);
g->appendNode(x_dims);
x_dims->output()->setType(torch::jit::IntType::get());
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1);
g->appendNode(eq);
eq->output()->setType(torch::jit::BoolType::get());
torch::jit::IValue except("EXCEPTION");
auto except_val = g->insertConstant(except);

auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1);
auto if_block0 = if_node->addBlock();
auto sum_node = g->create(torch::jit::aten::matmul, {x, y}, 1);
if_block0->appendNode(sum_node);
if_block0->registerOutput(sum_node->output());

auto if_block1 = if_node->addBlock();
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
if_block1->appendNode(exception_node);
if_block1->registerOutput(x);

g->insertNode(if_node);
g->registerOutput(if_node->output());

// Apply lowering pass and canonicalization to the graph
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
g = torch::jit::Canonicalize(g, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == g->toString()));
}

TEST(LoweringPasses, EliminateExceptionsSafeNestedIfBlock) {
/*std::string source_graph = R"IR(
graph(%x, %y):
%false : bool = prim::Constant[value=0]()
%dim : int = aten::dim(%x)
%48 : int = prim::Constant[value=2]()
%66 : bool = aten::eq(%48, %dim)
%45 : str = prim::Constant[value="EXCEPTION"]()
%4 : Tensor = prim::If(%66)
block0():
%45 : str = prim::Constant[value="EXCEPTION"]()
= prim::If(%false)
block0():
-> ()
block1():
= prim::RaiseException(%45)
-> ()
= prim::RaiseException(%45)
-> (%x)
block1():
%res = aten::mul(%x, %y)
-> (%res)
return (%4))IR";*/

std::string target_graph = R"IR(
graph(%x : Tensor,
%y : Tensor):
%6 : Tensor = aten::mul(%x, %y)
return (%6))IR";

// Construct graph via manual commands, to avoid IR parsing issues with
// unassigned variables (such as prim::RaiseException)
auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
auto none_const_val = g->insertConstant(torch::jit::IValue());
auto false_const_val = g->insertConstant(torch::jit::IValue(false));
auto two_const_val = g->insertConstant(torch::jit::IValue(2));
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1);
g->appendNode(x_dims);
x_dims->output()->setType(torch::jit::IntType::get());
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1);
g->appendNode(eq);
eq->output()->setType(torch::jit::BoolType::get());
torch::jit::IValue except("EXCEPTION");
auto except_val = g->insertConstant(except);

// Construct nested-If substructure in graph
auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1);
auto if_block0 = if_node->addBlock();
auto if_if_node = g->create(torch::jit::prim::If, {false_const_val}, 0);
if_block0->appendNode(if_if_node);
/* auto if_if_block0 = */ if_if_node->addBlock();
auto if_if_block1 = if_if_node->addBlock();
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
if_if_block1->appendNode(exception_node);
auto exception_node_2 = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
if_block0->appendNode(exception_node_2);
if_block0->registerOutput(x);

auto if_block1 = if_node->addBlock();
auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1);
if_block1->appendNode(sum_node);
if_block1->registerOutput(sum_node->output());

g->insertNode(if_node);
g->registerOutput(if_node->output());

// Apply lowering pass and canonicalization to the graph
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
g = torch::jit::Canonicalize(g, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == g->toString()));
}