diff --git a/.bazelrc b/.bazelrc index 7a1ab46ece5..768d3ccda38 100644 --- a/.bazelrc +++ b/.bazelrc @@ -12,6 +12,7 @@ build:cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm # Options used to build with TPU support. build:tpu --distinct_host_configuration=false build:tpu --define=with_tpu_support=true --define=framework_shared_object=false +build:tpu --copt=-DLIBTPU_ON_GCE # Please note that MKL on MacOS or windows is still not supported. # If you would like to use a local MKL instead of downloading, please set the diff --git a/tensorflow_serving/model_servers/BUILD b/tensorflow_serving/model_servers/BUILD index 548c57192b1..51e92e4dc0e 100644 --- a/tensorflow_serving/model_servers/BUILD +++ b/tensorflow_serving/model_servers/BUILD @@ -395,7 +395,8 @@ cc_library( "@org_tensorflow//tensorflow/core:lib", "@org_tensorflow//tensorflow/core/platform/cloud:gcs_file_system", ] + if_libtpu([ - "@org_tensorflow//tensorflow/core/tpu:tpu_model_server_initializer", + "@org_tensorflow//tensorflow/core/tpu:tpu_global_init", + "@org_tensorflow//tensorflow/core/tpu:tpu_api_dlsym_initializer", ]), ) diff --git a/tensorflow_serving/model_servers/main.cc b/tensorflow_serving/model_servers/main.cc index 1e1a5af41c1..5079a9c84c8 100644 --- a/tensorflow_serving/model_servers/main.cc +++ b/tensorflow_serving/model_servers/main.cc @@ -50,6 +50,9 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/init_main.h" +#if defined(LIBTPU_ON_GCE) +#include "tensorflow/core/tpu/tpu_global_init.h" +#endif #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow_serving/model_servers/server.h" #include "tensorflow_serving/model_servers/version.h" @@ -266,6 +269,12 @@ int main(int argc, char** argv) { return -1; } +#if defined(LIBTPU_ON_GCE) + std::cout << "Initializing TPU system."; + TF_QCHECK_OK(tensorflow::InitializeTPUSystemGlobally()) + << "Failed to intialize the TPU system."; +#endif + if (display_version) { std::cout << "TensorFlow ModelServer: " << TF_Serving_Version() << "\n" << "TensorFlow Library: " << TF_Version() << "\n";