Skip to content

Commit

Permalink
workaround for pyarrow segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz committed Jul 27, 2018
1 parent 47e462f commit 10c5a5c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 29 deletions.
2 changes: 2 additions & 0 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def parse_version(root):


# Workaround for https://issues.apache.org/jira/browse/ARROW-2657
# and https://issues.apache.org/jira/browse/ARROW-2920
if _sys.platform in ('linux', 'linux2'):
compat.import_tensorflow_extension()
compat.import_pytorch_extension()


from pyarrow.lib import cpu_count, set_cpu_count
Expand Down
80 changes: 51 additions & 29 deletions python/pyarrow/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,31 +160,17 @@ def encode_file_path(path):
# will convert utf8 to utf16
return encoded_path

def import_tensorflow_extension():
def _iterate_python_module_paths(package_name):
"""
Load the TensorFlow extension if it exists.
Return an iterator to full paths of a python package.
This is used to load the TensorFlow extension before
pyarrow.lib. If we don't do this there are symbol clashes
between TensorFlow's use of threading and our global
thread pool, see also
https://issues.apache.org/jira/browse/ARROW-2657 and
https://github.com/apache/arrow/pull/2096.
This is a best effort and might fail (for example on Python 2).
It uses the official way of loading modules from
https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module
"""
import os
tensorflow_loaded = False

# Try to load the tensorflow extension directly
# This is a performance optimization, tensorflow will always be
# loaded via the "import tensorflow" statement below if this
# doesn't succeed.
#
# This uses the official way of loading modules from
# https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module

try:
import importlib
absolute_name = importlib.util.resolve_name("tensorflow", None)
absolute_name = importlib.util.resolve_name(package_name, None)
except (ImportError, AttributeError):
# Sometimes, importlib is not available (e.g. Python 2)
# or importlib.util is not available (e.g. Python 2.7)
Expand All @@ -205,16 +191,37 @@ def import_tensorflow_extension():
if spec:
module = importlib.util.module_from_spec(spec)
for path in module.__path__:
ext = os.path.join(path, "libtensorflow_framework.so")
if os.path.exists(ext):
import ctypes
try:
ctypes.CDLL(ext)
except OSError:
pass
tensorflow_loaded = True
break
yield path

def import_tensorflow_extension():
"""
Load the TensorFlow extension if it exists.
This is used to load the TensorFlow extension before
pyarrow.lib. If we don't do this there are symbol clashes
between TensorFlow's use of threading and our global
thread pool, see also
https://issues.apache.org/jira/browse/ARROW-2657 and
https://github.com/apache/arrow/pull/2096.
"""
import os
tensorflow_loaded = False

# Try to load the tensorflow extension directly
# This is a performance optimization, tensorflow will always be
# loaded via the "import tensorflow" statement below if this
# doesn't succeed.

for path in _iterate_python_module_paths("tensorflow"):
ext = os.path.join(path, "libtensorflow_framework.so")
if os.path.exists(ext):
import ctypes
try:
ctypes.CDLL(ext)
except OSError:
pass
tensorflow_loaded = True
break

# If the above failed, try to load tensorflow the normal way
# (this is more expensive)
Expand All @@ -225,6 +232,21 @@ def import_tensorflow_extension():
except ImportError:
pass

def import_pytorch_extension():
"""
Load the PyTorch extension if it exists.
This is used to load the PyTorch extension before
pyarrow.lib. If we don't do this there are symbol clashes
between PyTorch's use of threading and our global
thread pool, see also
https://issues.apache.org/jira/browse/ARROW-2920
"""
import ctypes

for path in _iterate_python_module_paths("torch"):
ctypes.CDLL(os.path.join(path, "lib/libcaffe2.so")


integer_types = six.integer_types + (np.integer,)

Expand Down

0 comments on commit 10c5a5c

Please sign in to comment.