Skip to content

Commit

Permalink
[MLIR][TORCH] Add torch-onnx-to-torch-backend pipeline (#3801)
Browse files Browse the repository at this point in the history
This commit adds the torch-onnx-to-torch-backend pipeline which
converts the Torch Onnx IR to Torch Backend IR.

This commit also moves the `ScalarizeShapes` pass from the
`torch-backend-to-linalg-on-tensors-backend-pipeline` to the
`torch-onnx-to-torch-backend` pipeline since the primary goal of
this pass is to scalarize the shapes in the IR coming from the
Onnx models.
  • Loading branch information
vivekkhandelwal1 authored Oct 21, 2024
1 parent d2330df commit fa4794d
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 18 deletions.
5 changes: 5 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ void createTorchDynamoExportToTorchBackendPipeline(
void createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);

/// Creates a pipeline that lowers the torch Onnx IR that is produced by
/// Onnx import into the form expected by torch-verify-backend-contract.
void createTorchOnnxToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);

/// Creates a pipeline that simplifies the computations in the program.
/// This pass does not do any global program restructuring -- it works entirely
/// within a single semantic model of a `builtin.module` with
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"

void mlir::torch::registerTorchPasses() {
mlir::torch::registerPasses();
Expand All @@ -25,6 +26,10 @@ void mlir::torch::registerTorchPasses() {
"torch-function-to-torch-backend-pipeline",
"Pipeline lowering a Torch function to Torch backend form.",
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-onnx-to-torch-backend-pipeline",
"Pipeline lowering Torch Onnx IR to Torch backend form.",
mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-simplification-pipeline",
"Pipeline simplifying computations in the program.",
Expand Down Expand Up @@ -86,6 +91,37 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
options.backendLegalOps, options.extraLibrary));
}

void mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(onnx_c::createTorchOnnxToTorchPass());
// The above pass just converts the torch onnx IR to torch, hence the given
// pipeline will make sure that the IR is transformed such that it satisfies
// the backend contract.
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
// TODO: Move the combination of two passes i.e., ScalarizeShapes and
// TorchShapeRefinementPipeline out of here and create an onnx shape
// refinement pipeline which runs iteratively over the IR.
createTorchShapeRefinementPipeline(pm, options);
// This pass scalarizes the tensor shape computations.
pm.addNestedPass<mlir::func::FuncOp>(
mlir::torch::Torch::createScalarizeShapesPass());
createTorchShapeRefinementPipeline(pm, options);
pm.addPass(Torch::createRefinePublicReturnPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The decompose pass is run again here since the scalarize shapes pass and
// shape refinement pipeline might create some ops for which decomposition
// exists.
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
}

// A simplification pipeline to establish the invariants of the backend
// contract (see `satisfiedBackendContract` in `LowerToBackendContract`).
//
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(

// We want to fuse quantized operations together before lowering to linalg.
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
pm.addNestedPass<func::FuncOp>(Torch::createScalarizeShapesPass());

// Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants,
Expand Down
26 changes: 9 additions & 17 deletions projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,33 +100,25 @@ def _module_lowering(
print("ONNX RAW IR")
print(torch_mod)

# Lower from ONNX to Torch
run_pipeline_with_repro_report(
torch_mod,
# The importer may produce additional MLIR functions corresponding to
# ONNX operators that are functions. In some cases they need to be
# inlined to avoid the backend choking on them.
f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
)

if verbose:
print("\n====================")
print("TorchFX IR")
print(torch_mod)

backend_legal_ops = [
"aten.flatten.using_ints",
"aten.adaptive_avg_pool1d",
"aten.unflatten.int",
]
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"

# Lower from ONNX to Torch
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(torch-lower-to-backend-contract{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
f"builtin.module(torch-onnx-to-torch-backend-pipeline{option_string})",
"Lowering Onnx Raw IR -> Torch Backend IR",
)

if verbose:
print("\n====================")
print("Torch IR")
print(torch_mod)

return lower_mlir_module(verbose, output_type, torch_mod)


Expand Down
67 changes: 67 additions & 0 deletions test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-onnx-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints,aten.unflatten.int})' -split-input-file %s | FileCheck %s

// CHECK-LABEL: func.func @test_reshape_negative_dim_decompose
func.func @test_reshape_negative_dim_decompose(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[INT6:.+]] = torch.constant.int 6
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.view %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,6,2],f32>
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32>
return %0 : !torch.vtensor<[2,6,2],f32>
}

// -----

// CHECK-LABEL: func.func @test_triu_decompose
func.func @test_triu_decompose(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ZERO_TENSOR:.+]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[INT4:.+]] = torch.constant.int 4
// CHECK: %[[INT5:.+]] = torch.constant.int 5
// CHECK: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT4]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64>
// CHECK: %[[ARANGE_0:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT5]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64>
// CHECK: %[[UNSQUEEZE:.+]] = torch.aten.unsqueeze %[[ARANGE]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64>
// CHECK: %[[UNSQUEEZE_0:.+]] = torch.aten.unsqueeze %[[ARANGE_0]], %[[INT0]] : !torch.vtensor<[5],si64>, !torch.int -> !torch.vtensor<[1,5],si64>
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[UNSQUEEZE]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64>
// CHECK: %[[COND:.+]] = torch.aten.ge.Tensor %[[UNSQUEEZE_0]], %[[ADD]] : !torch.vtensor<[1,5],si64>, !torch.vtensor<[4,1],si64> -> !torch.vtensor<[4,5],i1>
// CHECK: %[[RESULT:.+]] = torch.aten.where.self %[[COND]], %arg0, %[[ZERO_TENSOR]] : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4,5],si64>
%0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64>
return %0 : !torch.vtensor<[4,5],si64>
}

// -----

module {
// CHECK-LABEL: func.func @test_scalarize
func.func @test_scalarize(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} {
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[ADD:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT2]], %[[INT3]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%2 = torch.operator "onnx.Gather"(%0, %1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
%3 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64>
%4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__22> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%5 = torch.operator "onnx.Gather"(%3, %4) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
%6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%7 = torch.operator "onnx.Unsqueeze"(%2, %6) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64>
%8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%9 = torch.operator "onnx.Unsqueeze"(%5, %8) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64>
%10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_3209> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%11 = torch.operator "onnx.Concat"(%7, %9, %10) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64>
%12 = torch.operator "onnx.Reshape"(%arg0, %11) : (!torch.vtensor<[?,?,16,64],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32>
return %12 : !torch.vtensor<[?,?,?],f32>
}
}

{-#
dialect_resources: {
builtin: {
__21: "0x080000000000000000000000",
__22: "0x080000000100000000000000",
_onnx__Concat_3209: "0x080000000004000000000000"
}
}
#-}

0 comments on commit fa4794d

Please sign in to comment.