Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz committed Jul 3, 2018
1 parent ac38837 commit c18cccb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
22 changes: 5 additions & 17 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,11 @@ def parse_version(root):
__version__ = None


try:
# We need to do the following to load TensorFlow before
# pyarrow.lib. If we don't do this there are symbol clashes
# between TensorFlow's use of threading and our global
# thread pool, see also
# https://issues.apache.org/jira/browse/ARROW-2657 and
# https://github.com/apache/arrow/pull/2096.
import ctypes
import os
from sys import platform
import site
if platform == "linux" or platform == "linux2":
SITE_PATH, = site.getsitepackages()
ctypes.CDLL(os.path.join(SITE_PATH, "tensorflow",
"libtensorflow_framework.so"))
except:
pass
from pyarrow.compat import import_tensorflow_extension


# Workaround for https://issues.apache.org/jira/browse/ARROW-2657
import_tensorflow_extension()


from pyarrow.lib import cpu_count, set_cpu_count
Expand Down
19 changes: 19 additions & 0 deletions python/pyarrow/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,25 @@ def encode_file_path(path):
# will convert utf8 to utf16
return encoded_path

def import_tensorflow_extension():
"""
Load the TensorFlow extension if it exists.
This is used to load the TensorFlow extension before
pyarrow.lib. If we don't do this there are symbol clashes
between TensorFlow's use of threading and our global
thread pool, see also
https://issues.apache.org/jira/browse/ARROW-2657 and
https://github.com/apache/arrow/pull/2096.
"""
import os
import site
for site_path in site.getsitepackages():
ext = os.path.join(site_path, "tensorflow",
"libtensorflow_framework.so")
if os.path.exists(ext):
import ctypes
ctypes.CDLL(ext)

integer_types = six.integer_types + (np.integer,)

Expand Down

0 comments on commit c18cccb

Please sign in to comment.