-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
examples/torchscript_resnet18.py doesn't work because incompatible function arguments #2298
Comments
don't know the reason, try to build pytorch from source code and it works |
I got the same error as you when running unit tests
I did not recompile torch. I updated the versions in pytorch-requirements.txt and torchvision-requirements.txt :
Then recompile torch-mlir and it work |
I got an initial report of this this morning on: nod-ai/SHARK-Studio#1989 but we hadn't concluded if it was isolated/something wrong with that use case. We did determine that this was not broken as of: https://github.com/llvm/torch-mlir/releases/tag/snapshot-20231119.1027 Given the timing of the issue, we suspect that whatever it is broke as a side effect of this patch (#2582) but don't have a working theory yet as to why this has started flaking. It appears to be a permutation of a very old issue from the early days that perhaps was never truly squashed. If memory serves, the problem comes from the PyTorch and torch-mlir Python extensions seeing different subsets of certain key class identity symbols. The only thing that changed in the above patch is the mechanism of determining PyTorch compilation flags, and that is precisely the thing that was causing this old issue (it is necessary to "massage" pybind11 to agree on the C++ ABI when dealing with binary packages). That at least gives me a lead to follow for tomorrow. For the moment, I would recommend pinning to the above release or syncing to prior to the above patch if you are experiencing this issue. |
Can you also confirm for me what compiler you are using and see if you can get CMake logs like these from your invocation:
|
This is what I get on a source build of torch-mlir (via
Is this what you were referring to? Also, compiler:
Isn't the torch |
Yes, I expect that is the issue. And thanks for confirming this is with GCC. The That means that the torch-mlir Python extensions are not being compiled to be compatible with the PyTorch version with which they need to mate, and the result will be that native Python types defined in PyTorch proper will appear to be distinct types from those in the torch-mlir extensions. And that will result in the signature mismatch errors you see. Without detecting the right ABI flags, it is a coin toss whether your system defaults line up. This is useful. I need to repro this setup and find a fix. I'm building with a different setup which is likely why I'm not seeing it. |
The plot thickens. I repro'd this situation but only on the very first cmake invocation in a build directory. In subsequent configures, it detects the flags properly. A bad theory is forming in my mind. In the prior arrangement, we were configuring PyTorch multiple times in each directory that needed it. I think this equaled (somehow) it being mis-computed wrong once but then somehow latching correctly for the others. It is probably not observably fatal if most of the places that were doing this got it wrong, so coin flips. |
I was try to compile torch-mlir and run the testcase examples/torchscript_resnet18.py, does anyone else came across such kind of problem?
Traceback (most recent call last):
File "/home/yinrun/hp_workspace/torch-mlir/examples/torchscript_resnet18.py", line 70, in
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
File "/home/yinrun/hp_workspace/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/init.py", line 359, in compile
class_annotator.exportNone(scripted._c._type())
TypeError: exportNone(): incompatible function arguments. The following argument types are supported:
1. (self: torch_mlir._mlir_libs._jit_ir_importer.ClassAnnotator, arg0: c10::ClassType) -> None
Invoked with: ClassAnnotator {
}
, torch.torchvision.models.resnet.ResNet
The text was updated successfully, but these errors were encountered: