Skip to content

Commit

Permalink
fix: when on gpu it is relevant to import tensorflow before torch
Browse files Browse the repository at this point in the history
  • Loading branch information
lucashervier committed Nov 13, 2023
1 parent 8e2cdcf commit 3e4dc20
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/wrappers/test_pytorch_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import torch.nn as nn
import tensorflow as tf
import torch.nn as nn

import pytest

Expand Down
7 changes: 6 additions & 1 deletion xplique/wrappers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def __init__(self, torch_model: "nn.Module", device: Union["torch.device", str],
is_channel_first: Optional[bool] = None
): # pylint: disable=C0415,C0103

super().__init__()
try:
super().__init__()
except tf.errors.InternalError as error:
raise Exception("If you have a tensorflow InternalError with cudaGetDevice() here, \
it is possible that importing tensorflow before torch might resolve the issue."
) from error

try:
# use PyTorch functionality
Expand Down

0 comments on commit 3e4dc20

Please sign in to comment.