diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 1aaad9924bc01..8010b72f601c1 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -51,8 +51,10 @@ def parse_version(root): # Workaround for https://issues.apache.org/jira/browse/ARROW-2657 +# and https://issues.apache.org/jira/browse/ARROW-2920 if _sys.platform in ('linux', 'linux2'): compat.import_tensorflow_extension() + compat.import_pytorch_extension() from pyarrow.lib import cpu_count, set_cpu_count diff --git a/python/pyarrow/compat.py b/python/pyarrow/compat.py index 47aeaa5bfd5f7..de0a14d30888b 100644 --- a/python/pyarrow/compat.py +++ b/python/pyarrow/compat.py @@ -160,31 +160,17 @@ def encode_file_path(path): # will convert utf8 to utf16 return encoded_path -def import_tensorflow_extension(): +def _iterate_python_module_paths(package_name): """ - Load the TensorFlow extension if it exists. + Return an iterator to full paths of a python package. - This is used to load the TensorFlow extension before - pyarrow.lib. If we don't do this there are symbol clashes - between TensorFlow's use of threading and our global - thread pool, see also - https://issues.apache.org/jira/browse/ARROW-2657 and - https://github.com/apache/arrow/pull/2096. + This is a best effort and might fail (for example on Python 2). + It uses the official way of loading modules from + https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module """ - import os - tensorflow_loaded = False - - # Try to load the tensorflow extension directly - # This is a performance optimization, tensorflow will always be - # loaded via the "import tensorflow" statement below if this - # doesn't succeed. - # - # This uses the official way of loading modules from - # https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module - try: import importlib - absolute_name = importlib.util.resolve_name("tensorflow", None) + absolute_name = importlib.util.resolve_name(package_name, None) except (ImportError, AttributeError): # Sometimes, importlib is not available (e.g. Python 2) # or importlib.util is not available (e.g. Python 2.7) @@ -205,16 +191,37 @@ def import_tensorflow_extension(): if spec: module = importlib.util.module_from_spec(spec) for path in module.__path__: - ext = os.path.join(path, "libtensorflow_framework.so") - if os.path.exists(ext): - import ctypes - try: - ctypes.CDLL(ext) - except OSError: - pass - tensorflow_loaded = True - break + yield path + +def import_tensorflow_extension(): + """ + Load the TensorFlow extension if it exists. + + This is used to load the TensorFlow extension before + pyarrow.lib. If we don't do this there are symbol clashes + between TensorFlow's use of threading and our global + thread pool, see also + https://issues.apache.org/jira/browse/ARROW-2657 and + https://github.com/apache/arrow/pull/2096. + """ + import os + tensorflow_loaded = False + + # Try to load the tensorflow extension directly + # This is a performance optimization, tensorflow will always be + # loaded via the "import tensorflow" statement below if this + # doesn't succeed. + for path in _iterate_python_module_paths("tensorflow"): + ext = os.path.join(path, "libtensorflow_framework.so") + if os.path.exists(ext): + import ctypes + try: + ctypes.CDLL(ext) + except OSError: + pass + tensorflow_loaded = True + break # If the above failed, try to load tensorflow the normal way # (this is more expensive) @@ -225,6 +232,21 @@ def import_tensorflow_extension(): except ImportError: pass +def import_pytorch_extension(): + """ + Load the PyTorch extension if it exists. + + This is used to load the PyTorch extension before + pyarrow.lib. If we don't do this there are symbol clashes + between PyTorch's use of threading and our global + thread pool, see also + https://issues.apache.org/jira/browse/ARROW-2920 + """ + import ctypes + + for path in _iterate_python_module_paths("torch"): + ctypes.CDLL(os.path.join(path, "lib/libcaffe2.so") + integer_types = six.integer_types + (np.integer,)