Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.export fails when model uses ImageList.from_tensors #5347

Open
dgcnz opened this issue Aug 13, 2024 · 0 comments
Open

torch.export fails when model uses ImageList.from_tensors #5347

dgcnz opened this issue Aug 13, 2024 · 0 comments

Comments

@dgcnz
Copy link

dgcnz commented Aug 13, 2024

Instructions To Reproduce the 🐛 Bug:

  1. Full runnable code or full changes you made:
from detectron2.structures import ImageList
from torch.export import export
import torch

class M(torch.nn.Module):
    def forward(self, images: list[torch.Tensor]):
        images = ImageList.from_tensors(images)
        return images.tensor
    

example_kwargs = {
    "images": list(torch.randn(1, 3, 24, 24)),
}
exported_program: torch.export.ExportedProgram = export(
    M(), (), kwargs=example_kwargs
)
  1. What exact command you run:
  2. Full logs or other relevant observations:
GuardOnDataDependentSymNode: Could not guard on data-dependent expression u1 < 0 (unhinted: u1 < 0).  (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
  File ".venv/lib/python3.10/site-packages/torch/_refs/__init__.py", line 2849, in constant_pad_nd
    if pad[pad_idx + 1] < 0:
  1. please simplify the steps as much as possible so they do not require additional resources to
    run, such as a private dataset.

Expected behavior:

If there are no obvious error in "full logs" provided above,
please tell us the expected behavior.

Environment:

-------------------------------  ----------------------------------------------------------------------------------------------
sys.platform                     darwin
Python                           3.10.12 (main, Dec  4 2023, 21:38:54) [Clang 14.0.3 (clang-1403.0.22.14.1)]
numpy                            1.26.4
detectron2                       0.6 @/Users/dgcnz/development/amsterdam/edge/detrex/detectron2/detectron2
Compiler                         clang 15.0.0
CUDA compiler                    not available
DETECTRON2_ENV_MODULE            <not set>
PyTorch                          2.3.1 @/Users/dgcnz/development/amsterdam/edge/.venv/lib/python3.10/site-packages/torch
PyTorch debug build              False
torch._C._GLIBCXX_USE_CXX11_ABI  False
GPU available                    No: torch.cuda.is_available() == False
Pillow                           10.4.0
torchvision                      0.18.1 @/Users/dgcnz/development/amsterdam/edge/.venv/lib/python3.10/site-packages/torchvision
fvcore                           0.1.5.post20221221
iopath                           0.1.9
cv2                              4.10.0
-------------------------------  ----------------------------------------------------------------------------------------------
PyTorch built with:
  - GCC 4.2
  - C++ Version: 201703
  - clang 14.0.3
  - OpenMP 201811
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: NO AVX
  - Build settings: BLAS_INFO=accelerate, BUILD_TYPE=Release, CXX_COMPILER=/Applications/Xcode_14.3.1.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/c++, CXX_FLAGS= -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_PYTORCH_METAL_EXPORT -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DUSE_COREML_DELEGATE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=braced-scalar-init -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wvla-extension -Wsuggest-override -Wnewline-eof -Winconsistent-missing-override -Winconsistent-missing-destructor-override -Wno-pass-failed -Wno-error=pedantic -Wno-error=old-style-cast -Wno-error=inconsistent-missing-override -Wno-error=inconsistent-missing-destructor-override -Wconstant-conversion -Wno-invalid-partial-specialization -Wno-missing-braces -Qunused-arguments -fcolor-diagnostics -faligned-new -Wno-unused-but-set-variable -fno-math-errno -fno-trapping-math -Werror=format -DUSE_MPS -Wno-unused-private-field -Wno-missing-braces, LAPACK_INFO=accelerate, TORCH_VERSION=2.3.1, USE_CUDA=OFF, USE_CUDNN=OFF, USE_CUSPARSELT=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant