diff --git a/tests/wrappers/test_pytorch_wrapper.py b/tests/wrappers/test_pytorch_wrapper.py index 7fb41225..a910cd38 100644 --- a/tests/wrappers/test_pytorch_wrapper.py +++ b/tests/wrappers/test_pytorch_wrapper.py @@ -1,7 +1,7 @@ import numpy as np -import torch.nn as nn import tensorflow as tf +import torch.nn as nn import pytest diff --git a/xplique/wrappers/pytorch.py b/xplique/wrappers/pytorch.py index 78e0a1ce..28c3b1e2 100644 --- a/xplique/wrappers/pytorch.py +++ b/xplique/wrappers/pytorch.py @@ -27,7 +27,12 @@ def __init__(self, torch_model: "nn.Module", device: Union["torch.device", str], is_channel_first: Optional[bool] = None ): # pylint: disable=C0415,C0103 - super().__init__() + try: + super().__init__() + except tf.errors.InternalError as error: + raise Exception("If you have a tensorflow InternalError with cudaGetDevice() here, \ + it is possible that importing tensorflow before torch might resolve the issue." + ) from error try: # use PyTorch functionality