Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange behavior while lowering nn.BatchNorm2d #110

Closed
JBloodless opened this issue Oct 26, 2023 · 15 comments
Closed

Strange behavior while lowering nn.BatchNorm2d #110

JBloodless opened this issue Oct 26, 2023 · 15 comments
Assignees

Comments

@JBloodless
Copy link

Hello. I'm using the most basic BatchNorm in my model and for some reason it is being lowered to some strange variation, which is not supported.

Here's minimal repro:

import shark_turbine.aot as aot
import torch
import torch.nn as nn

class AdaptCNN(nn.Module):
    def __init__(self,
                 input_channels=1,
                 c_out_1=16,
                 kernel_size=(3, 3),
                 ):
        super().__init__()

        self.input_channels = input_channels
        self.c_out_1 = c_out_1
        self.kernel_size = kernel_size

        self.cnn_pad = (1, 1)

        self.conv1 = nn.Conv2d(
            self.input_channels,
            self.c_out_1,
            self.kernel_size,
            padding=self.cnn_pad)

        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)

    def forward(self, x):
        x = self.bn1(self.conv1(x))
        return x


x = torch.zeros((63, 1, 48, 15))

model = AdaptCNN()
model.eval()

export_output = aot.export(model, x)
binary = export_output.compile(save_to=None)

And here's what I get:

Traceback (most recent call last):
  File "/Users/i.beskrovnyy/tts/NISQA-s/repro_bn.py", line 38, in <module>
    export_output = aot.export(model, x)
                    ^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/exporter.py", line 198, in export
    cm = Exported(context=context, import_to="import")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 534, in __new__
    do_export(proc_def)
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 531, in do_export
    trace.trace_py_func(invoke_with_self)
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/tracer.py", line 120, in trace_py_func
    return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 512, in invoke_with_self
    return proc_def.callable(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/exporter.py", line 182, in main
    return jittable(mdl.forward)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/base.py", line 137, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/tracer.py", line 136, in handle_call
    return target.resolve_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/builtins/jittable.py", line 239, in resolve_call
    fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name)
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/dynamo/importer.py", line 265, in import_stateless_graph
    node_importer.import_nodes(g.nodes)
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/dynamo/importer.py", line 496, in import_nodes
    self._import_torch_op_overload(loc, node, target)
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/dynamo/importer.py", line 633, in _import_torch_op_overload
    raise NotImplementedError(
NotImplementedError: Unimplemented torch op in the IREE compiler: 'torch.aten._native_batch_norm_legit_no_training' (either implement this op/variant or configure the compiler to allow unknown operations and fallback to PyTorch).

Why BatchNorm fallbacks to this _native_batch_norm_legit_no_training and what am I doing wrong here?

@dan-garvey
Copy link
Member

Thanks for filing this! You aren't doing anything wrong. We need to update the pypi version of turbine. If you do pip uninstall shark-turbine and then navigate to a clone of shark-turbine, pip install -e ., then your code should work without issue. This issue was solved by #102 . We'll push some new pip packages soon. Let me know if this works for you and I can close the issue.

@JBloodless
Copy link
Author

JBloodless commented Oct 27, 2023

Thanks for filing this! You aren't doing anything wrong. We need to update the pypi version of turbine. If you do pip uninstall shark-turbine and then navigate to a clone of shark-turbine, pip install -e ., then your code should work without issue. This issue was solved by #102 . We'll push some new pip packages soon. Let me know if this works for you and I can close the issue.

Export works with latest turbine (I guess), but .compile throws

loc("<eval_with_key>.0 from /Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:489 in wrapped":8:0): error: 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible
Traceback (most recent call last):
  File "/Users/i.beskrovnyy/tts/NISQA-s/repro_bn.py", line 40, in <module>
    binary = export_output.compile(save_to=None)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/SHARK-Turbine/python/shark_turbine/aot/exporter.py", line 138, in compile
    raise RuntimeError("Compilation failed: See diagnostics")
RuntimeError: Compilation failed: See diagnostics

@zjgarvey
Copy link
Collaborator

Thanks for filing this! You aren't doing anything wrong. We need to update the pypi version of turbine. If you do pip uninstall shark-turbine and then navigate to a clone of shark-turbine, pip install -e ., then your code should work without issue. This issue was solved by #102 . We'll push some new pip packages soon. Let me know if this works for you and I can close the issue.

Export works with latest turbine (I guess), but .compile throws

loc("<eval_with_key>.0 from /Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:489 in wrapped":8:0): error: 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible
Traceback (most recent call last):
  File "/Users/i.beskrovnyy/tts/NISQA-s/repro_bn.py", line 40, in <module>
    binary = export_output.compile(save_to=None)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/SHARK-Turbine/python/shark_turbine/aot/exporter.py", line 138, in compile
    raise RuntimeError("Compilation failed: See diagnostics")
RuntimeError: Compilation failed: See diagnostics

Here is a minimal reproduction:

import shark_turbine.aot as aot
import torch
import torch.nn as nn

class ExMod(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = nn.BatchNorm2d(100)

    def forward(self,x):
        return self.m(x)

x = torch.zeros(20,100,35,45)

mod = ExMod()
mod.eval()

export_output = aot.export(mod,x)
#export_output.save_mlir('bnex.mlir')
binary = export_output.compile(save_to = None)

@JBloodless
Copy link
Author

Bumping this issue since not having working nn.BatchNorm2d looks crucial (it's one of the most basic operations in models)

@stellaraccident
Copy link
Contributor

We may be missing patterns to elide zero element tensors.

@AmosLewis
Copy link
Contributor

AmosLewis commented Nov 16, 2023

(iree_venv) ➜  src  cd /nodclouddata/chi/src ; /usr/bin/env /nodclouddata/chi/s
rc/SHARK-Turbine/.venv/bin/python3.11 /home/chi/.vscode-server/extensions/ms-py
thon.python-2023.20.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/laun
cher 38747 -- /nodclouddata/chi/src/SHARK-Turbine/tests/aot/batchnorm2d_test.py
 
loc("<eval_with_key>.0 from /nodclouddata/chi/src/SHARK-Turbine/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":5:0): 
error: 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible

The key debug info is 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible

With manually build iree-compile from iree source code, I finally locate the fail ops torch.aten.empty.memory_format

(iree_venv) ➜  SHARK-Turbine git:(main) ✗ iree-compile bnex.mlir --compile-to=input
bnex.mlir:18:10: error: 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible
    %1 = torch.aten.empty.memory_format %0, %int0_0, %int0_1, %cpu, %none, %none_2 : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.none, !torch.none -> !torch.vtensor<[0],ui8>
         ^
bnex.mlir:18:10: note: see current operation: %20 = "tensor.cast"(%19) : (tensor<?xui8>) -> tensor<0xi8>

Then iree-compile bnex.mlir --compile-to=input --debug, it looks like is when lowering torch-dialect to linalg dialect.

The torch-mlir-opt debug info is here:
batchnorm2d_test.py
The extract empty.memory_format op debug info is here:
empty_memory_format_test.mlir

@AmosLewis
Copy link
Contributor

AmosLewis commented Nov 16, 2023

One more bug: the export_output.save_mlir('bnex.mlir') with TOM iree will lead to this. But with shark_turbine iree-compiler=20131113.707 is good. This must be arisen from iree module print interface change after 20131113.707

TypeError: print(): incompatible function arguments. The following argument types are supported:
    1. (self: iree.compiler._mlir_libs._mlir.ir._OperationBase, state: mlir::python::PyAsmState, file: object = None, binary: bool = False) -> None
    2. (self: iree.compiler._mlir_libs._mlir.ir._OperationBase, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False, file: object = None, binary: bool = False) -> None

Invoked with: <iree.compiler._mlir_libs._mlir.ir.Operation object at 0x7f9fea26e030>, <_io.BufferedWriter name='bnex.mlir'>; kwargs: binary=True

@JBloodless
Copy link
Author

Is there any progress/workaround yet?

@AmosLewis
Copy link
Contributor

Is there any progress/workaround yet?

Yes. Should be fixed this week.

@AmosLewis
Copy link
Contributor

AmosLewis commented Nov 28, 2023

Is there any progress/workaround yet?

For the upstream fix, we will need to push a fix patch in torch-mlir upstream and then uplift the torch-mlir version IREE, then uplift IREE version Shark-Turbine, it will take some time. So here I provide a local quick fix for your usage.

For the runtime error.
Here is a quick fix locally in IREE thirdparty torch-mlir.
To build iree, you need to look at IREE getting started
Then once you setup iree build env(including the python env), you need to manually changed code here https://github.com/shark-infra/torch-mlir/blob/ad182198204e299377d4efba74fa2241e5404b0f/lib/Dialect/Torch/IR/TorchTypes.cpp#L405C1-L406C52
Replace the

    return IntegerType::get(context, integerType.getWidth(),
                            IntegerType::Signless);

with

    return dtype;

Then build iree locally

cmake -G Ninja -B ../iree-build/ -S . \
    -DCMAKE_BUILD_TYPE=Release \
    -DIREE_ENABLE_ASSERTIONS=ON \
    -DIREE_ENABLE_SPLIT_DWARF=ON \
    -DIREE_ENABLE_THIN_ARCHIVES=ON \
    -DCMAKE_C_COMPILER_LAUNCHER=ccache \
    -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
    -DIREE_BUILD_PYTHON_BINDINGS=ON \
    -DPython3_EXECUTABLE="$(which python)" \
    -DCMAKE_C_COMPILER=clang \
    -DCMAKE_CXX_COMPILER=clang++ \
    -DIREE_ENABLE_LLD=ON
cmake --build ../iree-build/

Then in order to use the new build python bindings, in your shark-turbine python env you have to

pip uninstall iree-compiler
pip uninstall iree-runtime

And set PYTHONPATH with the iree new build python bindings, here is an example in my path:
export PYTHONPATH=/nodclouddata/chi/src/iree-build/compiler/bindings/python:/nodclouddata/chi/src/iree-build/runtime/bindings/python:/nodclouddata/chi/src/SHARK-Turbine/python:$PYTHONPATH
This quick fix won't be merged in torch-mlir, instead the final solution is the fix link I give you, that will be merged in torch-mlir later.

For the save_mlir incompatible error: you can use export_output.print_readable() instead

AmosLewis pushed a commit to AmosLewis/torch-mlir that referenced this issue Nov 28, 2023
  When converting vtenor to buildin tensor, the converter wrong convert the ui8 to i8.
  This will lead to other type cast issue in later usage of i8. Here is example:
  nod-ai/SHARK-Turbine#110
AmosLewis pushed a commit to AmosLewis/torch-mlir that referenced this issue Nov 28, 2023
  When converting vtenor to buildin tensor, the converter wrong convert the ui8 to i8.
  This will lead to other type cast issue in later usage of i8. Here is example:
  nod-ai/SHARK-Turbine#110
@AmosLewis
Copy link
Contributor

AmosLewis commented Nov 29, 2023

@JBloodless With the TOM torch-mlir the issue is fixed now. We will next uplift the change into IREE and then shark-turbine.

pip install iree-compiler==20231130.724
pip install iree-runtime==20231130.724

AmosLewis pushed a commit to AmosLewis/iree that referenced this issue Nov 29, 2023
Bumping torch-mlir on 11/29/2023 up to commit e568f7e
This can fix the BatchNorm2d runtime error:
nod-ai/SHARK-Turbine#110
AmosLewis added a commit to AmosLewis/iree that referenced this issue Nov 29, 2023
Bumping torch-mlir on 11/29/2023 up to commit e568f7e.
This can fix the BatchNorm2d runtime error:
nod-ai/SHARK-Turbine#110
stellaraccident pushed a commit to iree-org/iree that referenced this issue Nov 29, 2023
Bumping torch-mlir on 11/29/2023 up to commit e568f7e.
This can fix the BatchNorm2d runtime error:
nod-ai/SHARK-Turbine#110
AmosLewis added a commit to AmosLewis/SHARK-Turbine that referenced this issue Nov 30, 2023
To solve the issue nod-ai#110
AmosLewis added a commit to AmosLewis/SHARK-Turbine that referenced this issue Nov 30, 2023
@AmosLewis
Copy link
Contributor

AmosLewis commented Nov 30, 2023

@JBloodless Now you can fix your issue in Shark-Turbine Python env by changing requirements to

iree-compiler==20231130.724
iree-runtime==20231130.724

@JBloodless
Copy link
Author

iree-runtime==20231130.724

(mlir) i.beskrovnyy@i-beskrovnyy-2 NISQA-s % pip install iree-compiler==20231130.724
ERROR: Could not find a version that satisfies the requirement iree-compiler==20231130.724 (from versions: 20230404.479, 20230419.494, 20230524.529, 20230815.614, 20231004.665, 20231113.707)
ERROR: No matching distribution found for iree-compiler==20231130.724
(mlir) i.beskrovnyy@i-beskrovnyy-2 NISQA-s % pip install iree-runtime==20231130.724
ERROR: Could not find a version that satisfies the requirement iree-runtime==20231130.724 (from versions: 20230228.444, 20230404.479, 20230419.494, 20230524.529, 20230815.614, 20231004.665, 20231113.707)
ERROR: No matching distribution found for iree-runtime==20231130.724

Seems like these versions are not in pypi yet, but I'll keep monitoring after the weekend, thanks

@AmosLewis
Copy link
Contributor

iree-runtime==20231130.724

(mlir) i.beskrovnyy@i-beskrovnyy-2 NISQA-s % pip install iree-compiler==20231130.724
ERROR: Could not find a version that satisfies the requirement iree-compiler==20231130.724 (from versions: 20230404.479, 20230419.494, 20230524.529, 20230815.614, 20231004.665, 20231113.707)
ERROR: No matching distribution found for iree-compiler==20231130.724
(mlir) i.beskrovnyy@i-beskrovnyy-2 NISQA-s % pip install iree-runtime==20231130.724
ERROR: Could not find a version that satisfies the requirement iree-runtime==20231130.724 (from versions: 20230228.444, 20230404.479, 20230419.494, 20230524.529, 20230815.614, 20231004.665, 20231113.707)
ERROR: No matching distribution found for iree-runtime==20231130.724

Seems like these versions are not in pypi yet, but I'll keep monitoring after the weekend, thanks

You need the -f https://openxla.github.io/iree/pip-release-links.html. You can find it in the requirements.txt. What I usually do is change requirements.txt and pip install -r requirements.txt

@JBloodless
Copy link
Author

Closing this since BatchNorm2d seems to be fixed with 20231130.724

AmosLewis added a commit to AmosLewis/SHARK-Turbine that referenced this issue Dec 5, 2023
To solve the batchnorm2d issue nod-ai#110
Xfail llama_test becasue missing ops from torch to linalg
Xfail uninitialized parameters test
AmosLewis added a commit to AmosLewis/SHARK-Turbine that referenced this issue Dec 5, 2023
To solve the batchnorm2d issue nod-ai#110
Xfail llama_test becasue missing ops from torch to linalg
Xfail uninitialized parameters test
AmosLewis added a commit to AmosLewis/SHARK-Turbine that referenced this issue Dec 5, 2023
To solve the batchnorm2d issue nod-ai#110
Xfail llama_test becasue missing ops from torch to linalg
AmosLewis added a commit to AmosLewis/SHARK-Turbine that referenced this issue Dec 5, 2023
To solve the batchnorm2d issue nod-ai#110
Xfail llama_test becasue missing ops from torch to linalg
stellaraccident pushed a commit that referenced this issue Dec 6, 2023
ramiro050 pushed a commit to ramiro050/iree that referenced this issue Dec 19, 2023
Bumping torch-mlir on 11/29/2023 up to commit e568f7e.
This can fix the BatchNorm2d runtime error:
nod-ai/SHARK-Turbine#110
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants