diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/Xspace.png b/rfcs/20200624-pluggable-device-for-tensorflow/Xspace.png
new file mode 100644
index 000000000..d79249427
Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/Xspace.png differ
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/flow.png b/rfcs/20200624-pluggable-device-for-tensorflow/flow.png
new file mode 100644
index 000000000..59b61575d
Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/flow.png differ
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/modular_TensorFlow.png b/rfcs/20200624-pluggable-device-for-tensorflow/modular_TensorFlow.png
new file mode 100644
index 000000000..39a1c26e7
Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/modular_TensorFlow.png differ
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/profiler_result.png b/rfcs/20200624-pluggable-device-for-tensorflow/profiler_result.png
new file mode 100644
index 000000000..b10cef062
Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/profiler_result.png differ
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/.bazelrc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/.bazelrc
new file mode 100644
index 000000000..ea5676684
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/.bazelrc
@@ -0,0 +1,16 @@
+build --define=use_fast_cpp_protos=true
+build --define=allow_oversize_protos=true
+
+build --spawn_strategy=standalone
+# build --strategy=Genrule=standalone
+build -c opt
+
+# Default paths for TF_SYSTEM_LIBS
+build --define=PREFIX=/usr
+build --define=LIBDIR=$(PREFIX)/lib
+build --define=INCLUDEDIR=$(PREFIX)/include
+
+# host build is useless
+build --distinct_host_configuration=false
+
+try-import %workspace%/.tf_plugin_configure.bazelrc
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/README.md b/rfcs/20200624-pluggable-device-for-tensorflow/sample/README.md
new file mode 100644
index 000000000..002eb14ca
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/README.md
@@ -0,0 +1,96 @@
+# TensorFlow Plugin demo
+This sample is a simple demo shows how to implement, build, install and run a TensorFlow plugin.
+
+## Supported OS
+* Linux
+
+## Prerequisites
+
+* [Bazel](https://docs.bazel.build/versions/master/install-ubuntu.html) (version 3.1 and above)
+* Git (version 1.8 and above)
+* Python (version 3.6 and above)
+
+## Build and Run
+
+### Linux
+1. Run the following command to install the latest `tensorflow`.
+```
+$ pip install tensorflow
+```
+2. In the plug-in `sample` code folder, configure the build options:
+```
+$ ./configure
+
+Please specify the location of python. [Default is /home/test/miniconda2/envs/sycl3.6/bin/python]:
+
+
+Found possible Python library paths:
+ /home/test/miniconda2/envs/sycl3.6/lib/python3.6/site-packages
+Please input the desired Python library path to use. Default is [/home/test/miniconda2/envs/sycl3.6/lib/python3.6/site-packages]
+
+Do you wish to build TensorFlow plug-in with MPI support? [y/N]:
+No MPI support will be enabled for TensorFlow plug-in.
+
+Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:
+```
+
+3. Built the plug-in with
+```
+$ bazel build -c opt //tensorflow_plugin/tools/pip_package:build_pip_package --verbose_failures
+```
+4. Then generate a python wheel and install it.
+```
+$ bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package .
+$ pip install tensorflow_plugins-0.0.1-cp36-cp36m-linux_x86_64.whl
+```
+5. Now we can run the TensorFlow with plug-in device enabled.
+```
+$ python
+>>> import tensorflow as tf
+>>> tf.config.list_physical_devices()
+[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:MY_DEVICE:0', device_type='MY_DEVICE')]
+```
+* Relu case:
+```
+$ python relu.py
+random_normal/RandomStandardNormal: (RandomStandardNormal): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:48:20.714819: I tensorflow/core/common_runtime/placer.cc:114] random_normal/RandomStandardNormal: (RandomStandardNormal): /job:localhost/replica:0/task:0/device:CPU:0
+random_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:48:20.714864: I tensorflow/core/common_runtime/placer.cc:114] random_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0
+random_normal: (AddV2): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:48:20.714903: I tensorflow/core/common_runtime/placer.cc:114] random_normal: (AddV2): /job:localhost/replica:0/task:0/device:CPU:0
+Relu: (Relu): /job:localhost/replica:0/task:0/device:MY_DEVICE:0
+2021-10-21 12:48:20.714937: I tensorflow/core/common_runtime/placer.cc:114] Relu: (Relu): /job:localhost/replica:0/task:0/device:MY_DEVICE:0
+random_normal/shape: (Const): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:48:20.714968: I tensorflow/core/common_runtime/placer.cc:114] random_normal/shape: (Const): /job:localhost/replica:0/task:0/device:CPU:0
+random_normal/mean: (Const): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:48:20.714997: I tensorflow/core/common_runtime/placer.cc:114] random_normal/mean: (Const): /job:localhost/replica:0/task:0/device:CPU:0
+random_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:48:20.715022: I tensorflow/core/common_runtime/placer.cc:114] random_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:CPU:0
+[2.9109507 0. 0. 0. 0. 0. 0.
+ 0. 0. 1.316411 ]
+
+```
+* Conv + Relu case:
+```
+$ python conv_relu.py
+2021-10-21 12:53:36.389514: I tensorflow/core/common_runtime/placer.cc:114] random_normal_3/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0
+random_normal_3: (AddV2): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:53:36.389537: I tensorflow/core/common_runtime/placer.cc:114] random_normal_3: (AddV2): /job:localhost/replica:0/task:0/device:CPU:0
+Relu: (Relu): /job:localhost/replica:0/task:0/device:MY_DEVICE:0
+2021-10-21 12:53:36.389565: I tensorflow/core/common_runtime/placer.cc:114] Relu: (Relu): /job:localhost/replica:0/task:0/device:MY_DEVICE:0
+Conv2D: (Conv2D): /job:localhost/replica:0/task:0/device:MY_DEVICE:0
+2021-10-21 12:53:36.389592: I tensorflow/core/common_runtime/placer.cc:114] Conv2D: (Conv2D): /job:localhost/replica:0/task:0/device:MY_DEVICE:0
+Relu_1: (Relu): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:53:36.389617: I tensorflow/core/common_runtime/placer.cc:114] Relu_1: (Relu): /job:localhost/replica:0/task:0/device:CPU:0
+Conv2D_1: (Conv2D): /job:localhost/replica:0/task:0/device:CPU:0
+2021-10-21 12:53:36.389641: I tensorflow/core/common_runtime/placer.cc:114] Conv2D_1: (Conv2D): /job:localhost/replica:0/task:0/device:CPU:0
+```
+* Profiler case:
+```
+$python test_profiler.py
+```
+
+
+
+
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/WORKSPACE b/rfcs/20200624-pluggable-device-for-tensorflow/sample/WORKSPACE
new file mode 100644
index 000000000..1fe8eb76a
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/WORKSPACE
@@ -0,0 +1,21 @@
+workspace(name = "org_tensorflow_plugin")
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
+load("//third_party:version_check.bzl", "check_bazel_version_at_least")
+
+check_bazel_version_at_least("3.1.0")
+
+load("//tensorflow_plugin:tf_configure.bzl", "tf_configure")
+
+tf_configure(name = "local_config_tf")
+
+load("//tensorflow_plugin:workspace.bzl", "clean_dep", "demo_plugin_workspace")
+
+demo_plugin_workspace()
+
+load(
+ "@bazel_toolchains//repositories:repositories.bzl",
+ bazel_toolchains_repositories = "repositories",
+)
+
+bazel_toolchains_repositories()
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/build.sh b/rfcs/20200624-pluggable-device-for-tensorflow/sample/build.sh
new file mode 100644
index 000000000..30cb29095
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/build.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+bazel build -c opt //tensorflow_plugin/tools/pip_package:build_pip_package --verbose_failures
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure b/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure
new file mode 100755
index 000000000..66b66ba54
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+
+set -e
+set -o pipefail
+
+if [ -z "$PYTHON_BIN_PATH" ]; then
+ PYTHON_BIN_PATH=$(which python || which python3 || true)
+fi
+
+# Set all env variables
+CONFIGURE_DIR=$(dirname "$0")
+"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@"
+
+echo "Configuration finished"
+
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure.py
new file mode 100644
index 000000000..b4fdb36a7
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure.py
@@ -0,0 +1,1127 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""configure script to get build parameters from user."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import errno
+import os
+import platform
+import re
+import subprocess
+import sys
+
+# pylint: disable=g-import-not-at-top
+try:
+ from shutil import which
+except ImportError:
+ from distutils.spawn import find_executable as which
+# pylint: enable=g-import-not-at-top
+
+
+_DEFAULT_GCC_TOOLCHAIN_PATH = ''
+_DEFAULT_GCC_TOOLCHAIN_TARGET = ''
+
+_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
+
+_TF_BAZELRC_FILENAME = '.tf_plugin_configure.bazelrc'
+_TF_WORKSPACE_ROOT = ''
+_TF_BAZELRC = ''
+_TF_CURRENT_BAZEL_VERSION = None
+
+NCCL_LIB_PATHS = [
+ 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
+]
+
+
+class UserInputError(Exception):
+ pass
+
+
+def is_windows():
+ return platform.system() == 'Windows'
+
+
+def is_linux():
+ return platform.system() == 'Linux'
+
+
+def is_macos():
+ return platform.system() == 'Darwin'
+
+
+def is_ppc64le():
+ return platform.machine() == 'ppc64le'
+
+
+def is_cygwin():
+ return platform.system().startswith('CYGWIN_NT')
+
+
+def get_input(question):
+ try:
+ try:
+ answer = raw_input(question)
+ except NameError:
+ answer = input(question) # pylint: disable=bad-builtin
+ except EOFError:
+ answer = ''
+ return answer
+
+
+def symlink_force(target, link_name):
+ """Force symlink, equivalent of 'ln -sf'.
+
+ Args:
+ target: items to link to.
+ link_name: name of the link.
+ """
+ try:
+ os.symlink(target, link_name)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ os.remove(link_name)
+ os.symlink(target, link_name)
+ else:
+ raise e
+
+
+def sed_in_place(filename, old, new):
+ """Replace old string with new string in file.
+
+ Args:
+ filename: string for filename.
+ old: string to replace.
+ new: new string to replace to.
+ """
+ with open(filename, 'r') as f:
+ filedata = f.read()
+ newdata = filedata.replace(old, new)
+ with open(filename, 'w') as f:
+ f.write(newdata)
+
+
+def write_to_bazelrc(line):
+ with open(_TF_BAZELRC, 'a') as f:
+ f.write(line + '\n')
+
+
+def write_action_env_to_bazelrc(var_name, var):
+ write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
+
+
+def run_shell(cmd, allow_non_zero=False):
+ if allow_non_zero:
+ try:
+ output = subprocess.check_output(cmd)
+ except subprocess.CalledProcessError as e:
+ output = e.output
+ else:
+ output = subprocess.check_output(cmd)
+ return output.decode('UTF-8').strip()
+
+
+def cygpath(path):
+ """Convert path from posix to windows."""
+ return os.path.abspath(path).replace('\\', '/')
+
+
+def get_python_path(environ_cp, python_bin_path):
+ """Get the python site package paths."""
+ python_paths = []
+ if environ_cp.get('PYTHONPATH'):
+ python_paths = environ_cp.get('PYTHONPATH').split(':')
+ try:
+ library_paths = run_shell([
+ python_bin_path, '-c',
+ 'import site; print("\\n".join(site.getsitepackages()))'
+ ]).split('\n')
+ except subprocess.CalledProcessError:
+ library_paths = [
+ run_shell([
+ python_bin_path, '-c',
+ 'from distutils.sysconfig import get_python_lib;'
+ 'print(get_python_lib())'
+ ])
+ ]
+
+ all_paths = set(python_paths + library_paths)
+
+ paths = []
+ for path in all_paths:
+ if os.path.isdir(path):
+ paths.append(path)
+ return paths
+
+
+def get_python_major_version(python_bin_path):
+ """Get the python major version."""
+ return run_shell([python_bin_path, '-c', 'import sys; print(sys.version[0])'])
+
+
+def setup_python(environ_cp):
+ """Setup python related env variables."""
+ # Get PYTHON_BIN_PATH, default is the current running python.
+ default_python_bin_path = sys.executable
+ ask_python_bin_path = ('Please specify the location of python. [Default is '
+ '%s]: ') % default_python_bin_path
+ while True:
+ python_bin_path = get_from_env_or_user_or_default(environ_cp,
+ 'PYTHON_BIN_PATH',
+ ask_python_bin_path,
+ default_python_bin_path)
+ # Check if the path is valid
+ if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
+ break
+ elif not os.path.exists(python_bin_path):
+ print('Invalid python path: %s cannot be found.' % python_bin_path)
+ else:
+ print('%s is not executable. Is it the python binary?' % python_bin_path)
+ environ_cp['PYTHON_BIN_PATH'] = ''
+
+ # Convert python path to Windows style before checking lib and version
+ if is_windows() or is_cygwin():
+ python_bin_path = cygpath(python_bin_path)
+
+ # Get PYTHON_LIB_PATH
+ python_lib_path = environ_cp.get('PYTHON_LIB_PATH')
+ if not python_lib_path:
+ python_lib_paths = get_python_path(environ_cp, python_bin_path)
+ if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1':
+ python_lib_path = python_lib_paths[0]
+ else:
+ print('Found possible Python library paths:\n %s' %
+ '\n '.join(python_lib_paths))
+ default_python_lib_path = python_lib_paths[0]
+ python_lib_path = get_input(
+ 'Please input the desired Python library path to use. '
+ 'Default is [%s]\n' % python_lib_paths[0])
+ if not python_lib_path:
+ python_lib_path = default_python_lib_path
+ environ_cp['PYTHON_LIB_PATH'] = python_lib_path
+
+ _ = get_python_major_version(python_bin_path)
+
+ # Convert python path to Windows style before writing into bazel.rc
+ if is_windows() or is_cygwin():
+ python_lib_path = cygpath(python_lib_path)
+
+ # Set-up env variables used by python_configure.bzl
+ write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
+ write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
+ write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
+ environ_cp['PYTHON_BIN_PATH'] = python_bin_path
+
+ # If choosen python_lib_path is from a path specified in the PYTHONPATH
+ # variable, need to tell bazel to include PYTHONPATH
+ if environ_cp.get('PYTHONPATH'):
+ python_paths = environ_cp.get('PYTHONPATH').split(':')
+ if python_lib_path in python_paths:
+ write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH'))
+
+ # Write tools/python_bin_path.sh
+ with open(
+ os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow_plugin', 'tools', 'python_bin_path.sh'),
+ 'w') as f:
+ f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
+
+def get_python_lib_name(environ_cp):
+ python_bin_path = environ_cp['PYTHON_BIN_PATH']
+ path_list = python_bin_path.split(os.sep)[:-2]
+ path_list.append('lib')
+ py_lib_path = os.sep.join(path_list)
+ for _, _, files in os.walk(py_lib_path):
+ for name in files:
+ if str(name).startswith('libpython') and str(name).endswith('.so'):
+ # strip libxxx.so to get xxx
+ return str(name).strip()[3:-3]
+
+
+def get_python_link_path(environ_cp):
+ # TODO(quintin): we need to link libpythonx.y.so for _pywrap_tensorflow_internal.so
+ # once google change CAPI symbols into libtensorflow.so, we don't need this
+ python_bin_path = environ_cp['PYTHON_BIN_PATH']
+ path_list = python_bin_path.split(os.sep)[:-2]
+ path_list.append('lib')
+ py_lib_path = os.sep.join(path_list)
+ return py_lib_path
+
+def create_build_configuration(environ_cp):
+
+ tf_header_dir = environ_cp['PYTHON_LIB_PATH'] + "/tensorflow/include"
+ tf_shared_lib_dir = environ_cp['PYTHON_LIB_PATH'] + "/tensorflow/"
+
+ write_action_env_to_bazelrc("TF_HEADER_DIR", tf_header_dir)
+ write_action_env_to_bazelrc("TF_SHARED_LIBRARY_DIR", tf_shared_lib_dir)
+ write_action_env_to_bazelrc("TF_CXX11_ABI_FLAG", 1)
+ write_action_env_to_bazelrc("PYTHON_LINK_LIB_NAME", get_python_lib_name(environ_cp))
+ write_action_env_to_bazelrc("PYTHON_LINK_PATH", get_python_link_path(environ_cp))
+
+
+def reset_tf_configure_bazelrc():
+ """Reset file that contains customized config settings."""
+ open(_TF_BAZELRC, 'w').close()
+
+
+def cleanup_makefile():
+ """Delete any leftover BUILD files from the Makefile build.
+
+ These files could interfere with Bazel parsing.
+ """
+ makefile_download_dir = os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow',
+ 'contrib', 'makefile', 'downloads')
+ if os.path.isdir(makefile_download_dir):
+ for root, _, filenames in os.walk(makefile_download_dir):
+ for f in filenames:
+ if f.endswith('BUILD'):
+ os.remove(os.path.join(root, f))
+
+
+def get_var(environ_cp,
+ var_name,
+ query_item,
+ enabled_by_default,
+ question=None,
+ yes_reply=None,
+ no_reply=None):
+ """Get boolean input from user.
+
+ If var_name is not set in env, ask user to enable query_item or not. If the
+ response is empty, use the default.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
+ query_item: string for feature related to the variable, e.g. "CUDA for
+ Nvidia GPUs".
+ enabled_by_default: boolean for default behavior.
+ question: optional string for how to ask for user input.
+ yes_reply: optional string for reply when feature is enabled.
+ no_reply: optional string for reply when feature is disabled.
+
+ Returns:
+ boolean value of the variable.
+
+ Raises:
+ UserInputError: if an environment variable is set, but it cannot be
+ interpreted as a boolean indicator, assume that the user has made a
+ scripting error, and will continue to provide invalid input.
+ Raise the error to avoid infinitely looping.
+ """
+ if not question:
+ question = 'Do you wish to build TensorFlow plug-in with %s support?' % query_item
+ if not yes_reply:
+ yes_reply = '%s support will be enabled for TensorFlow plug-in.' % query_item
+ if not no_reply:
+ no_reply = 'No %s' % yes_reply
+
+ yes_reply += '\n'
+ no_reply += '\n'
+
+ if enabled_by_default:
+ question += ' [Y/n]: '
+ else:
+ question += ' [y/N]: '
+
+ var = environ_cp.get(var_name)
+ if var is not None:
+ var_content = var.strip().lower()
+ true_strings = ('1', 't', 'true', 'y', 'yes')
+ false_strings = ('0', 'f', 'false', 'n', 'no')
+ if var_content in true_strings:
+ var = True
+ elif var_content in false_strings:
+ var = False
+ else:
+ raise UserInputError(
+ 'Environment variable %s must be set as a boolean indicator.\n'
+ 'The following are accepted as TRUE : %s.\n'
+ 'The following are accepted as FALSE: %s.\n'
+ 'Current value is %s.' %
+ (var_name, ', '.join(true_strings), ', '.join(false_strings), var))
+
+ while var is None:
+ user_input_origin = get_input(question)
+ user_input = user_input_origin.strip().lower()
+ if user_input == 'y':
+ print(yes_reply)
+ var = True
+ elif user_input == 'n':
+ print(no_reply)
+ var = False
+ elif not user_input:
+ if enabled_by_default:
+ print(yes_reply)
+ var = True
+ else:
+ print(no_reply)
+ var = False
+ else:
+ print('Invalid selection: %s' % user_input_origin)
+ return var
+
+
+def set_build_var(environ_cp,
+ var_name,
+ query_item,
+ option_name,
+ enabled_by_default,
+ bazel_config_name=None):
+ """Set if query_item will be enabled for the build.
+
+ Ask user if query_item will be enabled. Default is used if no input is given.
+ Set subprocess environment variable and write to .bazelrc if enabled.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
+ query_item: string for feature related to the variable, e.g. "CUDA for
+ Nvidia GPUs".
+ option_name: string for option to define in .bazelrc.
+ enabled_by_default: boolean for default behavior.
+ bazel_config_name: Name for Bazel --config argument to enable build feature.
+ """
+
+ var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
+ environ_cp[var_name] = var
+ if var == '1':
+ write_to_bazelrc('build:%s --define %s=true' %
+ (bazel_config_name, option_name))
+ write_to_bazelrc('build --config=%s' % bazel_config_name)
+ elif bazel_config_name is not None:
+ # TODO(mikecase): Migrate all users of configure.py to use --config Bazel
+ # options and not to set build configs through environment variables.
+ write_to_bazelrc('build:%s --define %s=true' %
+ (bazel_config_name, option_name))
+
+
+def set_action_env_var(environ_cp,
+ var_name,
+ query_item,
+ enabled_by_default,
+ question=None,
+ yes_reply=None,
+ no_reply=None):
+ """Set boolean action_env variable.
+
+ Ask user if query_item will be enabled. Default is used if no input is given.
+ Set environment variable and write to .bazelrc.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
+ query_item: string for feature related to the variable, e.g. "CUDA for
+ Nvidia GPUs".
+ enabled_by_default: boolean for default behavior.
+ question: optional string for how to ask for user input.
+ yes_reply: optional string for reply when feature is enabled.
+ no_reply: optional string for reply when feature is disabled.
+ """
+ var = int(
+ get_var(environ_cp, var_name, query_item, enabled_by_default, question,
+ yes_reply, no_reply))
+
+ write_action_env_to_bazelrc(var_name, var)
+ environ_cp[var_name] = str(var)
+
+
+def convert_version_to_int(version):
+ """Convert a version number to a integer that can be used to compare.
+
+ Version strings of the form X.YZ and X.Y.Z-xxxxx are supported. The
+ 'xxxxx' part, for instance 'homebrew' on OS/X, is ignored.
+
+ Args:
+ version: a version to be converted
+
+ Returns:
+ An integer if converted successfully, otherwise return None.
+ """
+ version = version.split('-')[0]
+ version_segments = version.split('.')
+ # Treat "0.24" as "0.24.0"
+ if len(version_segments) == 2:
+ version_segments.append('0')
+ for seg in version_segments:
+ if not seg.isdigit():
+ return None
+
+ version_str = ''.join(['%03d' % int(seg) for seg in version_segments])
+ return int(version_str)
+
+
+def check_bazel_version(min_version, max_version):
+ """Check installed bazel version is between min_version and max_version.
+
+ Args:
+ min_version: string for minimum bazel version (must exist!).
+ max_version: string for maximum bazel version (must exist!).
+
+ Returns:
+ The bazel version detected.
+ """
+ if which('bazel') is None:
+ print('Cannot find bazel. Please install bazel.')
+ sys.exit(0)
+ curr_version = run_shell(
+ ['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
+
+ for line in curr_version.split('\n'):
+ if 'Build label: ' in line:
+ curr_version = line.split('Build label: ')[1]
+ break
+
+ min_version_int = convert_version_to_int(min_version)
+ curr_version_int = convert_version_to_int(curr_version)
+ max_version_int = convert_version_to_int(max_version)
+
+ # Check if current bazel version can be detected properly.
+ if not curr_version_int:
+ print('WARNING: current bazel installation is not a release version.')
+ print('Make sure you are running at least bazel %s' % min_version)
+ return curr_version
+
+ print('You have bazel %s installed.' % curr_version)
+
+ if curr_version_int < min_version_int:
+ print('Please upgrade your bazel installation to version %s or higher to '
+ 'build TensorFlow!' % min_version)
+ sys.exit(1)
+ if (curr_version_int > max_version_int and
+ 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ):
+ print('Please downgrade your bazel installation to version %s or lower to '
+ 'build TensorFlow! To downgrade: download the installer for the old '
+ 'version (from https://github.com/bazelbuild/bazel/releases) then '
+ 'run the installer.' % max_version)
+ sys.exit(1)
+ return curr_version
+
+
+def set_cc_opt_flags(environ_cp):
+ """Set up architecture-dependent optimization flags.
+
+ Also append CC optimization flags to bazel.rc..
+
+ Args:
+ environ_cp: copy of the os.environ.
+ """
+ if is_ppc64le():
+ # gcc on ppc64le does not support -march, use mcpu instead
+ default_cc_opt_flags = '-mcpu=native'
+ elif is_windows():
+ default_cc_opt_flags = '/arch:AVX'
+ else:
+ default_cc_opt_flags = '-march=native -Wno-sign-compare'
+ question = ('Please specify optimization flags to use during compilation when'
+ ' bazel option "--config=opt" is specified [Default is %s]: '
+ ) % default_cc_opt_flags
+ cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS',
+ question, default_cc_opt_flags)
+ for opt in cc_opt_flags.split():
+ write_to_bazelrc('build:opt --copt=%s' % opt)
+ # It should be safe on the same build host.
+ if not is_ppc64le() and not is_windows():
+ write_to_bazelrc('build:opt --host_copt=-march=native')
+ write_to_bazelrc('build:opt --define with_default_optimizations=true')
+
+
+
+def set_tf_download_clang(environ_cp):
+ """Set TF_DOWNLOAD_CLANG action_env."""
+ question = 'Do you wish to download a fresh release of clang? (Experimental)'
+ yes_reply = 'Clang will be downloaded and used to compile tensorflow.'
+ no_reply = 'Clang will not be downloaded.'
+ set_action_env_var(
+ environ_cp,
+ 'TF_DOWNLOAD_CLANG',
+ None,
+ False,
+ question=question,
+ yes_reply=yes_reply,
+ no_reply=no_reply)
+
+
+def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
+ var_default):
+ """Get var_name either from env, or user or default.
+
+ If var_name has been set as environment variable, use the preset value, else
+ ask for user input. If no input is provided, the default is used.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
+ ask_for_var: string for how to ask for user input.
+ var_default: default value string.
+
+ Returns:
+ string value for var_name
+ """
+ var = environ_cp.get(var_name)
+ if not var:
+ var = get_input(ask_for_var)
+ print('\n')
+ if not var:
+ var = var_default
+ return var
+
+
+def prompt_loop_or_load_from_env(environ_cp,
+ var_name,
+ var_default,
+ ask_for_var,
+ check_success,
+ error_msg,
+ suppress_default_error=False,
+ n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS):
+ """Loop over user prompts for an ENV param until receiving a valid response.
+
+ For the env param var_name, read from the environment or verify user input
+ until receiving valid input. When done, set var_name in the environ_cp to its
+ new value.
+
+ Args:
+ environ_cp: (Dict) copy of the os.environ.
+ var_name: (String) string for name of environment variable, e.g. "TF_MYVAR".
+ var_default: (String) default value string.
+ ask_for_var: (String) string for how to ask for user input.
+ check_success: (Function) function that takes one argument and returns a
+ boolean. Should return True if the value provided is considered valid. May
+ contain a complex error message if error_msg does not provide enough
+ information. In that case, set suppress_default_error to True.
+ error_msg: (String) String with one and only one '%s'. Formatted with each
+ invalid response upon check_success(input) failure.
+ suppress_default_error: (Bool) Suppress the above error message in favor of
+ one from the check_success function.
+ n_ask_attempts: (Integer) Number of times to query for valid input before
+ raising an error and quitting.
+
+ Returns:
+ [String] The value of var_name after querying for input.
+
+ Raises:
+ UserInputError: if a query has been attempted n_ask_attempts times without
+ success, assume that the user has made a scripting error, and will
+ continue to provide invalid input. Raise the error to avoid infinitely
+ looping.
+ """
+ default = environ_cp.get(var_name) or var_default
+ full_query = '%s [Default is %s]: ' % (
+ ask_for_var,
+ default,
+ )
+
+ for _ in range(n_ask_attempts):
+ val = get_from_env_or_user_or_default(environ_cp, var_name, full_query,
+ default)
+ if check_success(val):
+ break
+ if not suppress_default_error:
+ print(error_msg % val)
+ environ_cp[var_name] = ''
+ else:
+ raise UserInputError('Invalid %s setting was provided %d times in a row. '
+ 'Assuming to be a scripting mistake.' %
+ (var_name, n_ask_attempts))
+
+ environ_cp[var_name] = val
+ return val
+
+
+def create_android_ndk_rule(environ_cp):
+ """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule."""
+ if is_windows() or is_cygwin():
+ default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' %
+ environ_cp['APPDATA'])
+ elif is_macos():
+ default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
+ else:
+ default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME']
+
+ def valid_ndk_path(path):
+ return (os.path.exists(path) and
+ os.path.exists(os.path.join(path, 'source.properties')))
+
+ android_ndk_home_path = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='ANDROID_NDK_HOME',
+ var_default=default_ndk_path,
+ ask_for_var='Please specify the home path of the Android NDK to use.',
+ check_success=valid_ndk_path,
+ error_msg=('The path %s or its child file "source.properties" '
+ 'does not exist.'))
+ write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
+ write_action_env_to_bazelrc(
+ 'ANDROID_NDK_API_LEVEL',
+ get_ndk_api_level(environ_cp, android_ndk_home_path))
+
+
+def create_android_sdk_rule(environ_cp):
+ """Set Android variables and write Android SDK WORKSPACE rule."""
+ if is_windows() or is_cygwin():
+ default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA'])
+ elif is_macos():
+ default_sdk_path = '%s/library/Android/Sdk' % environ_cp['HOME']
+ else:
+ default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME']
+
+ def valid_sdk_path(path):
+ return (os.path.exists(path) and
+ os.path.exists(os.path.join(path, 'platforms')) and
+ os.path.exists(os.path.join(path, 'build-tools')))
+
+ android_sdk_home_path = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='ANDROID_SDK_HOME',
+ var_default=default_sdk_path,
+ ask_for_var='Please specify the home path of the Android SDK to use.',
+ check_success=valid_sdk_path,
+ error_msg=('Either %s does not exist, or it does not contain the '
+ 'subdirectories "platforms" and "build-tools".'))
+
+ platforms = os.path.join(android_sdk_home_path, 'platforms')
+ api_levels = sorted(os.listdir(platforms))
+ api_levels = [x.replace('android-', '') for x in api_levels]
+
+ def valid_api_level(api_level):
+ return os.path.exists(
+ os.path.join(android_sdk_home_path, 'platforms',
+ 'android-' + api_level))
+
+ android_api_level = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='ANDROID_API_LEVEL',
+ var_default=api_levels[-1],
+ ask_for_var=('Please specify the Android SDK API level to use. '
+ '[Available levels: %s]') % api_levels,
+ check_success=valid_api_level,
+ error_msg='Android-%s is not present in the SDK path.')
+
+ build_tools = os.path.join(android_sdk_home_path, 'build-tools')
+ versions = sorted(os.listdir(build_tools))
+
+ def valid_build_tools(version):
+ return os.path.exists(
+ os.path.join(android_sdk_home_path, 'build-tools', version))
+
+ android_build_tools_version = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='ANDROID_BUILD_TOOLS_VERSION',
+ var_default=versions[-1],
+ ask_for_var=('Please specify an Android build tools version to use. '
+ '[Available versions: %s]') % versions,
+ check_success=valid_build_tools,
+ error_msg=('The selected SDK does not have build-tools version %s '
+ 'available.'))
+
+ write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION',
+ android_build_tools_version)
+ write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', android_api_level)
+ write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path)
+
+
+def get_ndk_api_level(environ_cp, android_ndk_home_path):
+ """Gets the appropriate NDK API level to use for the provided Android NDK path."""
+
+ # First check to see if we're using a blessed version of the NDK.
+ properties_path = '%s/source.properties' % android_ndk_home_path
+ if is_windows() or is_cygwin():
+ properties_path = cygpath(properties_path)
+ with open(properties_path, 'r') as f:
+ filedata = f.read()
+
+ revision = re.search(r'Pkg.Revision = (\d+)', filedata)
+ if revision:
+ ndk_version = revision.group(1)
+ else:
+ raise Exception('Unable to parse NDK revision.')
+ if int(ndk_version) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
+ print('WARNING: The NDK version in %s is %s, which is not '
+ 'supported by Bazel (officially supported versions: %s). Please use '
+ 'another version. Compiling Android targets may result in confusing '
+ 'errors.\n' % (android_ndk_home_path, ndk_version,
+ _SUPPORTED_ANDROID_NDK_VERSIONS))
+
+ # Now grab the NDK API level to use. Note that this is different from the
+ # SDK API level, as the NDK API level is effectively the *min* target SDK
+ # version.
+ platforms = os.path.join(android_ndk_home_path, 'platforms')
+ api_levels = sorted(os.listdir(platforms))
+ api_levels = [
+ x.replace('android-', '') for x in api_levels if 'android-' in x
+ ]
+
+ def valid_api_level(api_level):
+ return os.path.exists(
+ os.path.join(android_ndk_home_path, 'platforms',
+ 'android-' + api_level))
+
+ android_ndk_api_level = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='ANDROID_NDK_API_LEVEL',
+ var_default='18', # 18 is required for GPU acceleration.
+ ask_for_var=('Please specify the (min) Android NDK API level to use. '
+ '[Available levels: %s]') % api_levels,
+ check_success=valid_api_level,
+ error_msg='Android-%s is not present in the NDK path.')
+
+ return android_ndk_api_level
+
+
+def set_gcc_host_compiler_path(environ_cp):
+ """Set GCC_HOST_COMPILER_PATH."""
+ default_gcc_host_compiler_path = which('gcc')
+ if os.path.islink(default_gcc_host_compiler_path):
+ # os.readlink is only available in linux
+ default_gcc_host_compiler_path = os.path.realpath(default_gcc_host_compiler_path)
+
+ gcc_host_compiler_path = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='GCC_HOST_COMPILER_PATH',
+ var_default=default_gcc_host_compiler_path,
+ ask_for_var='Please specify which gcc should be used by nvcc as the host compiler.',
+ check_success=os.path.exists,
+ error_msg='Invalid gcc path. %s cannot be found.',
+ )
+
+ write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path)
+
+
+def reformat_version_sequence(version_str, sequence_count):
+ """Reformat the version string to have the given number of sequences.
+
+ For example:
+ Given (7, 2) -> 7.0
+ (7.0.1, 2) -> 7.0
+ (5, 1) -> 5
+ (5.0.3.2, 1) -> 5
+
+ Args:
+ version_str: String, the version string.
+ sequence_count: int, an integer.
+
+ Returns:
+ string, reformatted version string.
+ """
+ v = version_str.split('.')
+ if len(v) < sequence_count:
+ v = v + (['0'] * (sequence_count - len(v)))
+
+ return '.'.join(v[:sequence_count])
+
+
+def set_host_cxx_compiler(environ_cp):
+ """Set HOST_CXX_COMPILER."""
+ default_cxx_host_compiler = which('g++') or ''
+
+ host_cxx_compiler = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='HOST_CXX_COMPILER',
+ var_default=default_cxx_host_compiler,
+ ask_for_var=('Please specify which C++ compiler should be used as the '
+ 'host C++ compiler.'),
+ check_success=os.path.exists,
+ error_msg='Invalid C++ compiler path. %s cannot be found.',
+ )
+
+ write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler)
+
+
+def set_host_c_compiler(environ_cp):
+ """Set HOST_C_COMPILER."""
+ default_c_host_compiler = which('gcc') or ''
+
+ host_c_compiler = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='HOST_C_COMPILER',
+ var_default=default_c_host_compiler,
+ ask_for_var=('Please specify which C compiler should be used as the host '
+ 'C compiler.'),
+ check_success=os.path.exists,
+ error_msg='Invalid C compiler path. %s cannot be found.',
+ )
+
+ write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler)
+
+
+def set_opencl_sdk_root(environ_cp):
+ """Set OPENCL SDK ROOT"""
+
+ def toolkit_exists(toolkit_path):
+ """Check if a CL header path is valid."""
+ if toolkit_path == '':
+ return True
+
+ if is_linux():
+ cl_header_path = 'opencl/SDK/include/CL/cl.h'
+ else:
+ cl_header_path = ''
+
+ cl_path_full = os.path.join(toolkit_path, cl_header_path)
+ exists = os.path.exists(cl_path_full)
+ if not exists:
+ print('Invalid OPENCL SDK ROOT path. %s cannot be found' %
+ (cl_path_full))
+ return exists
+
+ ocl_sdk_root = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='OCL_SDK_ROOT',
+ var_default=_DEFAULT_OCL_SDK_ROOT,
+ ask_for_var=(
+ 'Please specify the location of opencl SDK install path '
+ 'for ocl headers and libOpenCL.so'),
+ check_success=toolkit_exists,
+ error_msg='Invalid OPENCL SDK ROOT path.',
+ suppress_default_error=True)
+
+ write_action_env_to_bazelrc('OCL_SDK_ROOT',
+ ocl_sdk_root)
+
+def set_gcc_toolchain_path(environ_cp):
+ """Set GCC_TOOLCHAIN_PATH."""
+ def no_check(arg):
+ return True
+
+ gcc_toolchain_path = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='GCC_TOOLCHAIN_PATH',
+ var_default=_DEFAULT_GCC_TOOLCHAIN_PATH,
+ ask_for_var=(
+ 'Please specify the location of gcc toolchain used by the compiler'),
+ check_success=no_check,
+ error_msg='Invalid GCC_TOOLCHAIN path.',
+ suppress_default_error=True)
+
+ write_action_env_to_bazelrc('GCC_TOOLCHAIN_PATH',
+ gcc_toolchain_path)
+ return gcc_toolchain_path
+
+def set_gcc_toolchain_target(environ_cp, gcc_toolchain_path):
+ """Set GCC_TOOLCHAIN_TARGET."""
+ if gcc_toolchain_path == "":
+ return ""
+
+ def toolkit_exists(target):
+ """Check if a gcc toolchain-target is valid."""
+ if is_linux():
+ if target == '':
+ gcc_bin_path = 'bin/gcc'
+ else:
+ gcc_bin_path = 'bin/' + target + '-gcc'
+ else:
+ gcc_bin_path = ''
+
+ gcc_bin_path_full = os.path.join(gcc_toolchain_path, gcc_bin_path)
+ exists = os.path.exists(gcc_bin_path_full)
+ if not exists:
+ print('Invalid GCC_TOOLCHAIN path and TARGET. %s cannot be found' %
+ (gcc_bin_path_full))
+ return exists
+
+ gcc_toolchain_target = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='GCC_TOOLCHAIN_TARGET',
+ var_default=_DEFAULT_GCC_TOOLCHAIN_TARGET,
+ ask_for_var=(
+ 'Please specify the target of gcc toolchain (e.g. x86_64-pc-linux) '
+ 'the compiler will use.'),
+ check_success=toolkit_exists,
+ error_msg='Invalid GCC_TOOLCHAIN_TARGET',
+ suppress_default_error=True)
+
+ write_action_env_to_bazelrc('GCC_TOOLCHAIN_TARGET',
+ gcc_toolchain_target)
+
+def set_mpi_home(environ_cp):
+ """Set MPI_HOME."""
+
+ default_mpi_home = which('mpirun') or which('mpiexec') or ''
+ default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
+
+ def valid_mpi_path(mpi_home):
+ exists = (
+ os.path.exists(os.path.join(mpi_home, 'include')) and
+ (os.path.exists(os.path.join(mpi_home, 'lib')) or
+ os.path.exists(os.path.join(mpi_home, 'lib64')) or
+ os.path.exists(os.path.join(mpi_home, 'lib32'))))
+ if not exists:
+ print(
+ 'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found'
+ % (os.path.join(mpi_home, 'include'),
+ os.path.exists(os.path.join(mpi_home, 'lib')),
+ os.path.exists(os.path.join(mpi_home, 'lib64')),
+ os.path.exists(os.path.join(mpi_home, 'lib32'))))
+ return exists
+
+ _ = prompt_loop_or_load_from_env(
+ environ_cp,
+ var_name='MPI_HOME',
+ var_default=default_mpi_home,
+ ask_for_var='Please specify the MPI toolkit folder.',
+ check_success=valid_mpi_path,
+ error_msg='',
+ suppress_default_error=True)
+
+
+def set_other_mpi_vars(environ_cp):
+ """Set other MPI related variables."""
+ # Link the MPI header files
+ mpi_home = environ_cp.get('MPI_HOME')
+ symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h')
+
+ # Determine if we use OpenMPI or MVAPICH, these require different header files
+ # to be included here to make bazel dependency checker happy
+ if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')):
+ symlink_force(
+ os.path.join(mpi_home, 'include/mpi_portable_platform.h'),
+ 'third_party/mpi/mpi_portable_platform.h')
+ # TODO(gunan): avoid editing files in configure
+ sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=False',
+ 'MPI_LIB_IS_OPENMPI=True')
+ else:
+ # MVAPICH / MPICH
+ symlink_force(
+ os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h')
+ symlink_force(
+ os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h')
+ # TODO(gunan): avoid editing files in configure
+ sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=True',
+ 'MPI_LIB_IS_OPENMPI=False')
+
+ if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
+ symlink_force(
+ os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
+ elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')):
+ symlink_force(
+ os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so')
+ elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')):
+ symlink_force(
+ os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so')
+
+ else:
+ raise ValueError(
+ 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
+ (mpi_home, mpi_home, mpi_home))
+
+
+def set_system_libs_flag(environ_cp):
+ syslibs = environ_cp.get('TF_SYSTEM_LIBS', '')
+ if syslibs:
+ if ',' in syslibs:
+ syslibs = ','.join(sorted(syslibs.split(',')))
+ else:
+ syslibs = ','.join(sorted(syslibs.split()))
+ write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs)
+
+ if 'PREFIX' in environ_cp:
+ write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX'])
+ if 'LIBDIR' in environ_cp:
+ write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR'])
+ if 'INCLUDEDIR' in environ_cp:
+ write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR'])
+
+
+def set_windows_build_flags(environ_cp):
+ """Set Windows specific build options."""
+ # The non-monolithic build is not supported yet
+ write_to_bazelrc('build --config monolithic')
+ # Suppress warning messages
+ write_to_bazelrc('build --copt=-w --host_copt=-w')
+ # Fix winsock2.h conflicts
+ write_to_bazelrc(
+ 'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN '
+ '--copt=-DNOGDI --host_copt=-DNOGDI')
+ # Output more verbose information when something goes wrong
+ write_to_bazelrc('build --verbose_failures')
+ # The host and target platforms are the same in Windows build. So we don't
+ # have to distinct them. This avoids building the same targets twice.
+ write_to_bazelrc('build --distinct_host_configuration=false')
+
+ if get_var(
+ environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',
+ True, ('Would you like to override eigen strong inline for some C++ '
+ 'compilation to reduce the compilation time?'),
+ 'Eigen strong inline overridden.', 'Not overriding eigen strong inline, '
+ 'some compilations could take more than 20 mins.'):
+ # Due to a known MSVC compiler issue
+ # https://github.com/tensorflow/tensorflow/issues/10521
+ # Overriding eigen strong inline speeds up the compiling of
+ # conv_grad_ops_3d.cc and conv_ops_3d.cc by 20 minutes,
+ # but this also hurts the performance. Let users decide what they want.
+ write_to_bazelrc('build --define=override_eigen_strong_inline=true')
+
+
+def config_info_line(name, help_text):
+ """Helper function to print formatted help text for Bazel config options."""
+ print('\t--config=%-12s\t# %s' % (name, help_text))
+
+
+def main():
+ global _TF_WORKSPACE_ROOT
+ global _TF_BAZELRC
+ global _TF_CURRENT_BAZEL_VERSION
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--workspace',
+ type=str,
+ default=os.path.abspath(os.path.dirname(__file__)),
+ help='The absolute path to your active Bazel workspace.')
+ args = parser.parse_args()
+
+ _TF_WORKSPACE_ROOT = args.workspace
+ _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME)
+
+ # Make a copy of os.environ to be clear when functions and getting and setting
+ # environment variables.
+ environ_cp = dict(os.environ)
+
+ current_bazel_version = check_bazel_version('3.1.0', '3.7.0')
+ _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
+
+ reset_tf_configure_bazelrc()
+
+ cleanup_makefile()
+ setup_python(environ_cp)
+ create_build_configuration(environ_cp)
+
+ if is_windows():
+ environ_cp['TF_DOWNLOAD_CLANG'] = '0'
+ environ_cp['TF_NEED_MPI'] = '0'
+
+ # The numpy package on ppc64le uses OpenBLAS which has multi-threading
+ # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at
+ # runtime to allow the Tensorflow testcases which compare numpy
+ # results to Tensorflow results to succeed.
+ if is_ppc64le():
+ write_action_env_to_bazelrc('OMP_NUM_THREADS', 1)
+
+
+ set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
+ if environ_cp.get('TF_NEED_MPI') == '1':
+ set_mpi_home(environ_cp)
+ set_other_mpi_vars(environ_cp)
+
+ set_cc_opt_flags(environ_cp)
+ set_system_libs_flag(environ_cp)
+ if is_windows():
+ set_windows_build_flags(environ_cp)
+
+if __name__ == '__main__':
+ main()
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/conv_relu.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/conv_relu.py
new file mode 100644
index 000000000..276433c24
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/conv_relu.py
@@ -0,0 +1,38 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+#!/usr/bin/env python
+# coding=utf-8
+import tensorflow as tf
+import numpy as np
+tf.compat.v1.disable_eager_execution()
+a = tf.random.normal(shape=[1,10, 10, 8], dtype=tf.float32, seed=1)
+w = tf.random.normal(shape=[3, 3, 8, 4], dtype=tf.float32, seed=1)
+
+a1 = tf.random.normal(shape=[1, 10, 10, 8], dtype=tf.float32, seed=1)
+w1 = tf.random.normal(shape=[3, 3, 8, 4], dtype=tf.float32, seed=1)
+
+
+with tf.device("/MY_DEVICE:0"):
+ b = tf.nn.relu(a)
+ c = tf.nn.conv2d(b, w, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC')
+
+with tf.device("/CPU:0"):
+ b1 = tf.nn.relu(a1)
+ c1 = tf.nn.conv2d(b1, w1, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC')
+
+
+sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=False, log_device_placement=True))
+print(sess.run(tf.reduce_all(tf.less(c - c1, 1e-5))))
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/genpip.sh b/rfcs/20200624-pluggable-device-for-tensorflow/sample/genpip.sh
new file mode 100644
index 000000000..f344dfc41
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/genpip.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+PACKAGE_PATH=$PWD
+bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package $PACKAGE_PATH
+
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/profiler_result.png b/rfcs/20200624-pluggable-device-for-tensorflow/sample/profiler_result.png
new file mode 100644
index 000000000..b10cef062
Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/sample/profiler_result.png differ
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/relu.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/relu.py
new file mode 100644
index 000000000..7ca4574ab
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/relu.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+
+# coding=utf-8
+import tensorflow as tf
+import numpy as np
+tf.compat.v1.disable_eager_execution()
+a = tf.random.normal(shape=[10], dtype=tf.float32)
+
+with tf.device("/MY_DEVICE:0"):
+ b = tf.nn.relu(a)
+
+sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=False, log_device_placement=True))
+print(sess.run(b))
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/BUILD
new file mode 100644
index 000000000..636069aa7
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/BUILD
@@ -0,0 +1,27 @@
+cc_binary(
+ name = "libdemo_plugin.so",
+ linkshared = True,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_plugin/src:plugin_device",
+ "//tensorflow_plugin/src:plugin_graph",
+ "//tensorflow_plugin/src:plugin_kernel",
+ "//tensorflow_plugin/src:plugin_profiler",
+ ],
+)
+
+config_setting(
+ name = "linux_x86_64",
+ values = {"cpu": "k8"},
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "core",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_plugin/src/device/cpu:cpu_device_impl",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ alwayslink = True,
+)
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/build_config.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/build_config.bzl
new file mode 100644
index 000000000..01d4a549c
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/build_config.bzl
@@ -0,0 +1,35 @@
+# Platform-specific build configurations.
+
+load("@com_google_protobuf//:protobuf.bzl", "proto_gen")
+load("//tensorflow_plugin:workspace.bzl", "clean_dep")
+load("@rules_cc//cc:defs.bzl", "cc_library")
+
+def cc_proto(name, src, deps = []):
+ native.genrule(
+ name = "%s_cc" % name,
+ outs = ["%s.pb.cc" % name, "%s.pb.h" % name],
+ cmd = "echo $(GENDIR); which $(location @com_google_protobuf//:protoc); $(location @com_google_protobuf//:protoc) --cpp_out=$(GENDIR) $<",
+ srcs = [src],
+ tools = ["@com_google_protobuf//:protoc"],
+ )
+ native.cc_library(
+ name = "%s_proto" % name,
+ srcs = ["%s.pb.cc" % name],
+ hdrs = ["%s.pb.h" % name],
+ deps = [
+ "@com_google_protobuf//:protobuf_headers",
+ "@com_google_protobuf//:protobuf",
+ ] + deps,
+ copts = ["-I$(GENDIR)"],
+ )
+
+def if_static(extra_deps = [], otherwise = []):
+ return otherwise
+
+def tf_protobuf_deps():
+ return if_static(
+ [
+ clean_dep("@com_google_protobuf//:protobuf"),
+ ],
+ otherwise = [clean_dep("@com_google_protobuf//:protobuf_headers")],
+ )
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/demo_plugin.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/demo_plugin.bzl
new file mode 100644
index 000000000..865909429
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/demo_plugin.bzl
@@ -0,0 +1,56 @@
+# Return the options to use for a C++ library or binary build.
+# Uses the ":optmode" config_setting to pick the options.
+
+def if_linux_x86_64(a, otherwise = []):
+ return select({
+ "//conditons:default": otherwise,
+ })
+
+def tf_copts(android_optimization_level_override = "-O2", is_external = False):
+ # For compatibility reasons, android_optimization_level_override
+ # is currently only being set for Android.
+ # To clear this value, and allow the CROSSTOOL default
+ # to be used, pass android_optimization_level_override=None
+ return (
+ [
+ "-Wno-sign-compare",
+ "-fno-exceptions",
+ "-ftemplate-depth=900",
+ "-msse3",
+ "-pthread",
+ ]
+ )
+
+def _get_transitive_headers(hdrs, deps):
+ return depset(
+ hdrs,
+ transitive = [dep[CcInfo].compilation_context.headers for dep in deps],
+ )
+
+def _transitive_hdrs_impl(ctx):
+ outputs = _get_transitive_headers([], ctx.attr.deps)
+ return struct(files = outputs)
+
+_transitive_hdrs = rule(
+ attrs = {
+ "deps": attr.label_list(
+ allow_files = True,
+ providers = [CcInfo],
+ ),
+ },
+ implementation = _transitive_hdrs_impl,
+)
+
+def transitive_hdrs(name, deps = [], **kwargs):
+ _transitive_hdrs(name = name + "_gather", deps = deps)
+ native.filegroup(name = name, srcs = [":" + name + "_gather"])
+
+def cc_header_only_library(name, deps = [], includes = [], extra_deps = [], **kwargs):
+ _transitive_hdrs(name = name + "_gather", deps = deps)
+ native.cc_library(
+ name = name,
+ srcs = [":" + name + "_gather"],
+ hdrs = includes,
+ deps = extra_deps,
+ **kwargs
+ )
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/__init__.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/__init__.py
new file mode 100644
index 000000000..8d1c8b69c
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/__init__.py
@@ -0,0 +1 @@
+
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/test.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/test.py
new file mode 100644
index 000000000..b5e88fec1
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/test.py
@@ -0,0 +1,20 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This is a module for dummy test."""
+
+import os
+
+if __name__ == '__main__':
+ print(os.path.realpath('.'))
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/BUILD
new file mode 100644
index 000000000..1a8c74193
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/BUILD
@@ -0,0 +1,61 @@
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "plugin_device",
+ srcs = ["plugin_device.cc"],
+ hdrs = ["plugin_device.h"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_plugin/src/device/cpu:cpu_device_impl",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ alwayslink = True,
+)
+
+cc_library(
+ name = "plugin_graph",
+ srcs = ["plugin_graph.cc"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":plugin_device",
+ "//tensorflow_plugin/src/graph:plugin_optimizer",
+ "//tensorflow_plugin/src/kernels/cpu:cpu_kernel_impl",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ alwayslink = True,
+)
+
+cc_library(
+ name = "plugin_kernel",
+ srcs = ["plugin_kernel.cc"],
+ hdrs = ["plugin_kernel.h"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":plugin_device",
+ "//tensorflow_plugin/src/kernels/cpu:cpu_kernel_impl",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ alwayslink = True,
+)
+
+
+cc_library(
+ name = "plugin_profiler",
+ srcs = ["plugin_profiler.cc"],
+ hdrs = ["plugin_device.h"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_plugin/src/profiler/cpu:demo_profiler",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ alwayslink = True,
+)
+
+
+
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/BUILD
new file mode 100644
index 000000000..6d997f51b
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/BUILD
@@ -0,0 +1,25 @@
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "cpu_device_impl",
+ srcs = ["cpu_device_plugin.cc"],
+ hdrs = [
+ "cpu_device_plugin.h",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "@local_config_gcc//:framework_lib",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ alwayslink = True,
+)
+
+exports_files(
+ srcs = [
+ "cpu_device_plugin.h",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.cc
new file mode 100644
index 000000000..db8c10b73
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.cc
@@ -0,0 +1,294 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+#include "tensorflow_plugin/src/device/cpu/cpu_device_plugin.h"
+#include
+
+#include
+#include
+#include
+
+#include
+
+void plugin_get_device_count(const SP_Platform* platform, int* device_count,
+ TF_Status* status) {
+ *device_count = 1;
+}
+
+void plugin_create_device(const SP_Platform* platform,
+ SE_CreateDeviceParams* params,
+ TF_Status* const status) {
+ params->device->struct_size = SP_DEVICE_STRUCT_SIZE;
+ params->device->device_handle = nullptr;
+ params->device->ordinal = 0;
+}
+
+void plugin_destroy_device(const SP_Platform* platform, SP_Device* device) {
+ device->device_handle = nullptr;
+ device->ordinal = -1;
+}
+
+void plugin_create_device_fns(const SP_Platform* platform,
+ SE_CreateDeviceFnsParams* params,
+ TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
+}
+void plugin_destroy_device_fns(const SP_Platform* platform,
+ SP_DeviceFns* device_fns) {}
+
+/*StreamExecutor Backend Impl*/
+void plugin_allocate(const SP_Device* device, uint64_t size,
+ int64_t memory_space, SP_DeviceMemoryBase* mem) {
+ mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE;
+ mem->opaque = aligned_alloc(64, size);
+ mem->size = size;
+}
+
+void plugin_deallocate(const SP_Device* device, SP_DeviceMemoryBase* mem) {
+ free(mem->opaque);
+ mem->opaque = nullptr;
+ mem->size = 0;
+}
+
+void* plugin_host_memory_allocate(const SP_Device* device, uint64_t size) {
+ void* ptr = aligned_alloc(64, size);
+ return ptr;
+}
+
+void plugin_host_memory_deallocate(const SP_Device* device, void* mem) {
+ free(mem);
+}
+
+TF_Bool plugin_get_allocator_stats(const SP_Device* device,
+ SP_AllocatorStats* stats) {
+ stats->struct_size = SP_ALLOCATORSTATS_STRUCT_SIZE;
+ stats->bytes_in_use = 123;
+ return true;
+}
+
+// TODO(plugin):Check correctness of this function
+TF_Bool plugin_device_memory_usage(const SP_Device* device, int64_t* free,
+ int64_t* total) {
+ struct sysinfo info;
+ int err = sysinfo(&info);
+ *free = info.freeram;
+ *total = info.totalram;
+
+ return (err == 0);
+}
+
+void plugin_create_stream(const SP_Device* device, SP_Stream* stream,
+ TF_Status* status) {}
+
+// Destroys SP_Stream and deallocates any underlying resources.
+void plugin_destroy_stream(const SP_Device* device, SP_Stream stream) {}
+
+void plugin_create_stream_dependency(const SP_Device* device,
+ SP_Stream dependent, SP_Stream other,
+ TF_Status* status) {}
+
+// Without blocking the device, retrieve the current stream status.
+void plugin_get_stream_status(const SP_Device* device, SP_Stream stream,
+ TF_Status* status) {}
+
+void plugin_create_event(const SP_Device* device, SP_Event* event,
+ TF_Status* status) {}
+
+// Destroy SE_Event and perform any platform-specific deallocation and
+// cleanup of an event.
+void plugin_destroy_event(const SP_Device* device, SP_Event event) {}
+
+// Requests the current status of the event from the underlying platform.
+SE_EventStatus plugin_get_event_status(const SP_Device* device,
+ SP_Event event) {
+ return SE_EVENT_COMPLETE;
+}
+
+// Inserts the specified event at the end of the specified stream.
+void plugin_record_event(const SP_Device* device, SP_Stream stream,
+ SP_Event event, TF_Status* status) {}
+
+// Wait for the specified event at the end of the specified stream.
+void plugin_wait_for_event(const SP_Device* const device, SP_Stream stream,
+ SP_Event event, TF_Status* const status) {}
+
+/*** TIMER CALLBACKS ***/
+// Creates SP_Timer. Allocates timer resources on the underlying platform
+// and initializes its internals, setting `timer` output variable. Sets
+// values in `timer_fns` struct.
+void plugin_create_timer(const SP_Device* device, SP_Timer* timer,
+ TF_Status* status) {}
+
+// Destroy timer and deallocates timer resources on the underlying platform.
+void plugin_destroy_timer(const SP_Device* device, SP_Timer timer) {}
+
+// Records a start event for an interval timer.
+void plugin_start_timer(const SP_Device* device, SP_Stream stream,
+ SP_Timer timer, TF_Status* status) {}
+
+// Records a stop event for an interval timer.
+void plugin_stop_timer(const SP_Device* device, SP_Stream stream,
+ SP_Timer timer, TF_Status* status) {}
+
+/*** MEMCPY CALLBACKS ***/
+// Enqueues a memcpy operation onto stream, with a host destination location
+// `host_dst` and a device memory source, with target size `size`.
+void plugin_memcpy_dtoh(const SP_Device* device, SP_Stream stream,
+ void* host_dst, const SP_DeviceMemoryBase* device_src,
+ uint64_t size, TF_Status* status) {
+ memcpy(host_dst, device_src->opaque, size);
+}
+
+// Enqueues a memcpy operation onto stream, with a device destination
+// location and a host memory source, with target size `size`.
+void plugin_memcpy_htod(const SP_Device* device, SP_Stream stream,
+ SP_DeviceMemoryBase* device_dst, const void* host_src,
+ uint64_t size, TF_Status* status) {
+ memcpy(device_dst->opaque, host_src, size);
+}
+
+// Enqueues a memcpy operation onto stream, with a device destination
+// location and a device memory source, with target size `size`.
+void plugin_memcpy_dtod(const SP_Device* device, SP_Stream stream,
+ SP_DeviceMemoryBase* device_dst,
+ const SP_DeviceMemoryBase* device_src, uint64_t size,
+ TF_Status* status) {
+ memcpy(device_dst->opaque, device_src->opaque, size);
+}
+
+// Blocks the caller while a data segment of the given size is
+// copied from the device source to the host destination.
+void plugin_sync_memcpy_dtoh(const SP_Device* device, void* host_dst,
+ const SP_DeviceMemoryBase* device_src,
+ uint64_t size, TF_Status* status) {
+ memcpy(host_dst, device_src->opaque, size);
+}
+
+// Blocks the caller while a data segment of the given size is
+// copied from the host source to the device destination.
+void plugin_sync_memcpy_htod(const SP_Device* device,
+ SP_DeviceMemoryBase* device_dst,
+ const void* host_src, uint64_t size,
+ TF_Status* status) {
+ memcpy(device_dst->opaque, host_src, size);
+}
+
+// Blocks the caller while a data segment of the given size is copied from the
+// device source to the device destination.
+void plugin_sync_memcpy_dtod(const SP_Device* device,
+ SP_DeviceMemoryBase* device_dst,
+ const SP_DeviceMemoryBase* device_src,
+ uint64_t size, TF_Status* status) {
+ memcpy(device_dst->opaque, device_src->opaque, size);
+}
+
+// Causes the host code to synchronously wait for the event to complete.
+void plugin_block_host_for_event(const SP_Device* device, SP_Event event,
+ TF_Status* status) {}
+
+void plugin_block_host_until_done(const SP_Device* device, SP_Stream stream,
+ TF_Status* status) {}
+
+// Synchronizes all activity occurring in the StreamExecutor's context (most
+// likely a whole device).
+void plugin_synchronize_all_activity(const SP_Device* device,
+ TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+}
+
+// Enqueues on a stream a user-specified function to be run on the host.
+// `callback_arg` should be passed as the first argument to `callback_fn`.
+TF_Bool plugin_host_callback(const SP_Device* device, SP_Stream stream,
+ SE_StatusCallbackFn callback_fn,
+ void* callback_arg) {
+ return TF_OK;
+}
+
+/*Timer Backer Impl*/
+uint64_t nanoseconds(SP_Timer timer) { return timer->timer_handle; }
+
+void plugin_create_timer_fns(const SP_Platform* platform,
+ SP_TimerFns* timer_fns, TF_Status* const status) {
+ timer_fns->nanoseconds = nanoseconds;
+}
+
+void plugin_destroy_timer_fns(const SP_Platform* platform,
+ SP_TimerFns* timer_fns) {}
+
+void plugin_create_stream_executor(const SP_Platform* platform,
+ SE_CreateStreamExecutorParams* params,
+ TF_Status* const status) {
+ params->stream_executor->struct_size = SP_STREAMEXECUTOR_STRUCT_SIZE;
+ params->stream_executor->allocate = plugin_allocate;
+ params->stream_executor->deallocate = plugin_deallocate;
+ params->stream_executor->host_memory_allocate = plugin_host_memory_allocate;
+ params->stream_executor->host_memory_deallocate =
+ plugin_host_memory_deallocate;
+ params->stream_executor->get_allocator_stats = plugin_get_allocator_stats;
+ params->stream_executor->device_memory_usage = plugin_device_memory_usage;
+
+ params->stream_executor->create_stream = plugin_create_stream;
+ params->stream_executor->destroy_stream = plugin_destroy_stream;
+ params->stream_executor->create_stream_dependency =
+ plugin_create_stream_dependency;
+ params->stream_executor->get_stream_status = plugin_get_stream_status;
+ params->stream_executor->create_event = plugin_create_event;
+ params->stream_executor->destroy_event = plugin_destroy_event;
+ params->stream_executor->get_event_status = plugin_get_event_status;
+ params->stream_executor->record_event = plugin_record_event;
+ params->stream_executor->wait_for_event = plugin_wait_for_event;
+ params->stream_executor->create_timer = plugin_create_timer;
+ params->stream_executor->destroy_timer = plugin_destroy_timer;
+ params->stream_executor->start_timer = plugin_start_timer;
+ params->stream_executor->stop_timer = plugin_stop_timer;
+
+ params->stream_executor->memcpy_dtoh = plugin_memcpy_dtoh;
+ params->stream_executor->memcpy_htod = plugin_memcpy_htod;
+ params->stream_executor->memcpy_dtod = plugin_memcpy_dtod;
+ params->stream_executor->sync_memcpy_dtoh = plugin_sync_memcpy_dtoh;
+ params->stream_executor->sync_memcpy_htod = plugin_sync_memcpy_htod;
+ params->stream_executor->sync_memcpy_dtod = plugin_sync_memcpy_dtod;
+
+ // TODO(plugin): Fill the function for block stream
+ params->stream_executor->block_host_until_done = plugin_block_host_until_done;
+ params->stream_executor->block_host_for_event = plugin_block_host_for_event;
+
+ params->stream_executor->synchronize_all_activity =
+ plugin_synchronize_all_activity;
+ params->stream_executor->host_callback = plugin_host_callback;
+}
+
+void plugin_destroy_stream_executor(const SP_Platform* platform,
+ SP_StreamExecutor* stream_executor) {}
+
+void plugin_destroy_platform(SP_Platform* const platform) {}
+void plugin_destroy_platform_fns(SP_PlatformFns* const platform_fns) {}
+
+void SE_InitPluginFns(SE_PlatformRegistrationParams* const params,
+ TF_Status* const status) {
+ params->platform_fns->get_device_count = plugin_get_device_count;
+ params->platform_fns->create_device = plugin_create_device;
+ params->platform_fns->destroy_device = plugin_destroy_device;
+ params->platform_fns->create_device_fns = plugin_create_device_fns;
+ params->platform_fns->destroy_device_fns = plugin_destroy_device_fns;
+ params->platform_fns->create_stream_executor = plugin_create_stream_executor;
+ params->platform_fns->destroy_stream_executor =
+ plugin_destroy_stream_executor;
+ params->platform_fns->create_timer_fns = plugin_create_timer_fns;
+ params->platform_fns->destroy_timer_fns = plugin_destroy_timer_fns;
+ params->destroy_platform = plugin_destroy_platform;
+ params->destroy_platform_fns = plugin_destroy_platform_fns;
+}
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.h
new file mode 100644
index 000000000..e77bf369b
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.h
@@ -0,0 +1,38 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PLUGIN_SRC_DEVICE_CPU_H_
+#define TENSORFLOW_PLUGIN_SRC_DEVICE_CPU_H_
+#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
+#include "tensorflow/c/tf_status.h"
+
+void SE_InitPluginFns(SE_PlatformRegistrationParams* const params,
+ TF_Status* const status);
+
+struct SP_Stream_st {
+ explicit SP_Stream_st(void* stream_h) : stream_handle(stream_h) {}
+ void* stream_handle;
+};
+
+struct SP_Event_st {
+ explicit SP_Event_st(void* event_h) : event_handle(event_h) {}
+ void* event_handle;
+};
+
+struct SP_Timer_st {
+ explicit SP_Timer_st(int id) : timer_handle(id) {}
+ int timer_handle;
+};
+
+#endif // TENSORFLOW_PLUGIN_SRC_DEVICE_CPU_H_
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/BUILD
new file mode 100644
index 000000000..060eeaa89
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/BUILD
@@ -0,0 +1,30 @@
+load(
+ "//tensorflow_plugin:build_config.bzl",
+ "tf_protobuf_deps",
+)
+
+cc_library(
+ name = "tf_buffer",
+ srcs = ["tf_buffer.cc"],
+ hdrs = [
+ "tf_buffer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@local_config_tf//:tf_header_lib",
+ ] + tf_protobuf_deps(),
+)
+
+cc_library(
+ name = "plugin_optimizer",
+ srcs = ["plugin_optimizer.cc"],
+ hdrs = ["plugin_optimizer.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_plugin/src/utils:protos_all",
+ # "//tensorflow_plugin/src/utils:types_proto",
+ "tf_buffer",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ alwayslink = True,
+)
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.cc
new file mode 100644
index 000000000..e24e4b808
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.cc
@@ -0,0 +1,53 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+#include "tensorflow_plugin/src/graph/plugin_optimizer.h"
+#include "tensorflow/c/experimental/grappler/grappler.h"
+#include "tensorflow_plugin/src/graph/tf_buffer.h"
+#include "tensorflow_plugin/src/utils/graph.pb.h"
+
+namespace demo_plugin {
+namespace graph {
+
+void *Optimizer_Create() {
+ auto *optimizer = new Optimizer;
+ return reinterpret_cast(optimizer);
+}
+
+void Optimizer_Destroy(void *optimizer) {
+ if (optimizer)
+ delete reinterpret_cast(optimizer);
+}
+
+void Optimizer_Optimize(void *optimizer, const TF_Buffer *graph_buf,
+ const TF_GrapplerItem *item,
+ TF_Buffer *optimized_graph_buf, TF_Status *tf_status) {
+ // Deserialize graph_buf into GraphDef.
+ GraphDef graph_def;
+ BufferToMessage(graph_buf, graph_def, tf_status);
+ if (TF_GetCode(tf_status) != TF_OK)
+ return;
+
+ // Doing graph transformation.
+ GraphDef optimized_graph_def = graph_def;
+
+ // Serialize output GraphDef into optimized_graph_buf.
+ MessageToBuffer(optimized_graph_def, optimized_graph_buf, tf_status);
+ if (TF_GetCode(tf_status) != TF_OK)
+ return;
+}
+
+} // namespace graph
+} // namespace demo_plugin
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.h
new file mode 100644
index 000000000..55b73cf25
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.h
@@ -0,0 +1,38 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+#ifndef TENSORFLOW_PLUGIN_SRC_GRAPH_PLUGIN_OPTIMIZER_H_
+#define TENSORFLOW_PLUGIN_SRC_GRAPH_PLUGIN_OPTIMIZER_H_
+
+#include "tensorflow/c/experimental/grappler/grappler.h"
+
+namespace demo_plugin {
+namespace graph {
+
+typedef struct Optimizer {
+} Optimizer;
+
+void *Optimizer_Create();
+
+void Optimizer_Destroy(void *optimizer);
+
+void Optimizer_Optimize(void *optimizer, const TF_Buffer *graph_buf,
+ const TF_GrapplerItem *item,
+ TF_Buffer *optimized_graph_buf, TF_Status *tf_status);
+
+} // namespace graph
+} // namespace demo_plugin
+
+#endif // TENSORFLOW_PLUGIN_SRC_GRAPH_PLUGIN_OPTIMIZER_H_
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.cc
new file mode 100644
index 000000000..88be206e2
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.cc
@@ -0,0 +1,57 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+#include "tensorflow_plugin/src/graph/tf_buffer.h"
+#include "tensorflow/c/tf_status.h"
+
+namespace demo_plugin {
+
+void MessageToBuffer(const demo_plugin::protobuf::MessageLite& in,
+ TF_Buffer* out, TF_Status* status) {
+ if (out->data != nullptr) {
+ TF_SetStatus(status, TF_Code::TF_INVALID_ARGUMENT,
+ "Passing non-empty TF_Buffer is invalid.");
+ return;
+ }
+ const size_t proto_size = in.ByteSizeLong();
+ void* buf = malloc(proto_size);
+ if (buf == nullptr) {
+ TF_SetStatus(status, TF_Code::TF_RESOURCE_EXHAUSTED,
+ "Failed to allocate memory to serialize message.");
+ return;
+ }
+ if (!in.SerializeWithCachedSizesToArray(static_cast(buf))) {
+ free(buf);
+ TF_SetStatus(status, TF_Code::TF_INVALID_ARGUMENT,
+ "Unable to serialize protocol buffer.");
+ return;
+ }
+ out->data = buf;
+ out->length = proto_size;
+ out->data_deallocator = [](void* data, size_t length) { free(data); };
+ TF_SetStatus(status, TF_Code::TF_OK, "");
+}
+
+void BufferToMessage(const TF_Buffer* in,
+ demo_plugin::protobuf::MessageLite& out,
+ TF_Status* status) {
+ if (in == nullptr || !out.ParseFromArray(in->data, in->length)) {
+ TF_SetStatus(status, TF_Code::TF_INVALID_ARGUMENT, "Unparsable proto.");
+ return;
+ }
+ TF_SetStatus(status, TF_Code::TF_OK, "");
+}
+
+} // namespace demo_plugin
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.h
new file mode 100644
index 000000000..745330f24
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.h
@@ -0,0 +1,54 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+#ifndef TENSORFLOW_PLUGIN_SRC_GRAPH_UTILS_TF_BUFFER_H_
+#define TENSORFLOW_PLUGIN_SRC_GRAPH_UTILS_TF_BUFFER_H_
+
+#include "tensorflow/c/c_api.h"
+
+// Import whatever namespace protobuf comes from into the
+// ::tensorflow::protobuf namespace.
+//
+// TensorFlow code should use the ::tensorflow::protobuf namespace to
+// refer to all protobuf APIs.
+
+#include "google/protobuf/arena.h"
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/descriptor.pb.h"
+#include "google/protobuf/dynamic_message.h"
+#include "google/protobuf/io/coded_stream.h"
+#include "google/protobuf/io/tokenizer.h"
+#include "google/protobuf/io/zero_copy_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+#include "google/protobuf/map.h"
+#include "google/protobuf/message.h"
+#include "google/protobuf/repeated_field.h"
+#include "google/protobuf/text_format.h"
+#include "google/protobuf/util/json_util.h"
+#include "google/protobuf/util/type_resolver_util.h"
+
+namespace demo_plugin {
+
+namespace protobuf = ::google::protobuf;
+
+void MessageToBuffer(const demo_plugin::protobuf::MessageLite& in,
+ TF_Buffer* out, TF_Status* status);
+
+void BufferToMessage(const TF_Buffer* in,
+ demo_plugin::protobuf::MessageLite& out,
+ TF_Status* status);
+} // namespace demo_plugin
+
+#endif // TENSORFLOW_PLUGIN_SRC_GRAPH_UTILS_TF_BUFFER_H_
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/BUILD
new file mode 100644
index 000000000..790abc431
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/BUILD
@@ -0,0 +1,53 @@
+load("//tensorflow_plugin:demo_plugin.bzl", "tf_copts")
+
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "relu_op",
+ srcs = ["relu_op.cc"],
+ copts = tf_copts(),
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_plugin:core",
+ "@com_google_absl//absl/container:inlined_vector",
+ ],
+ alwayslink = True,
+)
+
+cc_library(
+ name = "conv2d_op",
+ srcs = ["conv_ops_using_gemm.cc"],
+ hdrs = ["gemm_functors.h"],
+ copts = tf_copts(),
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_plugin:core",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@eigen_archive//:eigen",
+ "//third_party/eigen3"
+ ],
+ alwayslink = True,
+)
+
+CPU_KERNELS = [
+ ":relu_op",
+ ":conv2d_op",
+]
+
+cc_library(
+ name = "cpu_kernel_impl",
+ srcs = ["cpu_kernel_init.cc"],
+ hdrs = [
+ "cpu_kernel_init.h",
+ "//tensorflow_plugin/src/device/cpu:cpu_device_plugin.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_plugin:core",
+ ] + CPU_KERNELS,
+ alwayslink = True,
+)
diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/conv_ops_using_gemm.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/conv_ops_using_gemm.cc
new file mode 100644
index 000000000..f1758d192
--- /dev/null
+++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/conv_ops_using_gemm.cc
@@ -0,0 +1,623 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+//#define EIGEN_USE_THREADS
+
+#include