Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error occurs when Lowering Torch Backend IR -> StableHLO Backend IR #3567

Open
jiang711 opened this issue Jul 29, 2024 · 0 comments
Open

Error occurs when Lowering Torch Backend IR -> StableHLO Backend IR #3567

jiang711 opened this issue Jul 29, 2024 · 0 comments

Comments

@jiang711
Copy link

jiang711 commented Jul 29, 2024

When I try to generate the forward and backward graphs of a simple model in StableHLO IR format, the forward graph can be converted correctly, but the following error occurs when converting the backward graph:

module {
func.func @M__0_forward_1(%arg0: tensor<10x32xf32>, %arg1: tensor<10xf32>, %arg2: tensor<8x8xf32>) -> (tensor<2x10xf32>, tensor<2x32xf32>) {
%cst = arith.constant dense<1> : tensor<1xi64>
%0 = stablehlo.reshape %arg2 : (tensor<8x8xf32>) -> tensor<2x32xf32>
%1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x32xf32>) -> tensor<32x10xf32>
%2 = stablehlo.dot_general %0, %1, contracting_dims = [1] x [0] : (tensor<2x32xf32>, tensor<32x10xf32>) -> tensor<2x10xf32>
%3 = stablehlo.convert %cst : (tensor<1xi64>) -> tensor<1xf32>
%4 = stablehlo.reshape %3 : (tensor<1xf32>) -> tensor
%5 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<10xf32>) -> tensor<10xf32>
%6 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<10xf32>
%7 = stablehlo.multiply %5, %6 : tensor<10xf32>
%8 = stablehlo.broadcast_in_dim %7, dims = [1] : (tensor<10xf32>) -> tensor<2x10xf32>
%9 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<2x10xf32>) -> tensor<2x10xf32>
%10 = stablehlo.add %8, %9 : tensor<2x10xf32>
return %10, %0 : tensor<2x10xf32>, tensor<2x32xf32>
}
}

error: failed to legalize operation 'torch.constant.none'
error: Module does not conform to the Stablehlo backend contract. See dialect conversion legality information above.
Traceback (most recent call last):
File "/home/tomjiang/test/test-mlp/issue.py", line 54, in
loss.backward()
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/autograd/init.py", line 346, in backward
_engine_run_backward(
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/autograd/graph.py", line 806, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/autograd/function.py", line 306, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1976, in backward
out = call_compiled_backward()
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1908, in call_compiled_backward
CompiledFunction.compiled_bw = aot_config.bw_compiler(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 47, in _wrapped_bw_compiler
return disable(disable(bw_compiler)(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 602, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 103, in f
out_f = compiler(fx_g, inps)
^^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/test/test-mlp/issue.py", line 32, in fx_import_aot_autograd_backend
m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name(), output_type=OutputType.STABLEHLO)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch_mlir/fx.py", line 116, in stateless_fx_import
return _module_lowering(
^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch_mlir/fx.py", line 47, in _module_lowering
return lower_mlir_module(verbose, output_type, torch_mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch_mlir/compiler_utils.py", line 167, in lower_mlir_module
run_pipeline_with_repro_report(
File "/home/tomjiang/miniconda3/envs/torch-mlir/lib/python3.11/site-packages/torch_mlir/compiler_utils.py", line 78, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> StableHLO Backend IR failed with the following diagnostics:

python exception: Failure while executing pass pipeline

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-stablehlo-backend-pipeline)' /tmp/UnnammedModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

Code is:

from typing import List

import torch
import torch.nn as nn
from torch._dynamo.backends.common import aot_autograd
from torch._functorch.aot_autograd import (
    make_boxed_compiler,
    get_aot_graph_name,
    set_model_name,
)

from torch_mlir import fx
from torch_mlir.compiler_utils import OutputType
import torch_mlir

class M(nn.Module):
    def init(self):
        super().init()
        self.l1 = nn.Linear(32, 10)
    def forward(self, x):
        x = x.view(-1, 32)
        return self.l1(x)

@make_boxed_compiler
def fx_import_aot_autograd_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    # print(gm.print_readable(False), flush=True)
    # print(gm.code)
    m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name(), output_type=OutputType.STABLEHLO)
    print(m, flush=True)
    return gm

if name == 'main':

    model = M()

    criterion = nn.CrossEntropyLoss()
    lr = 0.001
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    fx_import_backend = aot_autograd(fw_compiler=fx_import_aot_autograd_backend)
    set_model_name("M")
    model_opt = torch.compile(model, backend=fx_import_backend)

    out = model_opt(torch.randn(8, 8))
    labels = torch.randn(2, 10)

    loss = criterion(out, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant