diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index cb1fd97327..0c58c79490 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -9,7 +9,6 @@ #include "torch/csrc/jit/passes/lower_graph.h" #include "torch/csrc/jit/passes/lower_tuples.h" #include "torch/csrc/jit/passes/peephole.h" -#include "torch/csrc/jit/passes/remove_exceptions.h" #include "torch/csrc/jit/passes/remove_mutation.h" #include "core/lowering/lowering.h" @@ -105,7 +104,7 @@ void LowerGraph(std::shared_ptr& g, std::vector graph) { } } +/* + Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement + using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith, + so as to not invalidate the IR in challenging cases, such as nested Ifs + + Original Source from which it was adapted: + https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp +*/ + +bool certainlyThrows(Block* block) { + // A block certainly throws an exception if it contains + // the prim::RaiseException operation + for (Node* n : block->nodes()) { + if (n->kind() == prim::RaiseException) { + return true; + } + } + return false; +} + +void EliminateExceptionsSafe(Block* block) { + auto graph = block->owningGraph(); + // Generate false and true constant placeholders + Value* false_const = graph->insertConstant(IValue(false)); + Value* true_const = graph->insertConstant(IValue(true)); + + // For each prim::If node, if either block certainly throws an exception, + // replace input conditional of the node input with the logical opposite + for (Node* n : block->nodes()) { + if (n->kind() == prim::If) { + Block* true_block = n->blocks()[0]; + Block* false_block = n->blocks()[1]; + bool removed_exception = false; + Value* input_value_replacement; + + // If the block throws an exception, replace input with logical opposite + if (certainlyThrows(true_block)) { + removed_exception = true; + input_value_replacement = false_const; + } else if (certainlyThrows(false_block)) { + removed_exception = true; + input_value_replacement = true_const; + } + + // Log node and perform input replacement + if (removed_exception) { + LOG_WARNING("Detected and removing exception in TorchScript IR for node: " << util::node_info(n)); + n->insertInput(0, input_value_replacement); + n->removeInput(1); + } + } + + // Inspect and replace all instances within subblocks of the current node + for (Block* subblock : n->blocks()) { + EliminateExceptionsSafe(subblock); + } + } +} + +void EliminateExceptionsSafe(std::shared_ptr& graph) { + EliminateExceptionsSafe(graph->block()); + ConstantPropagation(graph); + ConstantPooling(graph); +} + } // namespace passes } // namespace lowering } // namespace core diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 77ff842198..1790537eb7 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -20,6 +20,7 @@ void Conv3DToConvolution(std::shared_ptr& graph); void ConvTransposed3DToConvolution(std::shared_ptr& graph); void FuseAddMMBranches(std::shared_ptr graph); void LinearToAddMM(std::shared_ptr& graph); +void EliminateExceptionsSafe(std::shared_ptr& graph); void EliminateExceptionOrPassPattern(std::shared_ptr graph); void ReduceToOperation(std::shared_ptr& graph); void ReduceGelu(std::shared_ptr& graph); diff --git a/tests/core/lowering/test_exception_elimination_pass.cpp b/tests/core/lowering/test_exception_elimination_pass.cpp index e35abbccef..aa1b8c53ee 100644 --- a/tests/core/lowering/test_exception_elimination_pass.cpp +++ b/tests/core/lowering/test_exception_elimination_pass.cpp @@ -1,6 +1,10 @@ #include "core/lowering/passes/passes.h" #include "gtest/gtest.h" +#include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/passes/canonicalize.h" +#include "torch/csrc/jit/passes/constant_pooling.h" +#include "torch/csrc/jit/passes/remove_exceptions.h" TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) { // parseIR does not support " = prim::If(%51)" with no return value @@ -169,3 +173,217 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) { } EXPECT_EQ(1, if_count); } + +TEST(LoweringPasses, EliminateExceptionsSafeIfBlock) { + /*std::string source_graph = R"IR( + graph(%x, %y): + %dim : int = aten::dim(%x) + %48 : int = prim::Constant[value=2]() + %66 : bool = aten::eq(%48, %dim) + %45 : str = prim::Constant[value="EXCEPTION"]() + %4 : Tensor = prim::If(%66) + block0(): + = prim::RaiseException(%45) + -> (%x) + block1(): + %res = aten::mul(%x, %y) + -> (%res) + return (%4))IR";*/ + + std::string target_graph = R"IR( + graph(%x : Tensor, + %y : Tensor): + %6 : Tensor = aten::mul(%x, %y) + return (%6))IR"; + + // Construct graph via manual commands, to avoid IR parsing issues with + // unassigned variables (such as prim::RaiseException) + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto y = g->insertInput(1, "y"); + auto none_const_val = g->insertConstant(torch::jit::IValue()); + auto two_const_val = g->insertConstant(torch::jit::IValue(2)); + auto x_dims = g->create(torch::jit::aten::dim, {x}, 1); + g->appendNode(x_dims); + x_dims->output()->setType(torch::jit::IntType::get()); + auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1); + g->appendNode(eq); + eq->output()->setType(torch::jit::BoolType::get()); + torch::jit::IValue except("EXCEPTION"); + auto except_val = g->insertConstant(except); + + auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1); + auto if_block0 = if_node->addBlock(); + auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); + if_block0->appendNode(exception_node); + if_block0->registerOutput(x); + + auto if_block1 = if_node->addBlock(); + auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1); + if_block1->appendNode(sum_node); + if_block1->registerOutput(sum_node->output()); + + g->insertNode(if_node); + g->registerOutput(if_node->output()); + + // Apply lowering pass and canonicalization to the graph + torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g); + g = torch::jit::Canonicalize(g, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == g->toString())); +} + +TEST(LoweringPasses, EliminateExceptionsSafeElseBlock) { + /*std::string source_graph = R"IR( + graph(%x, %y): + %dim : int = aten::dim(%x) + %48 : int = prim::Constant[value=2]() + %66 : bool = aten::eq(%48, %dim) + %45 : str = prim::Constant[value="EXCEPTION"]() + %4 : Tensor = prim::If(%66) + block0(): + %res = aten::matmul(%x, %y) + -> (%res) + block1(): + = prim::RaiseException(%45) + -> (%x) + return (%4))IR";*/ + + std::string target_graph = R"IR( + graph(%x : Tensor, + %y : Tensor): + %6 : Tensor = aten::matmul(%x, %y) + return (%6))IR"; + + // Construct graph via manual commands, to avoid IR parsing issues with + // unassigned variables (such as prim::RaiseException) + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto y = g->insertInput(1, "y"); + auto none_const_val = g->insertConstant(torch::jit::IValue()); + auto two_const_val = g->insertConstant(torch::jit::IValue(2)); + auto x_dims = g->create(torch::jit::aten::dim, {x}, 1); + g->appendNode(x_dims); + x_dims->output()->setType(torch::jit::IntType::get()); + auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1); + g->appendNode(eq); + eq->output()->setType(torch::jit::BoolType::get()); + torch::jit::IValue except("EXCEPTION"); + auto except_val = g->insertConstant(except); + + auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1); + auto if_block0 = if_node->addBlock(); + auto sum_node = g->create(torch::jit::aten::matmul, {x, y}, 1); + if_block0->appendNode(sum_node); + if_block0->registerOutput(sum_node->output()); + + auto if_block1 = if_node->addBlock(); + auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); + if_block1->appendNode(exception_node); + if_block1->registerOutput(x); + + g->insertNode(if_node); + g->registerOutput(if_node->output()); + + // Apply lowering pass and canonicalization to the graph + torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g); + g = torch::jit::Canonicalize(g, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == g->toString())); +} + +TEST(LoweringPasses, EliminateExceptionsSafeNestedIfBlock) { + /*std::string source_graph = R"IR( + graph(%x, %y): + %false : bool = prim::Constant[value=0]() + %dim : int = aten::dim(%x) + %48 : int = prim::Constant[value=2]() + %66 : bool = aten::eq(%48, %dim) + %45 : str = prim::Constant[value="EXCEPTION"]() + %4 : Tensor = prim::If(%66) + block0(): + %45 : str = prim::Constant[value="EXCEPTION"]() + = prim::If(%false) + block0(): + -> () + block1(): + = prim::RaiseException(%45) + -> () + = prim::RaiseException(%45) + -> (%x) + block1(): + %res = aten::mul(%x, %y) + -> (%res) + return (%4))IR";*/ + + std::string target_graph = R"IR( + graph(%x : Tensor, + %y : Tensor): + %6 : Tensor = aten::mul(%x, %y) + return (%6))IR"; + + // Construct graph via manual commands, to avoid IR parsing issues with + // unassigned variables (such as prim::RaiseException) + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto y = g->insertInput(1, "y"); + auto none_const_val = g->insertConstant(torch::jit::IValue()); + auto false_const_val = g->insertConstant(torch::jit::IValue(false)); + auto two_const_val = g->insertConstant(torch::jit::IValue(2)); + auto x_dims = g->create(torch::jit::aten::dim, {x}, 1); + g->appendNode(x_dims); + x_dims->output()->setType(torch::jit::IntType::get()); + auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1); + g->appendNode(eq); + eq->output()->setType(torch::jit::BoolType::get()); + torch::jit::IValue except("EXCEPTION"); + auto except_val = g->insertConstant(except); + + // Construct nested-If substructure in graph + auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1); + auto if_block0 = if_node->addBlock(); + auto if_if_node = g->create(torch::jit::prim::If, {false_const_val}, 0); + if_block0->appendNode(if_if_node); + /* auto if_if_block0 = */ if_if_node->addBlock(); + auto if_if_block1 = if_if_node->addBlock(); + auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); + if_if_block1->appendNode(exception_node); + auto exception_node_2 = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); + if_block0->appendNode(exception_node_2); + if_block0->registerOutput(x); + + auto if_block1 = if_node->addBlock(); + auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1); + if_block1->appendNode(sum_node); + if_block1->registerOutput(sum_node->output()); + + g->insertNode(if_node); + g->registerOutput(if_node->output()); + + // Apply lowering pass and canonicalization to the graph + torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g); + g = torch::jit::Canonicalize(g, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == g->toString())); +}