From c18cccb67cfe2ac8c62eceb934308e46932cdbd2 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Jul 2018 15:58:09 -0700 Subject: [PATCH] address comments --- python/pyarrow/__init__.py | 22 +++++----------------- python/pyarrow/compat.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 1662cd0d25d48..c8f23ca3fee0c 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -44,23 +44,11 @@ def parse_version(root): __version__ = None -try: - # We need to do the following to load TensorFlow before - # pyarrow.lib. If we don't do this there are symbol clashes - # between TensorFlow's use of threading and our global - # thread pool, see also - # https://issues.apache.org/jira/browse/ARROW-2657 and - # https://github.com/apache/arrow/pull/2096. - import ctypes - import os - from sys import platform - import site - if platform == "linux" or platform == "linux2": - SITE_PATH, = site.getsitepackages() - ctypes.CDLL(os.path.join(SITE_PATH, "tensorflow", - "libtensorflow_framework.so")) -except: - pass +from pyarrow.compat import import_tensorflow_extension + + +# Workaround for https://issues.apache.org/jira/browse/ARROW-2657 +import_tensorflow_extension() from pyarrow.lib import cpu_count, set_cpu_count diff --git a/python/pyarrow/compat.py b/python/pyarrow/compat.py index 1b19ca0e4029b..b2e98082253b5 100644 --- a/python/pyarrow/compat.py +++ b/python/pyarrow/compat.py @@ -160,6 +160,25 @@ def encode_file_path(path): # will convert utf8 to utf16 return encoded_path +def import_tensorflow_extension(): + """ + Load the TensorFlow extension if it exists. + + This is used to load the TensorFlow extension before + pyarrow.lib. If we don't do this there are symbol clashes + between TensorFlow's use of threading and our global + thread pool, see also + https://issues.apache.org/jira/browse/ARROW-2657 and + https://github.com/apache/arrow/pull/2096. + """ + import os + import site + for site_path in site.getsitepackages(): + ext = os.path.join(site_path, "tensorflow", + "libtensorflow_framework.so") + if os.path.exists(ext): + import ctypes + ctypes.CDLL(ext) integer_types = six.integer_types + (np.integer,)