-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Replace EliminateExceptions lowering pass #1859
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
#include "core/lowering/passes/passes.h" | ||
#include "gtest/gtest.h" | ||
#include "tests/util/util.h" | ||
#include "torch/csrc/jit/ir/irparser.h" | ||
#include "torch/csrc/jit/passes/canonicalize.h" | ||
#include "torch/csrc/jit/passes/constant_pooling.h" | ||
#include "torch/csrc/jit/passes/remove_exceptions.h" | ||
|
||
TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) { | ||
// parseIR does not support " = prim::If(%51)" with no return value | ||
|
@@ -169,3 +173,217 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) { | |
} | ||
EXPECT_EQ(1, if_count); | ||
} | ||
|
||
TEST(LoweringPasses, EliminateExceptionsSafeIfBlock) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For these test cases, what's the difference between the graph before and after that pass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference is that the graph will be collapsed since the exception will be removed from the graph. This is the same as the graph effect of the Original Graph:
New Graph:
|
||
/*std::string source_graph = R"IR( | ||
graph(%x, %y): | ||
%dim : int = aten::dim(%x) | ||
%48 : int = prim::Constant[value=2]() | ||
%66 : bool = aten::eq(%48, %dim) | ||
%45 : str = prim::Constant[value="EXCEPTION"]() | ||
%4 : Tensor = prim::If(%66) | ||
block0(): | ||
= prim::RaiseException(%45) | ||
-> (%x) | ||
block1(): | ||
%res = aten::mul(%x, %y) | ||
-> (%res) | ||
return (%4))IR";*/ | ||
|
||
std::string target_graph = R"IR( | ||
graph(%x : Tensor, | ||
%y : Tensor): | ||
%6 : Tensor = aten::mul(%x, %y) | ||
return (%6))IR"; | ||
|
||
// Construct graph via manual commands, to avoid IR parsing issues with | ||
// unassigned variables (such as prim::RaiseException) | ||
auto g = std::make_shared<torch::jit::Graph>(); | ||
auto x = g->insertInput(0, "x"); | ||
auto y = g->insertInput(1, "y"); | ||
auto none_const_val = g->insertConstant(torch::jit::IValue()); | ||
auto two_const_val = g->insertConstant(torch::jit::IValue(2)); | ||
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1); | ||
g->appendNode(x_dims); | ||
x_dims->output()->setType(torch::jit::IntType::get()); | ||
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1); | ||
g->appendNode(eq); | ||
eq->output()->setType(torch::jit::BoolType::get()); | ||
torch::jit::IValue except("EXCEPTION"); | ||
auto except_val = g->insertConstant(except); | ||
|
||
auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1); | ||
auto if_block0 = if_node->addBlock(); | ||
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); | ||
if_block0->appendNode(exception_node); | ||
if_block0->registerOutput(x); | ||
|
||
auto if_block1 = if_node->addBlock(); | ||
auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1); | ||
if_block1->appendNode(sum_node); | ||
if_block1->registerOutput(sum_node->output()); | ||
|
||
g->insertNode(if_node); | ||
g->registerOutput(if_node->output()); | ||
|
||
// Apply lowering pass and canonicalization to the graph | ||
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g); | ||
g = torch::jit::Canonicalize(g, false); | ||
|
||
auto tg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(target_graph, tg.get()); | ||
|
||
torch::jit::ConstantPooling(tg); | ||
tg = torch::jit::Canonicalize(tg, false); | ||
|
||
// Validate identical graphs after pooling constants and canonicalizing | ||
ASSERT_TRUE((tg->toString() == g->toString())); | ||
} | ||
|
||
TEST(LoweringPasses, EliminateExceptionsSafeElseBlock) { | ||
/*std::string source_graph = R"IR( | ||
graph(%x, %y): | ||
%dim : int = aten::dim(%x) | ||
%48 : int = prim::Constant[value=2]() | ||
%66 : bool = aten::eq(%48, %dim) | ||
%45 : str = prim::Constant[value="EXCEPTION"]() | ||
%4 : Tensor = prim::If(%66) | ||
block0(): | ||
%res = aten::matmul(%x, %y) | ||
-> (%res) | ||
block1(): | ||
= prim::RaiseException(%45) | ||
-> (%x) | ||
return (%4))IR";*/ | ||
|
||
std::string target_graph = R"IR( | ||
graph(%x : Tensor, | ||
%y : Tensor): | ||
%6 : Tensor = aten::matmul(%x, %y) | ||
return (%6))IR"; | ||
|
||
// Construct graph via manual commands, to avoid IR parsing issues with | ||
// unassigned variables (such as prim::RaiseException) | ||
auto g = std::make_shared<torch::jit::Graph>(); | ||
auto x = g->insertInput(0, "x"); | ||
auto y = g->insertInput(1, "y"); | ||
auto none_const_val = g->insertConstant(torch::jit::IValue()); | ||
auto two_const_val = g->insertConstant(torch::jit::IValue(2)); | ||
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1); | ||
g->appendNode(x_dims); | ||
x_dims->output()->setType(torch::jit::IntType::get()); | ||
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1); | ||
g->appendNode(eq); | ||
eq->output()->setType(torch::jit::BoolType::get()); | ||
torch::jit::IValue except("EXCEPTION"); | ||
auto except_val = g->insertConstant(except); | ||
|
||
auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1); | ||
auto if_block0 = if_node->addBlock(); | ||
auto sum_node = g->create(torch::jit::aten::matmul, {x, y}, 1); | ||
if_block0->appendNode(sum_node); | ||
if_block0->registerOutput(sum_node->output()); | ||
|
||
auto if_block1 = if_node->addBlock(); | ||
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); | ||
if_block1->appendNode(exception_node); | ||
if_block1->registerOutput(x); | ||
|
||
g->insertNode(if_node); | ||
g->registerOutput(if_node->output()); | ||
|
||
// Apply lowering pass and canonicalization to the graph | ||
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g); | ||
g = torch::jit::Canonicalize(g, false); | ||
|
||
auto tg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(target_graph, tg.get()); | ||
|
||
torch::jit::ConstantPooling(tg); | ||
tg = torch::jit::Canonicalize(tg, false); | ||
|
||
// Validate identical graphs after pooling constants and canonicalizing | ||
ASSERT_TRUE((tg->toString() == g->toString())); | ||
} | ||
|
||
TEST(LoweringPasses, EliminateExceptionsSafeNestedIfBlock) { | ||
/*std::string source_graph = R"IR( | ||
graph(%x, %y): | ||
%false : bool = prim::Constant[value=0]() | ||
%dim : int = aten::dim(%x) | ||
%48 : int = prim::Constant[value=2]() | ||
%66 : bool = aten::eq(%48, %dim) | ||
%45 : str = prim::Constant[value="EXCEPTION"]() | ||
%4 : Tensor = prim::If(%66) | ||
block0(): | ||
%45 : str = prim::Constant[value="EXCEPTION"]() | ||
= prim::If(%false) | ||
block0(): | ||
-> () | ||
block1(): | ||
= prim::RaiseException(%45) | ||
-> () | ||
= prim::RaiseException(%45) | ||
-> (%x) | ||
block1(): | ||
%res = aten::mul(%x, %y) | ||
-> (%res) | ||
return (%4))IR";*/ | ||
|
||
std::string target_graph = R"IR( | ||
graph(%x : Tensor, | ||
%y : Tensor): | ||
%6 : Tensor = aten::mul(%x, %y) | ||
return (%6))IR"; | ||
|
||
// Construct graph via manual commands, to avoid IR parsing issues with | ||
// unassigned variables (such as prim::RaiseException) | ||
auto g = std::make_shared<torch::jit::Graph>(); | ||
auto x = g->insertInput(0, "x"); | ||
auto y = g->insertInput(1, "y"); | ||
auto none_const_val = g->insertConstant(torch::jit::IValue()); | ||
auto false_const_val = g->insertConstant(torch::jit::IValue(false)); | ||
auto two_const_val = g->insertConstant(torch::jit::IValue(2)); | ||
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1); | ||
g->appendNode(x_dims); | ||
x_dims->output()->setType(torch::jit::IntType::get()); | ||
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1); | ||
g->appendNode(eq); | ||
eq->output()->setType(torch::jit::BoolType::get()); | ||
torch::jit::IValue except("EXCEPTION"); | ||
auto except_val = g->insertConstant(except); | ||
|
||
// Construct nested-If substructure in graph | ||
auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1); | ||
auto if_block0 = if_node->addBlock(); | ||
auto if_if_node = g->create(torch::jit::prim::If, {false_const_val}, 0); | ||
if_block0->appendNode(if_if_node); | ||
/* auto if_if_block0 = */ if_if_node->addBlock(); | ||
auto if_if_block1 = if_if_node->addBlock(); | ||
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); | ||
if_if_block1->appendNode(exception_node); | ||
auto exception_node_2 = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); | ||
if_block0->appendNode(exception_node_2); | ||
if_block0->registerOutput(x); | ||
|
||
auto if_block1 = if_node->addBlock(); | ||
auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1); | ||
if_block1->appendNode(sum_node); | ||
if_block1->registerOutput(sum_node->output()); | ||
|
||
g->insertNode(if_node); | ||
g->registerOutput(if_node->output()); | ||
|
||
// Apply lowering pass and canonicalization to the graph | ||
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g); | ||
g = torch::jit::Canonicalize(g, false); | ||
|
||
auto tg = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(target_graph, tg.get()); | ||
|
||
torch::jit::ConstantPooling(tg); | ||
tg = torch::jit::Canonicalize(tg, false); | ||
|
||
// Validate identical graphs after pooling constants and canonicalizing | ||
ASSERT_TRUE((tg->toString() == g->toString())); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly modifies inputs to the
prim::If
node only, to avoid incorrect boolean modifications elsewhere.Logs a warning informing the user that an exception was automatically removed from the TorchScript IR.