diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/Xspace.png b/rfcs/20200624-pluggable-device-for-tensorflow/Xspace.png new file mode 100644 index 000000000..d79249427 Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/Xspace.png differ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/flow.png b/rfcs/20200624-pluggable-device-for-tensorflow/flow.png new file mode 100644 index 000000000..59b61575d Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/flow.png differ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/modular_TensorFlow.png b/rfcs/20200624-pluggable-device-for-tensorflow/modular_TensorFlow.png new file mode 100644 index 000000000..39a1c26e7 Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/modular_TensorFlow.png differ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/profiler_result.png b/rfcs/20200624-pluggable-device-for-tensorflow/profiler_result.png new file mode 100644 index 000000000..b10cef062 Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/profiler_result.png differ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/.bazelrc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/.bazelrc new file mode 100644 index 000000000..ea5676684 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/.bazelrc @@ -0,0 +1,16 @@ +build --define=use_fast_cpp_protos=true +build --define=allow_oversize_protos=true + +build --spawn_strategy=standalone +# build --strategy=Genrule=standalone +build -c opt + +# Default paths for TF_SYSTEM_LIBS +build --define=PREFIX=/usr +build --define=LIBDIR=$(PREFIX)/lib +build --define=INCLUDEDIR=$(PREFIX)/include + +# host build is useless +build --distinct_host_configuration=false + +try-import %workspace%/.tf_plugin_configure.bazelrc diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/README.md b/rfcs/20200624-pluggable-device-for-tensorflow/sample/README.md new file mode 100644 index 000000000..002eb14ca --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/README.md @@ -0,0 +1,96 @@ +# TensorFlow Plugin demo +This sample is a simple demo shows how to implement, build, install and run a TensorFlow plugin. + +## Supported OS +* Linux + +## Prerequisites + +* [Bazel](https://docs.bazel.build/versions/master/install-ubuntu.html) (version 3.1 and above) +* Git (version 1.8 and above) +* Python (version 3.6 and above) + +## Build and Run + +### Linux +1. Run the following command to install the latest `tensorflow`. +``` +$ pip install tensorflow +``` +2. In the plug-in `sample` code folder, configure the build options: +``` +$ ./configure + +Please specify the location of python. [Default is /home/test/miniconda2/envs/sycl3.6/bin/python]: + + +Found possible Python library paths: + /home/test/miniconda2/envs/sycl3.6/lib/python3.6/site-packages +Please input the desired Python library path to use. Default is [/home/test/miniconda2/envs/sycl3.6/lib/python3.6/site-packages] + +Do you wish to build TensorFlow plug-in with MPI support? [y/N]: +No MPI support will be enabled for TensorFlow plug-in. + +Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]: +``` + +3. Built the plug-in with +``` +$ bazel build -c opt //tensorflow_plugin/tools/pip_package:build_pip_package --verbose_failures +``` +4. Then generate a python wheel and install it. +``` +$ bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package . +$ pip install tensorflow_plugins-0.0.1-cp36-cp36m-linux_x86_64.whl +``` +5. Now we can run the TensorFlow with plug-in device enabled. +``` +$ python +>>> import tensorflow as tf +>>> tf.config.list_physical_devices() +[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:MY_DEVICE:0', device_type='MY_DEVICE')] +``` +* Relu case: +``` +$ python relu.py +random_normal/RandomStandardNormal: (RandomStandardNormal): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:48:20.714819: I tensorflow/core/common_runtime/placer.cc:114] random_normal/RandomStandardNormal: (RandomStandardNormal): /job:localhost/replica:0/task:0/device:CPU:0 +random_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:48:20.714864: I tensorflow/core/common_runtime/placer.cc:114] random_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0 +random_normal: (AddV2): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:48:20.714903: I tensorflow/core/common_runtime/placer.cc:114] random_normal: (AddV2): /job:localhost/replica:0/task:0/device:CPU:0 +Relu: (Relu): /job:localhost/replica:0/task:0/device:MY_DEVICE:0 +2021-10-21 12:48:20.714937: I tensorflow/core/common_runtime/placer.cc:114] Relu: (Relu): /job:localhost/replica:0/task:0/device:MY_DEVICE:0 +random_normal/shape: (Const): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:48:20.714968: I tensorflow/core/common_runtime/placer.cc:114] random_normal/shape: (Const): /job:localhost/replica:0/task:0/device:CPU:0 +random_normal/mean: (Const): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:48:20.714997: I tensorflow/core/common_runtime/placer.cc:114] random_normal/mean: (Const): /job:localhost/replica:0/task:0/device:CPU:0 +random_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:48:20.715022: I tensorflow/core/common_runtime/placer.cc:114] random_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:CPU:0 +[2.9109507 0. 0. 0. 0. 0. 0. + 0. 0. 1.316411 ] + +``` +* Conv + Relu case: +``` +$ python conv_relu.py +2021-10-21 12:53:36.389514: I tensorflow/core/common_runtime/placer.cc:114] random_normal_3/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0 +random_normal_3: (AddV2): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:53:36.389537: I tensorflow/core/common_runtime/placer.cc:114] random_normal_3: (AddV2): /job:localhost/replica:0/task:0/device:CPU:0 +Relu: (Relu): /job:localhost/replica:0/task:0/device:MY_DEVICE:0 +2021-10-21 12:53:36.389565: I tensorflow/core/common_runtime/placer.cc:114] Relu: (Relu): /job:localhost/replica:0/task:0/device:MY_DEVICE:0 +Conv2D: (Conv2D): /job:localhost/replica:0/task:0/device:MY_DEVICE:0 +2021-10-21 12:53:36.389592: I tensorflow/core/common_runtime/placer.cc:114] Conv2D: (Conv2D): /job:localhost/replica:0/task:0/device:MY_DEVICE:0 +Relu_1: (Relu): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:53:36.389617: I tensorflow/core/common_runtime/placer.cc:114] Relu_1: (Relu): /job:localhost/replica:0/task:0/device:CPU:0 +Conv2D_1: (Conv2D): /job:localhost/replica:0/task:0/device:CPU:0 +2021-10-21 12:53:36.389641: I tensorflow/core/common_runtime/placer.cc:114] Conv2D_1: (Conv2D): /job:localhost/replica:0/task:0/device:CPU:0 +``` +* Profiler case: +``` +$python test_profiler.py +``` +
+ +
+ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/WORKSPACE b/rfcs/20200624-pluggable-device-for-tensorflow/sample/WORKSPACE new file mode 100644 index 000000000..1fe8eb76a --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/WORKSPACE @@ -0,0 +1,21 @@ +workspace(name = "org_tensorflow_plugin") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") +load("//third_party:version_check.bzl", "check_bazel_version_at_least") + +check_bazel_version_at_least("3.1.0") + +load("//tensorflow_plugin:tf_configure.bzl", "tf_configure") + +tf_configure(name = "local_config_tf") + +load("//tensorflow_plugin:workspace.bzl", "clean_dep", "demo_plugin_workspace") + +demo_plugin_workspace() + +load( + "@bazel_toolchains//repositories:repositories.bzl", + bazel_toolchains_repositories = "repositories", +) + +bazel_toolchains_repositories() diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/build.sh b/rfcs/20200624-pluggable-device-for-tensorflow/sample/build.sh new file mode 100644 index 000000000..30cb29095 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/build.sh @@ -0,0 +1,2 @@ +#!/bin/bash +bazel build -c opt //tensorflow_plugin/tools/pip_package:build_pip_package --verbose_failures diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure b/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure new file mode 100755 index 000000000..66b66ba54 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +set -e +set -o pipefail + +if [ -z "$PYTHON_BIN_PATH" ]; then + PYTHON_BIN_PATH=$(which python || which python3 || true) +fi + +# Set all env variables +CONFIGURE_DIR=$(dirname "$0") +"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" + +echo "Configuration finished" + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure.py new file mode 100644 index 000000000..b4fdb36a7 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/configure.py @@ -0,0 +1,1127 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""configure script to get build parameters from user.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import os +import platform +import re +import subprocess +import sys + +# pylint: disable=g-import-not-at-top +try: + from shutil import which +except ImportError: + from distutils.spawn import find_executable as which +# pylint: enable=g-import-not-at-top + + +_DEFAULT_GCC_TOOLCHAIN_PATH = '' +_DEFAULT_GCC_TOOLCHAIN_TARGET = '' + +_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 + +_TF_BAZELRC_FILENAME = '.tf_plugin_configure.bazelrc' +_TF_WORKSPACE_ROOT = '' +_TF_BAZELRC = '' +_TF_CURRENT_BAZEL_VERSION = None + +NCCL_LIB_PATHS = [ + 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' +] + + +class UserInputError(Exception): + pass + + +def is_windows(): + return platform.system() == 'Windows' + + +def is_linux(): + return platform.system() == 'Linux' + + +def is_macos(): + return platform.system() == 'Darwin' + + +def is_ppc64le(): + return platform.machine() == 'ppc64le' + + +def is_cygwin(): + return platform.system().startswith('CYGWIN_NT') + + +def get_input(question): + try: + try: + answer = raw_input(question) + except NameError: + answer = input(question) # pylint: disable=bad-builtin + except EOFError: + answer = '' + return answer + + +def symlink_force(target, link_name): + """Force symlink, equivalent of 'ln -sf'. + + Args: + target: items to link to. + link_name: name of the link. + """ + try: + os.symlink(target, link_name) + except OSError as e: + if e.errno == errno.EEXIST: + os.remove(link_name) + os.symlink(target, link_name) + else: + raise e + + +def sed_in_place(filename, old, new): + """Replace old string with new string in file. + + Args: + filename: string for filename. + old: string to replace. + new: new string to replace to. + """ + with open(filename, 'r') as f: + filedata = f.read() + newdata = filedata.replace(old, new) + with open(filename, 'w') as f: + f.write(newdata) + + +def write_to_bazelrc(line): + with open(_TF_BAZELRC, 'a') as f: + f.write(line + '\n') + + +def write_action_env_to_bazelrc(var_name, var): + write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var))) + + +def run_shell(cmd, allow_non_zero=False): + if allow_non_zero: + try: + output = subprocess.check_output(cmd) + except subprocess.CalledProcessError as e: + output = e.output + else: + output = subprocess.check_output(cmd) + return output.decode('UTF-8').strip() + + +def cygpath(path): + """Convert path from posix to windows.""" + return os.path.abspath(path).replace('\\', '/') + + +def get_python_path(environ_cp, python_bin_path): + """Get the python site package paths.""" + python_paths = [] + if environ_cp.get('PYTHONPATH'): + python_paths = environ_cp.get('PYTHONPATH').split(':') + try: + library_paths = run_shell([ + python_bin_path, '-c', + 'import site; print("\\n".join(site.getsitepackages()))' + ]).split('\n') + except subprocess.CalledProcessError: + library_paths = [ + run_shell([ + python_bin_path, '-c', + 'from distutils.sysconfig import get_python_lib;' + 'print(get_python_lib())' + ]) + ] + + all_paths = set(python_paths + library_paths) + + paths = [] + for path in all_paths: + if os.path.isdir(path): + paths.append(path) + return paths + + +def get_python_major_version(python_bin_path): + """Get the python major version.""" + return run_shell([python_bin_path, '-c', 'import sys; print(sys.version[0])']) + + +def setup_python(environ_cp): + """Setup python related env variables.""" + # Get PYTHON_BIN_PATH, default is the current running python. + default_python_bin_path = sys.executable + ask_python_bin_path = ('Please specify the location of python. [Default is ' + '%s]: ') % default_python_bin_path + while True: + python_bin_path = get_from_env_or_user_or_default(environ_cp, + 'PYTHON_BIN_PATH', + ask_python_bin_path, + default_python_bin_path) + # Check if the path is valid + if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): + break + elif not os.path.exists(python_bin_path): + print('Invalid python path: %s cannot be found.' % python_bin_path) + else: + print('%s is not executable. Is it the python binary?' % python_bin_path) + environ_cp['PYTHON_BIN_PATH'] = '' + + # Convert python path to Windows style before checking lib and version + if is_windows() or is_cygwin(): + python_bin_path = cygpath(python_bin_path) + + # Get PYTHON_LIB_PATH + python_lib_path = environ_cp.get('PYTHON_LIB_PATH') + if not python_lib_path: + python_lib_paths = get_python_path(environ_cp, python_bin_path) + if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1': + python_lib_path = python_lib_paths[0] + else: + print('Found possible Python library paths:\n %s' % + '\n '.join(python_lib_paths)) + default_python_lib_path = python_lib_paths[0] + python_lib_path = get_input( + 'Please input the desired Python library path to use. ' + 'Default is [%s]\n' % python_lib_paths[0]) + if not python_lib_path: + python_lib_path = default_python_lib_path + environ_cp['PYTHON_LIB_PATH'] = python_lib_path + + _ = get_python_major_version(python_bin_path) + + # Convert python path to Windows style before writing into bazel.rc + if is_windows() or is_cygwin(): + python_lib_path = cygpath(python_lib_path) + + # Set-up env variables used by python_configure.bzl + write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) + write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) + write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) + environ_cp['PYTHON_BIN_PATH'] = python_bin_path + + # If choosen python_lib_path is from a path specified in the PYTHONPATH + # variable, need to tell bazel to include PYTHONPATH + if environ_cp.get('PYTHONPATH'): + python_paths = environ_cp.get('PYTHONPATH').split(':') + if python_lib_path in python_paths: + write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH')) + + # Write tools/python_bin_path.sh + with open( + os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow_plugin', 'tools', 'python_bin_path.sh'), + 'w') as f: + f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) + +def get_python_lib_name(environ_cp): + python_bin_path = environ_cp['PYTHON_BIN_PATH'] + path_list = python_bin_path.split(os.sep)[:-2] + path_list.append('lib') + py_lib_path = os.sep.join(path_list) + for _, _, files in os.walk(py_lib_path): + for name in files: + if str(name).startswith('libpython') and str(name).endswith('.so'): + # strip libxxx.so to get xxx + return str(name).strip()[3:-3] + + +def get_python_link_path(environ_cp): + # TODO(quintin): we need to link libpythonx.y.so for _pywrap_tensorflow_internal.so + # once google change CAPI symbols into libtensorflow.so, we don't need this + python_bin_path = environ_cp['PYTHON_BIN_PATH'] + path_list = python_bin_path.split(os.sep)[:-2] + path_list.append('lib') + py_lib_path = os.sep.join(path_list) + return py_lib_path + +def create_build_configuration(environ_cp): + + tf_header_dir = environ_cp['PYTHON_LIB_PATH'] + "/tensorflow/include" + tf_shared_lib_dir = environ_cp['PYTHON_LIB_PATH'] + "/tensorflow/" + + write_action_env_to_bazelrc("TF_HEADER_DIR", tf_header_dir) + write_action_env_to_bazelrc("TF_SHARED_LIBRARY_DIR", tf_shared_lib_dir) + write_action_env_to_bazelrc("TF_CXX11_ABI_FLAG", 1) + write_action_env_to_bazelrc("PYTHON_LINK_LIB_NAME", get_python_lib_name(environ_cp)) + write_action_env_to_bazelrc("PYTHON_LINK_PATH", get_python_link_path(environ_cp)) + + +def reset_tf_configure_bazelrc(): + """Reset file that contains customized config settings.""" + open(_TF_BAZELRC, 'w').close() + + +def cleanup_makefile(): + """Delete any leftover BUILD files from the Makefile build. + + These files could interfere with Bazel parsing. + """ + makefile_download_dir = os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow', + 'contrib', 'makefile', 'downloads') + if os.path.isdir(makefile_download_dir): + for root, _, filenames in os.walk(makefile_download_dir): + for f in filenames: + if f.endswith('BUILD'): + os.remove(os.path.join(root, f)) + + +def get_var(environ_cp, + var_name, + query_item, + enabled_by_default, + question=None, + yes_reply=None, + no_reply=None): + """Get boolean input from user. + + If var_name is not set in env, ask user to enable query_item or not. If the + response is empty, use the default. + + Args: + environ_cp: copy of the os.environ. + var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". + query_item: string for feature related to the variable, e.g. "CUDA for + Nvidia GPUs". + enabled_by_default: boolean for default behavior. + question: optional string for how to ask for user input. + yes_reply: optional string for reply when feature is enabled. + no_reply: optional string for reply when feature is disabled. + + Returns: + boolean value of the variable. + + Raises: + UserInputError: if an environment variable is set, but it cannot be + interpreted as a boolean indicator, assume that the user has made a + scripting error, and will continue to provide invalid input. + Raise the error to avoid infinitely looping. + """ + if not question: + question = 'Do you wish to build TensorFlow plug-in with %s support?' % query_item + if not yes_reply: + yes_reply = '%s support will be enabled for TensorFlow plug-in.' % query_item + if not no_reply: + no_reply = 'No %s' % yes_reply + + yes_reply += '\n' + no_reply += '\n' + + if enabled_by_default: + question += ' [Y/n]: ' + else: + question += ' [y/N]: ' + + var = environ_cp.get(var_name) + if var is not None: + var_content = var.strip().lower() + true_strings = ('1', 't', 'true', 'y', 'yes') + false_strings = ('0', 'f', 'false', 'n', 'no') + if var_content in true_strings: + var = True + elif var_content in false_strings: + var = False + else: + raise UserInputError( + 'Environment variable %s must be set as a boolean indicator.\n' + 'The following are accepted as TRUE : %s.\n' + 'The following are accepted as FALSE: %s.\n' + 'Current value is %s.' % + (var_name, ', '.join(true_strings), ', '.join(false_strings), var)) + + while var is None: + user_input_origin = get_input(question) + user_input = user_input_origin.strip().lower() + if user_input == 'y': + print(yes_reply) + var = True + elif user_input == 'n': + print(no_reply) + var = False + elif not user_input: + if enabled_by_default: + print(yes_reply) + var = True + else: + print(no_reply) + var = False + else: + print('Invalid selection: %s' % user_input_origin) + return var + + +def set_build_var(environ_cp, + var_name, + query_item, + option_name, + enabled_by_default, + bazel_config_name=None): + """Set if query_item will be enabled for the build. + + Ask user if query_item will be enabled. Default is used if no input is given. + Set subprocess environment variable and write to .bazelrc if enabled. + + Args: + environ_cp: copy of the os.environ. + var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". + query_item: string for feature related to the variable, e.g. "CUDA for + Nvidia GPUs". + option_name: string for option to define in .bazelrc. + enabled_by_default: boolean for default behavior. + bazel_config_name: Name for Bazel --config argument to enable build feature. + """ + + var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) + environ_cp[var_name] = var + if var == '1': + write_to_bazelrc('build:%s --define %s=true' % + (bazel_config_name, option_name)) + write_to_bazelrc('build --config=%s' % bazel_config_name) + elif bazel_config_name is not None: + # TODO(mikecase): Migrate all users of configure.py to use --config Bazel + # options and not to set build configs through environment variables. + write_to_bazelrc('build:%s --define %s=true' % + (bazel_config_name, option_name)) + + +def set_action_env_var(environ_cp, + var_name, + query_item, + enabled_by_default, + question=None, + yes_reply=None, + no_reply=None): + """Set boolean action_env variable. + + Ask user if query_item will be enabled. Default is used if no input is given. + Set environment variable and write to .bazelrc. + + Args: + environ_cp: copy of the os.environ. + var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". + query_item: string for feature related to the variable, e.g. "CUDA for + Nvidia GPUs". + enabled_by_default: boolean for default behavior. + question: optional string for how to ask for user input. + yes_reply: optional string for reply when feature is enabled. + no_reply: optional string for reply when feature is disabled. + """ + var = int( + get_var(environ_cp, var_name, query_item, enabled_by_default, question, + yes_reply, no_reply)) + + write_action_env_to_bazelrc(var_name, var) + environ_cp[var_name] = str(var) + + +def convert_version_to_int(version): + """Convert a version number to a integer that can be used to compare. + + Version strings of the form X.YZ and X.Y.Z-xxxxx are supported. The + 'xxxxx' part, for instance 'homebrew' on OS/X, is ignored. + + Args: + version: a version to be converted + + Returns: + An integer if converted successfully, otherwise return None. + """ + version = version.split('-')[0] + version_segments = version.split('.') + # Treat "0.24" as "0.24.0" + if len(version_segments) == 2: + version_segments.append('0') + for seg in version_segments: + if not seg.isdigit(): + return None + + version_str = ''.join(['%03d' % int(seg) for seg in version_segments]) + return int(version_str) + + +def check_bazel_version(min_version, max_version): + """Check installed bazel version is between min_version and max_version. + + Args: + min_version: string for minimum bazel version (must exist!). + max_version: string for maximum bazel version (must exist!). + + Returns: + The bazel version detected. + """ + if which('bazel') is None: + print('Cannot find bazel. Please install bazel.') + sys.exit(0) + curr_version = run_shell( + ['bazel', '--batch', '--bazelrc=/dev/null', 'version']) + + for line in curr_version.split('\n'): + if 'Build label: ' in line: + curr_version = line.split('Build label: ')[1] + break + + min_version_int = convert_version_to_int(min_version) + curr_version_int = convert_version_to_int(curr_version) + max_version_int = convert_version_to_int(max_version) + + # Check if current bazel version can be detected properly. + if not curr_version_int: + print('WARNING: current bazel installation is not a release version.') + print('Make sure you are running at least bazel %s' % min_version) + return curr_version + + print('You have bazel %s installed.' % curr_version) + + if curr_version_int < min_version_int: + print('Please upgrade your bazel installation to version %s or higher to ' + 'build TensorFlow!' % min_version) + sys.exit(1) + if (curr_version_int > max_version_int and + 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ): + print('Please downgrade your bazel installation to version %s or lower to ' + 'build TensorFlow! To downgrade: download the installer for the old ' + 'version (from https://github.com/bazelbuild/bazel/releases) then ' + 'run the installer.' % max_version) + sys.exit(1) + return curr_version + + +def set_cc_opt_flags(environ_cp): + """Set up architecture-dependent optimization flags. + + Also append CC optimization flags to bazel.rc.. + + Args: + environ_cp: copy of the os.environ. + """ + if is_ppc64le(): + # gcc on ppc64le does not support -march, use mcpu instead + default_cc_opt_flags = '-mcpu=native' + elif is_windows(): + default_cc_opt_flags = '/arch:AVX' + else: + default_cc_opt_flags = '-march=native -Wno-sign-compare' + question = ('Please specify optimization flags to use during compilation when' + ' bazel option "--config=opt" is specified [Default is %s]: ' + ) % default_cc_opt_flags + cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', + question, default_cc_opt_flags) + for opt in cc_opt_flags.split(): + write_to_bazelrc('build:opt --copt=%s' % opt) + # It should be safe on the same build host. + if not is_ppc64le() and not is_windows(): + write_to_bazelrc('build:opt --host_copt=-march=native') + write_to_bazelrc('build:opt --define with_default_optimizations=true') + + + +def set_tf_download_clang(environ_cp): + """Set TF_DOWNLOAD_CLANG action_env.""" + question = 'Do you wish to download a fresh release of clang? (Experimental)' + yes_reply = 'Clang will be downloaded and used to compile tensorflow.' + no_reply = 'Clang will not be downloaded.' + set_action_env_var( + environ_cp, + 'TF_DOWNLOAD_CLANG', + None, + False, + question=question, + yes_reply=yes_reply, + no_reply=no_reply) + + +def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, + var_default): + """Get var_name either from env, or user or default. + + If var_name has been set as environment variable, use the preset value, else + ask for user input. If no input is provided, the default is used. + + Args: + environ_cp: copy of the os.environ. + var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". + ask_for_var: string for how to ask for user input. + var_default: default value string. + + Returns: + string value for var_name + """ + var = environ_cp.get(var_name) + if not var: + var = get_input(ask_for_var) + print('\n') + if not var: + var = var_default + return var + + +def prompt_loop_or_load_from_env(environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS): + """Loop over user prompts for an ENV param until receiving a valid response. + + For the env param var_name, read from the environment or verify user input + until receiving valid input. When done, set var_name in the environ_cp to its + new value. + + Args: + environ_cp: (Dict) copy of the os.environ. + var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". + var_default: (String) default value string. + ask_for_var: (String) string for how to ask for user input. + check_success: (Function) function that takes one argument and returns a + boolean. Should return True if the value provided is considered valid. May + contain a complex error message if error_msg does not provide enough + information. In that case, set suppress_default_error to True. + error_msg: (String) String with one and only one '%s'. Formatted with each + invalid response upon check_success(input) failure. + suppress_default_error: (Bool) Suppress the above error message in favor of + one from the check_success function. + n_ask_attempts: (Integer) Number of times to query for valid input before + raising an error and quitting. + + Returns: + [String] The value of var_name after querying for input. + + Raises: + UserInputError: if a query has been attempted n_ask_attempts times without + success, assume that the user has made a scripting error, and will + continue to provide invalid input. Raise the error to avoid infinitely + looping. + """ + default = environ_cp.get(var_name) or var_default + full_query = '%s [Default is %s]: ' % ( + ask_for_var, + default, + ) + + for _ in range(n_ask_attempts): + val = get_from_env_or_user_or_default(environ_cp, var_name, full_query, + default) + if check_success(val): + break + if not suppress_default_error: + print(error_msg % val) + environ_cp[var_name] = '' + else: + raise UserInputError('Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % + (var_name, n_ask_attempts)) + + environ_cp[var_name] = val + return val + + +def create_android_ndk_rule(environ_cp): + """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % + environ_cp['APPDATA']) + elif is_macos(): + default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + + def valid_ndk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'source.properties'))) + + android_ndk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_NDK_HOME', + var_default=default_ndk_path, + ask_for_var='Please specify the home path of the Android NDK to use.', + check_success=valid_ndk_path, + error_msg=('The path %s or its child file "source.properties" ' + 'does not exist.')) + write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) + write_action_env_to_bazelrc( + 'ANDROID_NDK_API_LEVEL', + get_ndk_api_level(environ_cp, android_ndk_home_path)) + + +def create_android_sdk_rule(environ_cp): + """Set Android variables and write Android SDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) + elif is_macos(): + default_sdk_path = '%s/library/Android/Sdk' % environ_cp['HOME'] + else: + default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] + + def valid_sdk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'platforms')) and + os.path.exists(os.path.join(path, 'build-tools'))) + + android_sdk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_SDK_HOME', + var_default=default_sdk_path, + ask_for_var='Please specify the home path of the Android SDK to use.', + check_success=valid_sdk_path, + error_msg=('Either %s does not exist, or it does not contain the ' + 'subdirectories "platforms" and "build-tools".')) + + platforms = os.path.join(android_sdk_home_path, 'platforms') + api_levels = sorted(os.listdir(platforms)) + api_levels = [x.replace('android-', '') for x in api_levels] + + def valid_api_level(api_level): + return os.path.exists( + os.path.join(android_sdk_home_path, 'platforms', + 'android-' + api_level)) + + android_api_level = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_API_LEVEL', + var_default=api_levels[-1], + ask_for_var=('Please specify the Android SDK API level to use. ' + '[Available levels: %s]') % api_levels, + check_success=valid_api_level, + error_msg='Android-%s is not present in the SDK path.') + + build_tools = os.path.join(android_sdk_home_path, 'build-tools') + versions = sorted(os.listdir(build_tools)) + + def valid_build_tools(version): + return os.path.exists( + os.path.join(android_sdk_home_path, 'build-tools', version)) + + android_build_tools_version = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_BUILD_TOOLS_VERSION', + var_default=versions[-1], + ask_for_var=('Please specify an Android build tools version to use. ' + '[Available versions: %s]') % versions, + check_success=valid_build_tools, + error_msg=('The selected SDK does not have build-tools version %s ' + 'available.')) + + write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', + android_build_tools_version) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path) + + +def get_ndk_api_level(environ_cp, android_ndk_home_path): + """Gets the appropriate NDK API level to use for the provided Android NDK path.""" + + # First check to see if we're using a blessed version of the NDK. + properties_path = '%s/source.properties' % android_ndk_home_path + if is_windows() or is_cygwin(): + properties_path = cygpath(properties_path) + with open(properties_path, 'r') as f: + filedata = f.read() + + revision = re.search(r'Pkg.Revision = (\d+)', filedata) + if revision: + ndk_version = revision.group(1) + else: + raise Exception('Unable to parse NDK revision.') + if int(ndk_version) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The NDK version in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_version, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + + # Now grab the NDK API level to use. Note that this is different from the + # SDK API level, as the NDK API level is effectively the *min* target SDK + # version. + platforms = os.path.join(android_ndk_home_path, 'platforms') + api_levels = sorted(os.listdir(platforms)) + api_levels = [ + x.replace('android-', '') for x in api_levels if 'android-' in x + ] + + def valid_api_level(api_level): + return os.path.exists( + os.path.join(android_ndk_home_path, 'platforms', + 'android-' + api_level)) + + android_ndk_api_level = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_NDK_API_LEVEL', + var_default='18', # 18 is required for GPU acceleration. + ask_for_var=('Please specify the (min) Android NDK API level to use. ' + '[Available levels: %s]') % api_levels, + check_success=valid_api_level, + error_msg='Android-%s is not present in the NDK path.') + + return android_ndk_api_level + + +def set_gcc_host_compiler_path(environ_cp): + """Set GCC_HOST_COMPILER_PATH.""" + default_gcc_host_compiler_path = which('gcc') + if os.path.islink(default_gcc_host_compiler_path): + # os.readlink is only available in linux + default_gcc_host_compiler_path = os.path.realpath(default_gcc_host_compiler_path) + + gcc_host_compiler_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='GCC_HOST_COMPILER_PATH', + var_default=default_gcc_host_compiler_path, + ask_for_var='Please specify which gcc should be used by nvcc as the host compiler.', + check_success=os.path.exists, + error_msg='Invalid gcc path. %s cannot be found.', + ) + + write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) + + +def reformat_version_sequence(version_str, sequence_count): + """Reformat the version string to have the given number of sequences. + + For example: + Given (7, 2) -> 7.0 + (7.0.1, 2) -> 7.0 + (5, 1) -> 5 + (5.0.3.2, 1) -> 5 + + Args: + version_str: String, the version string. + sequence_count: int, an integer. + + Returns: + string, reformatted version string. + """ + v = version_str.split('.') + if len(v) < sequence_count: + v = v + (['0'] * (sequence_count - len(v))) + + return '.'.join(v[:sequence_count]) + + +def set_host_cxx_compiler(environ_cp): + """Set HOST_CXX_COMPILER.""" + default_cxx_host_compiler = which('g++') or '' + + host_cxx_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_CXX_COMPILER', + var_default=default_cxx_host_compiler, + ask_for_var=('Please specify which C++ compiler should be used as the ' + 'host C++ compiler.'), + check_success=os.path.exists, + error_msg='Invalid C++ compiler path. %s cannot be found.', + ) + + write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler) + + +def set_host_c_compiler(environ_cp): + """Set HOST_C_COMPILER.""" + default_c_host_compiler = which('gcc') or '' + + host_c_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_C_COMPILER', + var_default=default_c_host_compiler, + ask_for_var=('Please specify which C compiler should be used as the host ' + 'C compiler.'), + check_success=os.path.exists, + error_msg='Invalid C compiler path. %s cannot be found.', + ) + + write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) + + +def set_opencl_sdk_root(environ_cp): + """Set OPENCL SDK ROOT""" + + def toolkit_exists(toolkit_path): + """Check if a CL header path is valid.""" + if toolkit_path == '': + return True + + if is_linux(): + cl_header_path = 'opencl/SDK/include/CL/cl.h' + else: + cl_header_path = '' + + cl_path_full = os.path.join(toolkit_path, cl_header_path) + exists = os.path.exists(cl_path_full) + if not exists: + print('Invalid OPENCL SDK ROOT path. %s cannot be found' % + (cl_path_full)) + return exists + + ocl_sdk_root = prompt_loop_or_load_from_env( + environ_cp, + var_name='OCL_SDK_ROOT', + var_default=_DEFAULT_OCL_SDK_ROOT, + ask_for_var=( + 'Please specify the location of opencl SDK install path ' + 'for ocl headers and libOpenCL.so'), + check_success=toolkit_exists, + error_msg='Invalid OPENCL SDK ROOT path.', + suppress_default_error=True) + + write_action_env_to_bazelrc('OCL_SDK_ROOT', + ocl_sdk_root) + +def set_gcc_toolchain_path(environ_cp): + """Set GCC_TOOLCHAIN_PATH.""" + def no_check(arg): + return True + + gcc_toolchain_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='GCC_TOOLCHAIN_PATH', + var_default=_DEFAULT_GCC_TOOLCHAIN_PATH, + ask_for_var=( + 'Please specify the location of gcc toolchain used by the compiler'), + check_success=no_check, + error_msg='Invalid GCC_TOOLCHAIN path.', + suppress_default_error=True) + + write_action_env_to_bazelrc('GCC_TOOLCHAIN_PATH', + gcc_toolchain_path) + return gcc_toolchain_path + +def set_gcc_toolchain_target(environ_cp, gcc_toolchain_path): + """Set GCC_TOOLCHAIN_TARGET.""" + if gcc_toolchain_path == "": + return "" + + def toolkit_exists(target): + """Check if a gcc toolchain-target is valid.""" + if is_linux(): + if target == '': + gcc_bin_path = 'bin/gcc' + else: + gcc_bin_path = 'bin/' + target + '-gcc' + else: + gcc_bin_path = '' + + gcc_bin_path_full = os.path.join(gcc_toolchain_path, gcc_bin_path) + exists = os.path.exists(gcc_bin_path_full) + if not exists: + print('Invalid GCC_TOOLCHAIN path and TARGET. %s cannot be found' % + (gcc_bin_path_full)) + return exists + + gcc_toolchain_target = prompt_loop_or_load_from_env( + environ_cp, + var_name='GCC_TOOLCHAIN_TARGET', + var_default=_DEFAULT_GCC_TOOLCHAIN_TARGET, + ask_for_var=( + 'Please specify the target of gcc toolchain (e.g. x86_64-pc-linux) ' + 'the compiler will use.'), + check_success=toolkit_exists, + error_msg='Invalid GCC_TOOLCHAIN_TARGET', + suppress_default_error=True) + + write_action_env_to_bazelrc('GCC_TOOLCHAIN_TARGET', + gcc_toolchain_target) + +def set_mpi_home(environ_cp): + """Set MPI_HOME.""" + + default_mpi_home = which('mpirun') or which('mpiexec') or '' + default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) + + def valid_mpi_path(mpi_home): + exists = ( + os.path.exists(os.path.join(mpi_home, 'include')) and + (os.path.exists(os.path.join(mpi_home, 'lib')) or + os.path.exists(os.path.join(mpi_home, 'lib64')) or + os.path.exists(os.path.join(mpi_home, 'lib32')))) + if not exists: + print( + 'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found' + % (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')), + os.path.exists(os.path.join(mpi_home, 'lib64')), + os.path.exists(os.path.join(mpi_home, 'lib32')))) + return exists + + _ = prompt_loop_or_load_from_env( + environ_cp, + var_name='MPI_HOME', + var_default=default_mpi_home, + ask_for_var='Please specify the MPI toolkit folder.', + check_success=valid_mpi_path, + error_msg='', + suppress_default_error=True) + + +def set_other_mpi_vars(environ_cp): + """Set other MPI related variables.""" + # Link the MPI header files + mpi_home = environ_cp.get('MPI_HOME') + symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h') + + # Determine if we use OpenMPI or MVAPICH, these require different header files + # to be included here to make bazel dependency checker happy + if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')): + symlink_force( + os.path.join(mpi_home, 'include/mpi_portable_platform.h'), + 'third_party/mpi/mpi_portable_platform.h') + # TODO(gunan): avoid editing files in configure + sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=False', + 'MPI_LIB_IS_OPENMPI=True') + else: + # MVAPICH / MPICH + symlink_force( + os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h') + symlink_force( + os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h') + # TODO(gunan): avoid editing files in configure + sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=True', + 'MPI_LIB_IS_OPENMPI=False') + + if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')): + symlink_force( + os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so') + elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')): + symlink_force( + os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so') + elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')): + symlink_force( + os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so') + + else: + raise ValueError( + 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % + (mpi_home, mpi_home, mpi_home)) + + +def set_system_libs_flag(environ_cp): + syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') + if syslibs: + if ',' in syslibs: + syslibs = ','.join(sorted(syslibs.split(','))) + else: + syslibs = ','.join(sorted(syslibs.split())) + write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) + + if 'PREFIX' in environ_cp: + write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) + if 'LIBDIR' in environ_cp: + write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR']) + if 'INCLUDEDIR' in environ_cp: + write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR']) + + +def set_windows_build_flags(environ_cp): + """Set Windows specific build options.""" + # The non-monolithic build is not supported yet + write_to_bazelrc('build --config monolithic') + # Suppress warning messages + write_to_bazelrc('build --copt=-w --host_copt=-w') + # Fix winsock2.h conflicts + write_to_bazelrc( + 'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN ' + '--copt=-DNOGDI --host_copt=-DNOGDI') + # Output more verbose information when something goes wrong + write_to_bazelrc('build --verbose_failures') + # The host and target platforms are the same in Windows build. So we don't + # have to distinct them. This avoids building the same targets twice. + write_to_bazelrc('build --distinct_host_configuration=false') + + if get_var( + environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', + True, ('Would you like to override eigen strong inline for some C++ ' + 'compilation to reduce the compilation time?'), + 'Eigen strong inline overridden.', 'Not overriding eigen strong inline, ' + 'some compilations could take more than 20 mins.'): + # Due to a known MSVC compiler issue + # https://github.com/tensorflow/tensorflow/issues/10521 + # Overriding eigen strong inline speeds up the compiling of + # conv_grad_ops_3d.cc and conv_ops_3d.cc by 20 minutes, + # but this also hurts the performance. Let users decide what they want. + write_to_bazelrc('build --define=override_eigen_strong_inline=true') + + +def config_info_line(name, help_text): + """Helper function to print formatted help text for Bazel config options.""" + print('\t--config=%-12s\t# %s' % (name, help_text)) + + +def main(): + global _TF_WORKSPACE_ROOT + global _TF_BAZELRC + global _TF_CURRENT_BAZEL_VERSION + + parser = argparse.ArgumentParser() + parser.add_argument( + '--workspace', + type=str, + default=os.path.abspath(os.path.dirname(__file__)), + help='The absolute path to your active Bazel workspace.') + args = parser.parse_args() + + _TF_WORKSPACE_ROOT = args.workspace + _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) + + # Make a copy of os.environ to be clear when functions and getting and setting + # environment variables. + environ_cp = dict(os.environ) + + current_bazel_version = check_bazel_version('3.1.0', '3.7.0') + _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) + + reset_tf_configure_bazelrc() + + cleanup_makefile() + setup_python(environ_cp) + create_build_configuration(environ_cp) + + if is_windows(): + environ_cp['TF_DOWNLOAD_CLANG'] = '0' + environ_cp['TF_NEED_MPI'] = '0' + + # The numpy package on ppc64le uses OpenBLAS which has multi-threading + # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at + # runtime to allow the Tensorflow testcases which compare numpy + # results to Tensorflow results to succeed. + if is_ppc64le(): + write_action_env_to_bazelrc('OMP_NUM_THREADS', 1) + + + set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) + if environ_cp.get('TF_NEED_MPI') == '1': + set_mpi_home(environ_cp) + set_other_mpi_vars(environ_cp) + + set_cc_opt_flags(environ_cp) + set_system_libs_flag(environ_cp) + if is_windows(): + set_windows_build_flags(environ_cp) + +if __name__ == '__main__': + main() diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/conv_relu.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/conv_relu.py new file mode 100644 index 000000000..276433c24 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/conv_relu.py @@ -0,0 +1,38 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +#!/usr/bin/env python +# coding=utf-8 +import tensorflow as tf +import numpy as np +tf.compat.v1.disable_eager_execution() +a = tf.random.normal(shape=[1,10, 10, 8], dtype=tf.float32, seed=1) +w = tf.random.normal(shape=[3, 3, 8, 4], dtype=tf.float32, seed=1) + +a1 = tf.random.normal(shape=[1, 10, 10, 8], dtype=tf.float32, seed=1) +w1 = tf.random.normal(shape=[3, 3, 8, 4], dtype=tf.float32, seed=1) + + +with tf.device("/MY_DEVICE:0"): + b = tf.nn.relu(a) + c = tf.nn.conv2d(b, w, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC') + +with tf.device("/CPU:0"): + b1 = tf.nn.relu(a1) + c1 = tf.nn.conv2d(b1, w1, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC') + + +sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=False, log_device_placement=True)) +print(sess.run(tf.reduce_all(tf.less(c - c1, 1e-5)))) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/genpip.sh b/rfcs/20200624-pluggable-device-for-tensorflow/sample/genpip.sh new file mode 100644 index 000000000..f344dfc41 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/genpip.sh @@ -0,0 +1,4 @@ +#!/bin/bash +PACKAGE_PATH=$PWD +bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package $PACKAGE_PATH + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/profiler_result.png b/rfcs/20200624-pluggable-device-for-tensorflow/sample/profiler_result.png new file mode 100644 index 000000000..b10cef062 Binary files /dev/null and b/rfcs/20200624-pluggable-device-for-tensorflow/sample/profiler_result.png differ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/relu.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/relu.py new file mode 100644 index 000000000..7ca4574ab --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/relu.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +# coding=utf-8 +import tensorflow as tf +import numpy as np +tf.compat.v1.disable_eager_execution() +a = tf.random.normal(shape=[10], dtype=tf.float32) + +with tf.device("/MY_DEVICE:0"): + b = tf.nn.relu(a) + +sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=False, log_device_placement=True)) +print(sess.run(b)) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/BUILD new file mode 100644 index 000000000..636069aa7 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/BUILD @@ -0,0 +1,27 @@ +cc_binary( + name = "libdemo_plugin.so", + linkshared = True, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin/src:plugin_device", + "//tensorflow_plugin/src:plugin_graph", + "//tensorflow_plugin/src:plugin_kernel", + "//tensorflow_plugin/src:plugin_profiler", + ], +) + +config_setting( + name = "linux_x86_64", + values = {"cpu": "k8"}, + visibility = ["//visibility:public"], +) + +cc_library( + name = "core", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin/src/device/cpu:cpu_device_impl", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = True, +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/build_config.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/build_config.bzl new file mode 100644 index 000000000..01d4a549c --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/build_config.bzl @@ -0,0 +1,35 @@ +# Platform-specific build configurations. + +load("@com_google_protobuf//:protobuf.bzl", "proto_gen") +load("//tensorflow_plugin:workspace.bzl", "clean_dep") +load("@rules_cc//cc:defs.bzl", "cc_library") + +def cc_proto(name, src, deps = []): + native.genrule( + name = "%s_cc" % name, + outs = ["%s.pb.cc" % name, "%s.pb.h" % name], + cmd = "echo $(GENDIR); which $(location @com_google_protobuf//:protoc); $(location @com_google_protobuf//:protoc) --cpp_out=$(GENDIR) $<", + srcs = [src], + tools = ["@com_google_protobuf//:protoc"], + ) + native.cc_library( + name = "%s_proto" % name, + srcs = ["%s.pb.cc" % name], + hdrs = ["%s.pb.h" % name], + deps = [ + "@com_google_protobuf//:protobuf_headers", + "@com_google_protobuf//:protobuf", + ] + deps, + copts = ["-I$(GENDIR)"], + ) + +def if_static(extra_deps = [], otherwise = []): + return otherwise + +def tf_protobuf_deps(): + return if_static( + [ + clean_dep("@com_google_protobuf//:protobuf"), + ], + otherwise = [clean_dep("@com_google_protobuf//:protobuf_headers")], + ) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/demo_plugin.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/demo_plugin.bzl new file mode 100644 index 000000000..865909429 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/demo_plugin.bzl @@ -0,0 +1,56 @@ +# Return the options to use for a C++ library or binary build. +# Uses the ":optmode" config_setting to pick the options. + +def if_linux_x86_64(a, otherwise = []): + return select({ + "//conditons:default": otherwise, + }) + +def tf_copts(android_optimization_level_override = "-O2", is_external = False): + # For compatibility reasons, android_optimization_level_override + # is currently only being set for Android. + # To clear this value, and allow the CROSSTOOL default + # to be used, pass android_optimization_level_override=None + return ( + [ + "-Wno-sign-compare", + "-fno-exceptions", + "-ftemplate-depth=900", + "-msse3", + "-pthread", + ] + ) + +def _get_transitive_headers(hdrs, deps): + return depset( + hdrs, + transitive = [dep[CcInfo].compilation_context.headers for dep in deps], + ) + +def _transitive_hdrs_impl(ctx): + outputs = _get_transitive_headers([], ctx.attr.deps) + return struct(files = outputs) + +_transitive_hdrs = rule( + attrs = { + "deps": attr.label_list( + allow_files = True, + providers = [CcInfo], + ), + }, + implementation = _transitive_hdrs_impl, +) + +def transitive_hdrs(name, deps = [], **kwargs): + _transitive_hdrs(name = name + "_gather", deps = deps) + native.filegroup(name = name, srcs = [":" + name + "_gather"]) + +def cc_header_only_library(name, deps = [], includes = [], extra_deps = [], **kwargs): + _transitive_hdrs(name = name + "_gather", deps = deps) + native.cc_library( + name = name, + srcs = [":" + name + "_gather"], + hdrs = includes, + deps = extra_deps, + **kwargs + ) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/__init__.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/__init__.py new file mode 100644 index 000000000..8d1c8b69c --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/__init__.py @@ -0,0 +1 @@ + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/test.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/test.py new file mode 100644 index 000000000..b5e88fec1 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/python/test.py @@ -0,0 +1,20 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This is a module for dummy test.""" + +import os + +if __name__ == '__main__': + print(os.path.realpath('.')) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/BUILD new file mode 100644 index 000000000..1a8c74193 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/BUILD @@ -0,0 +1,61 @@ +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "plugin_device", + srcs = ["plugin_device.cc"], + hdrs = ["plugin_device.h"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin/src/device/cpu:cpu_device_impl", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = True, +) + +cc_library( + name = "plugin_graph", + srcs = ["plugin_graph.cc"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":plugin_device", + "//tensorflow_plugin/src/graph:plugin_optimizer", + "//tensorflow_plugin/src/kernels/cpu:cpu_kernel_impl", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = True, +) + +cc_library( + name = "plugin_kernel", + srcs = ["plugin_kernel.cc"], + hdrs = ["plugin_kernel.h"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":plugin_device", + "//tensorflow_plugin/src/kernels/cpu:cpu_kernel_impl", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = True, +) + + +cc_library( + name = "plugin_profiler", + srcs = ["plugin_profiler.cc"], + hdrs = ["plugin_device.h"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin/src/profiler/cpu:demo_profiler", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = True, +) + + + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/BUILD new file mode 100644 index 000000000..6d997f51b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/BUILD @@ -0,0 +1,25 @@ +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "cpu_device_impl", + srcs = ["cpu_device_plugin.cc"], + hdrs = [ + "cpu_device_plugin.h", + ], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "@local_config_gcc//:framework_lib", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = True, +) + +exports_files( + srcs = [ + "cpu_device_plugin.h", + ], + visibility = ["//visibility:public"], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.cc new file mode 100644 index 000000000..db8c10b73 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.cc @@ -0,0 +1,294 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include "tensorflow_plugin/src/device/cpu/cpu_device_plugin.h" +#include + +#include +#include +#include + +#include + +void plugin_get_device_count(const SP_Platform* platform, int* device_count, + TF_Status* status) { + *device_count = 1; +} + +void plugin_create_device(const SP_Platform* platform, + SE_CreateDeviceParams* params, + TF_Status* const status) { + params->device->struct_size = SP_DEVICE_STRUCT_SIZE; + params->device->device_handle = nullptr; + params->device->ordinal = 0; +} + +void plugin_destroy_device(const SP_Platform* platform, SP_Device* device) { + device->device_handle = nullptr; + device->ordinal = -1; +} + +void plugin_create_device_fns(const SP_Platform* platform, + SE_CreateDeviceFnsParams* params, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE}; +} +void plugin_destroy_device_fns(const SP_Platform* platform, + SP_DeviceFns* device_fns) {} + +/*StreamExecutor Backend Impl*/ +void plugin_allocate(const SP_Device* device, uint64_t size, + int64_t memory_space, SP_DeviceMemoryBase* mem) { + mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE; + mem->opaque = aligned_alloc(64, size); + mem->size = size; +} + +void plugin_deallocate(const SP_Device* device, SP_DeviceMemoryBase* mem) { + free(mem->opaque); + mem->opaque = nullptr; + mem->size = 0; +} + +void* plugin_host_memory_allocate(const SP_Device* device, uint64_t size) { + void* ptr = aligned_alloc(64, size); + return ptr; +} + +void plugin_host_memory_deallocate(const SP_Device* device, void* mem) { + free(mem); +} + +TF_Bool plugin_get_allocator_stats(const SP_Device* device, + SP_AllocatorStats* stats) { + stats->struct_size = SP_ALLOCATORSTATS_STRUCT_SIZE; + stats->bytes_in_use = 123; + return true; +} + +// TODO(plugin):Check correctness of this function +TF_Bool plugin_device_memory_usage(const SP_Device* device, int64_t* free, + int64_t* total) { + struct sysinfo info; + int err = sysinfo(&info); + *free = info.freeram; + *total = info.totalram; + + return (err == 0); +} + +void plugin_create_stream(const SP_Device* device, SP_Stream* stream, + TF_Status* status) {} + +// Destroys SP_Stream and deallocates any underlying resources. +void plugin_destroy_stream(const SP_Device* device, SP_Stream stream) {} + +void plugin_create_stream_dependency(const SP_Device* device, + SP_Stream dependent, SP_Stream other, + TF_Status* status) {} + +// Without blocking the device, retrieve the current stream status. +void plugin_get_stream_status(const SP_Device* device, SP_Stream stream, + TF_Status* status) {} + +void plugin_create_event(const SP_Device* device, SP_Event* event, + TF_Status* status) {} + +// Destroy SE_Event and perform any platform-specific deallocation and +// cleanup of an event. +void plugin_destroy_event(const SP_Device* device, SP_Event event) {} + +// Requests the current status of the event from the underlying platform. +SE_EventStatus plugin_get_event_status(const SP_Device* device, + SP_Event event) { + return SE_EVENT_COMPLETE; +} + +// Inserts the specified event at the end of the specified stream. +void plugin_record_event(const SP_Device* device, SP_Stream stream, + SP_Event event, TF_Status* status) {} + +// Wait for the specified event at the end of the specified stream. +void plugin_wait_for_event(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) {} + +/*** TIMER CALLBACKS ***/ +// Creates SP_Timer. Allocates timer resources on the underlying platform +// and initializes its internals, setting `timer` output variable. Sets +// values in `timer_fns` struct. +void plugin_create_timer(const SP_Device* device, SP_Timer* timer, + TF_Status* status) {} + +// Destroy timer and deallocates timer resources on the underlying platform. +void plugin_destroy_timer(const SP_Device* device, SP_Timer timer) {} + +// Records a start event for an interval timer. +void plugin_start_timer(const SP_Device* device, SP_Stream stream, + SP_Timer timer, TF_Status* status) {} + +// Records a stop event for an interval timer. +void plugin_stop_timer(const SP_Device* device, SP_Stream stream, + SP_Timer timer, TF_Status* status) {} + +/*** MEMCPY CALLBACKS ***/ +// Enqueues a memcpy operation onto stream, with a host destination location +// `host_dst` and a device memory source, with target size `size`. +void plugin_memcpy_dtoh(const SP_Device* device, SP_Stream stream, + void* host_dst, const SP_DeviceMemoryBase* device_src, + uint64_t size, TF_Status* status) { + memcpy(host_dst, device_src->opaque, size); +} + +// Enqueues a memcpy operation onto stream, with a device destination +// location and a host memory source, with target size `size`. +void plugin_memcpy_htod(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* device_dst, const void* host_src, + uint64_t size, TF_Status* status) { + memcpy(device_dst->opaque, host_src, size); +} + +// Enqueues a memcpy operation onto stream, with a device destination +// location and a device memory source, with target size `size`. +void plugin_memcpy_dtod(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* device_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status) { + memcpy(device_dst->opaque, device_src->opaque, size); +} + +// Blocks the caller while a data segment of the given size is +// copied from the device source to the host destination. +void plugin_sync_memcpy_dtoh(const SP_Device* device, void* host_dst, + const SP_DeviceMemoryBase* device_src, + uint64_t size, TF_Status* status) { + memcpy(host_dst, device_src->opaque, size); +} + +// Blocks the caller while a data segment of the given size is +// copied from the host source to the device destination. +void plugin_sync_memcpy_htod(const SP_Device* device, + SP_DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size, + TF_Status* status) { + memcpy(device_dst->opaque, host_src, size); +} + +// Blocks the caller while a data segment of the given size is copied from the +// device source to the device destination. +void plugin_sync_memcpy_dtod(const SP_Device* device, + SP_DeviceMemoryBase* device_dst, + const SP_DeviceMemoryBase* device_src, + uint64_t size, TF_Status* status) { + memcpy(device_dst->opaque, device_src->opaque, size); +} + +// Causes the host code to synchronously wait for the event to complete. +void plugin_block_host_for_event(const SP_Device* device, SP_Event event, + TF_Status* status) {} + +void plugin_block_host_until_done(const SP_Device* device, SP_Stream stream, + TF_Status* status) {} + +// Synchronizes all activity occurring in the StreamExecutor's context (most +// likely a whole device). +void plugin_synchronize_all_activity(const SP_Device* device, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); +} + +// Enqueues on a stream a user-specified function to be run on the host. +// `callback_arg` should be passed as the first argument to `callback_fn`. +TF_Bool plugin_host_callback(const SP_Device* device, SP_Stream stream, + SE_StatusCallbackFn callback_fn, + void* callback_arg) { + return TF_OK; +} + +/*Timer Backer Impl*/ +uint64_t nanoseconds(SP_Timer timer) { return timer->timer_handle; } + +void plugin_create_timer_fns(const SP_Platform* platform, + SP_TimerFns* timer_fns, TF_Status* const status) { + timer_fns->nanoseconds = nanoseconds; +} + +void plugin_destroy_timer_fns(const SP_Platform* platform, + SP_TimerFns* timer_fns) {} + +void plugin_create_stream_executor(const SP_Platform* platform, + SE_CreateStreamExecutorParams* params, + TF_Status* const status) { + params->stream_executor->struct_size = SP_STREAMEXECUTOR_STRUCT_SIZE; + params->stream_executor->allocate = plugin_allocate; + params->stream_executor->deallocate = plugin_deallocate; + params->stream_executor->host_memory_allocate = plugin_host_memory_allocate; + params->stream_executor->host_memory_deallocate = + plugin_host_memory_deallocate; + params->stream_executor->get_allocator_stats = plugin_get_allocator_stats; + params->stream_executor->device_memory_usage = plugin_device_memory_usage; + + params->stream_executor->create_stream = plugin_create_stream; + params->stream_executor->destroy_stream = plugin_destroy_stream; + params->stream_executor->create_stream_dependency = + plugin_create_stream_dependency; + params->stream_executor->get_stream_status = plugin_get_stream_status; + params->stream_executor->create_event = plugin_create_event; + params->stream_executor->destroy_event = plugin_destroy_event; + params->stream_executor->get_event_status = plugin_get_event_status; + params->stream_executor->record_event = plugin_record_event; + params->stream_executor->wait_for_event = plugin_wait_for_event; + params->stream_executor->create_timer = plugin_create_timer; + params->stream_executor->destroy_timer = plugin_destroy_timer; + params->stream_executor->start_timer = plugin_start_timer; + params->stream_executor->stop_timer = plugin_stop_timer; + + params->stream_executor->memcpy_dtoh = plugin_memcpy_dtoh; + params->stream_executor->memcpy_htod = plugin_memcpy_htod; + params->stream_executor->memcpy_dtod = plugin_memcpy_dtod; + params->stream_executor->sync_memcpy_dtoh = plugin_sync_memcpy_dtoh; + params->stream_executor->sync_memcpy_htod = plugin_sync_memcpy_htod; + params->stream_executor->sync_memcpy_dtod = plugin_sync_memcpy_dtod; + + // TODO(plugin): Fill the function for block stream + params->stream_executor->block_host_until_done = plugin_block_host_until_done; + params->stream_executor->block_host_for_event = plugin_block_host_for_event; + + params->stream_executor->synchronize_all_activity = + plugin_synchronize_all_activity; + params->stream_executor->host_callback = plugin_host_callback; +} + +void plugin_destroy_stream_executor(const SP_Platform* platform, + SP_StreamExecutor* stream_executor) {} + +void plugin_destroy_platform(SP_Platform* const platform) {} +void plugin_destroy_platform_fns(SP_PlatformFns* const platform_fns) {} + +void SE_InitPluginFns(SE_PlatformRegistrationParams* const params, + TF_Status* const status) { + params->platform_fns->get_device_count = plugin_get_device_count; + params->platform_fns->create_device = plugin_create_device; + params->platform_fns->destroy_device = plugin_destroy_device; + params->platform_fns->create_device_fns = plugin_create_device_fns; + params->platform_fns->destroy_device_fns = plugin_destroy_device_fns; + params->platform_fns->create_stream_executor = plugin_create_stream_executor; + params->platform_fns->destroy_stream_executor = + plugin_destroy_stream_executor; + params->platform_fns->create_timer_fns = plugin_create_timer_fns; + params->platform_fns->destroy_timer_fns = plugin_destroy_timer_fns; + params->destroy_platform = plugin_destroy_platform; + params->destroy_platform_fns = plugin_destroy_platform_fns; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.h new file mode 100644 index 000000000..e77bf369b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/device/cpu/cpu_device_plugin.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_DEVICE_CPU_H_ +#define TENSORFLOW_PLUGIN_SRC_DEVICE_CPU_H_ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/c/tf_status.h" + +void SE_InitPluginFns(SE_PlatformRegistrationParams* const params, + TF_Status* const status); + +struct SP_Stream_st { + explicit SP_Stream_st(void* stream_h) : stream_handle(stream_h) {} + void* stream_handle; +}; + +struct SP_Event_st { + explicit SP_Event_st(void* event_h) : event_handle(event_h) {} + void* event_handle; +}; + +struct SP_Timer_st { + explicit SP_Timer_st(int id) : timer_handle(id) {} + int timer_handle; +}; + +#endif // TENSORFLOW_PLUGIN_SRC_DEVICE_CPU_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/BUILD new file mode 100644 index 000000000..060eeaa89 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/BUILD @@ -0,0 +1,30 @@ +load( + "//tensorflow_plugin:build_config.bzl", + "tf_protobuf_deps", +) + +cc_library( + name = "tf_buffer", + srcs = ["tf_buffer.cc"], + hdrs = [ + "tf_buffer.h", + ], + visibility = ["//visibility:public"], + deps = [ + "@local_config_tf//:tf_header_lib", + ] + tf_protobuf_deps(), +) + +cc_library( + name = "plugin_optimizer", + srcs = ["plugin_optimizer.cc"], + hdrs = ["plugin_optimizer.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin/src/utils:protos_all", + # "//tensorflow_plugin/src/utils:types_proto", + "tf_buffer", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = True, +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.cc new file mode 100644 index 000000000..e24e4b808 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.cc @@ -0,0 +1,53 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include "tensorflow_plugin/src/graph/plugin_optimizer.h" +#include "tensorflow/c/experimental/grappler/grappler.h" +#include "tensorflow_plugin/src/graph/tf_buffer.h" +#include "tensorflow_plugin/src/utils/graph.pb.h" + +namespace demo_plugin { +namespace graph { + +void *Optimizer_Create() { + auto *optimizer = new Optimizer; + return reinterpret_cast(optimizer); +} + +void Optimizer_Destroy(void *optimizer) { + if (optimizer) + delete reinterpret_cast(optimizer); +} + +void Optimizer_Optimize(void *optimizer, const TF_Buffer *graph_buf, + const TF_GrapplerItem *item, + TF_Buffer *optimized_graph_buf, TF_Status *tf_status) { + // Deserialize graph_buf into GraphDef. + GraphDef graph_def; + BufferToMessage(graph_buf, graph_def, tf_status); + if (TF_GetCode(tf_status) != TF_OK) + return; + + // Doing graph transformation. + GraphDef optimized_graph_def = graph_def; + + // Serialize output GraphDef into optimized_graph_buf. + MessageToBuffer(optimized_graph_def, optimized_graph_buf, tf_status); + if (TF_GetCode(tf_status) != TF_OK) + return; +} + +} // namespace graph +} // namespace demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.h new file mode 100644 index 000000000..55b73cf25 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/plugin_optimizer.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_GRAPH_PLUGIN_OPTIMIZER_H_ +#define TENSORFLOW_PLUGIN_SRC_GRAPH_PLUGIN_OPTIMIZER_H_ + +#include "tensorflow/c/experimental/grappler/grappler.h" + +namespace demo_plugin { +namespace graph { + +typedef struct Optimizer { +} Optimizer; + +void *Optimizer_Create(); + +void Optimizer_Destroy(void *optimizer); + +void Optimizer_Optimize(void *optimizer, const TF_Buffer *graph_buf, + const TF_GrapplerItem *item, + TF_Buffer *optimized_graph_buf, TF_Status *tf_status); + +} // namespace graph +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_GRAPH_PLUGIN_OPTIMIZER_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.cc new file mode 100644 index 000000000..88be206e2 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.cc @@ -0,0 +1,57 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include "tensorflow_plugin/src/graph/tf_buffer.h" +#include "tensorflow/c/tf_status.h" + +namespace demo_plugin { + +void MessageToBuffer(const demo_plugin::protobuf::MessageLite& in, + TF_Buffer* out, TF_Status* status) { + if (out->data != nullptr) { + TF_SetStatus(status, TF_Code::TF_INVALID_ARGUMENT, + "Passing non-empty TF_Buffer is invalid."); + return; + } + const size_t proto_size = in.ByteSizeLong(); + void* buf = malloc(proto_size); + if (buf == nullptr) { + TF_SetStatus(status, TF_Code::TF_RESOURCE_EXHAUSTED, + "Failed to allocate memory to serialize message."); + return; + } + if (!in.SerializeWithCachedSizesToArray(static_cast(buf))) { + free(buf); + TF_SetStatus(status, TF_Code::TF_INVALID_ARGUMENT, + "Unable to serialize protocol buffer."); + return; + } + out->data = buf; + out->length = proto_size; + out->data_deallocator = [](void* data, size_t length) { free(data); }; + TF_SetStatus(status, TF_Code::TF_OK, ""); +} + +void BufferToMessage(const TF_Buffer* in, + demo_plugin::protobuf::MessageLite& out, + TF_Status* status) { + if (in == nullptr || !out.ParseFromArray(in->data, in->length)) { + TF_SetStatus(status, TF_Code::TF_INVALID_ARGUMENT, "Unparsable proto."); + return; + } + TF_SetStatus(status, TF_Code::TF_OK, ""); +} + +} // namespace demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.h new file mode 100644 index 000000000..745330f24 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/graph/tf_buffer.h @@ -0,0 +1,54 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_GRAPH_UTILS_TF_BUFFER_H_ +#define TENSORFLOW_PLUGIN_SRC_GRAPH_UTILS_TF_BUFFER_H_ + +#include "tensorflow/c/c_api.h" + +// Import whatever namespace protobuf comes from into the +// ::tensorflow::protobuf namespace. +// +// TensorFlow code should use the ::tensorflow::protobuf namespace to +// refer to all protobuf APIs. + +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/tokenizer.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/map.h" +#include "google/protobuf/message.h" +#include "google/protobuf/repeated_field.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/json_util.h" +#include "google/protobuf/util/type_resolver_util.h" + +namespace demo_plugin { + +namespace protobuf = ::google::protobuf; + +void MessageToBuffer(const demo_plugin::protobuf::MessageLite& in, + TF_Buffer* out, TF_Status* status); + +void BufferToMessage(const TF_Buffer* in, + demo_plugin::protobuf::MessageLite& out, + TF_Status* status); +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_GRAPH_UTILS_TF_BUFFER_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/BUILD new file mode 100644 index 000000000..790abc431 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/BUILD @@ -0,0 +1,53 @@ +load("//tensorflow_plugin:demo_plugin.bzl", "tf_copts") + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "relu_op", + srcs = ["relu_op.cc"], + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin:core", + "@com_google_absl//absl/container:inlined_vector", + ], + alwayslink = True, +) + +cc_library( + name = "conv2d_op", + srcs = ["conv_ops_using_gemm.cc"], + hdrs = ["gemm_functors.h"], + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin:core", + "@com_google_absl//absl/container:inlined_vector", + "@eigen_archive//:eigen", + "//third_party/eigen3" + ], + alwayslink = True, +) + +CPU_KERNELS = [ + ":relu_op", + ":conv2d_op", +] + +cc_library( + name = "cpu_kernel_impl", + srcs = ["cpu_kernel_init.cc"], + hdrs = [ + "cpu_kernel_init.h", + "//tensorflow_plugin/src/device/cpu:cpu_device_plugin.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin:core", + ] + CPU_KERNELS, + alwayslink = True, +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/conv_ops_using_gemm.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/conv_ops_using_gemm.cc new file mode 100644 index 000000000..f1758d192 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/conv_ops_using_gemm.cc @@ -0,0 +1,623 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +//#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "gemm_functors.h" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/ops.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" + +namespace demo_plugin { + +struct StatusDeleter { + void operator()(TF_Status *s) { + if (s != nullptr) { + TF_DeleteStatus(s); + } + } +}; + +struct TensorDeleter { + void operator()(TF_Tensor *t) { + if (t != nullptr) { + TF_DeleteTensor(t); + } + } +}; + +using StatusSafePtr = std::unique_ptr; +using TensorSafePtr = std::unique_ptr; + +enum Padding { + VALID = 1, // No padding. + SAME = 2, // Input and output layers have the same size. + EXPLICIT = 3, // Padding is explicitly specified. +}; + +enum TensorFormat { + FORMAT_NHWC = 0, + FORMAT_NCHW = 1, + FORMAT_NCHW_VECT_C = 2, + FORMAT_NHWC_VECT_W = 3, + FORMAT_HWNC = 4, + FORMAT_HWCN = 5, +}; + +template struct TypeToEnum {}; + +template <> struct TypeToEnum { + static TF_DataType v() { return TF_DataType::TF_FLOAT; } +}; + +template <> struct TypeToEnum { + static TF_DataType v() { return TF_DataType::TF_DOUBLE; } +}; + +template <> struct TypeToEnum { + static TF_DataType v() { return TF_DataType::TF_HALF; } +}; + +template <> struct TypeToEnum { + static TF_DataType v() { return TF_DataType::TF_BFLOAT16; } +}; + +static bool GetWindowedOutputSize(int64_t input_size, int64_t filter_size, + int64_t dilation_rate, int64_t stride, + Padding padding_type, int64_t *output_size, + int64_t *padding_before) { + if (stride <= 0) { + std::cerr << "Stride must be > 0, but got " << stride << std::endl; + return false; + } + if (dilation_rate < 1) { + std::cerr << "Dilation rate must be >= 1, but got " << dilation_rate + << std::endl; + return false; + } + + int64_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; + switch (padding_type) { + case Padding::VALID: + *output_size = (input_size - effective_filter_size + stride) / stride; + *padding_before = 0; + break; + case Padding::SAME: + *output_size = (input_size + stride - 1) / stride; + const int64_t padding_needed = + std::max(int64_t{0}, (*output_size - 1) * stride + + effective_filter_size - input_size); + // For odd values of total padding, add more padding at the 'right' + // side of the given dimension. + *padding_before = padding_needed / 2; + break; + } + if (*output_size < 0) { + std::cerr << "Computed output size would be negative: " << *output_size + << " [input_size: " << input_size + << ", effective_filter_size: " << effective_filter_size + << ", stride: " << stride << "]" << std::endl; + return false; + } + return true; +} + +static int64_t GetTensorDim(TF_Tensor *tensor, std::string &format, char dim) { + int idx = -1; + if (format == "NCHW") { + switch (dim) { + case 'N': { + idx = 0; + break; + } + case 'C': { + idx = 1; + break; + } + case 'H': { + idx = 2; + break; + } + case 'W': { + idx = 3; + break; + } + default: { + idx = -1; + } + } + } else if (format == "NHWC") { + switch (dim) { + case 'N': { + idx = 0; + break; + } + case 'C': { + idx = 3; + break; + } + case 'H': { + idx = 1; + break; + } + case 'W': { + idx = 2; + break; + } + default: { + idx = -1; + } + } + } else { + std::cerr << "Unsupport data_format now" << std::endl; + return -1; + } + return TF_Dim(tensor, idx); +} + +#define CHECK_CONSTRUCT_STATUS(ctx, status) \ + do { \ + if (TF_GetCode(status) != TF_OK) { \ + TF_OpKernelConstruction_Failure(ctx, status); \ + } \ + } while (0); + +#define CHECK_CTX_STATUS(ctx, status) \ + do { \ + if (TF_GetCode(status) != TF_OK) { \ + TF_OpKernelContext_Failure(ctx, status); \ + } \ + } while (0); + +namespace { +const size_t kMaxChunkSize = (16 * 1024 * 1024); + +// Implements convolution as a two stage process, first packing the patches of +// the input image into columns (im2col) and then running GEMM to produce the +// final result. +template +class Im2ColConvFunctor { +public: + void operator()(const T1 *input_data, int input_batches, int input_height, + int input_width, int input_depth, const T2 *filter_data, + int filter_height, int filter_width, int filter_count, + int stride_rows, int stride_cols, Padding padding, + T3 *output_data, int output_height, int output_width) { + if ((input_batches <= 0) || (input_width <= 0) || (input_height <= 0) || + (input_depth <= 0)) { + std::cerr << "Conv2D was called with bad input dimensions: " + << input_batches << ", " << input_height << ", " << input_width + << ", " << input_depth; + return; + } + if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) { + std::cerr << "Conv2D was called with bad filter dimensions: " + << filter_width << ", " << filter_height << ", " + << filter_count; + return; + } + if ((output_width <= 0) || (output_height <= 0)) { + std::cerr << "Conv2D was called with bad output width or height: " + << output_width << ", " << output_height; + return; + } + // We can just use a GEMM if the im2col is the identity operator, e.g., if + // // the kernel is 1x1 or the input data and filter have same height/width. + if (filter_height == 1 && filter_width == 1 && stride_rows == 1 && + stride_cols == 1) { + // The kernel is 1x1. + const int m = input_batches * input_height * input_width; + const int n = filter_count; + const int k = input_depth; + const int lda = k; + const int ldb = filter_count; + const int ldc = filter_count; + TGemmFunctor gemm_functor; + gemm_functor(m, n, k, input_data, lda, filter_data, ldb, output_data, + ldc); + return; + } else if (filter_height == input_height && filter_width == input_width && + padding == VALID) { + // The input data and filter have the same height/width. + const int m = input_batches; + const int n = filter_count; + const int k = input_height * input_width * input_depth; + const int lda = k; + const int ldb = filter_count; + const int ldc = filter_count; + TGemmFunctor gemm_functor; + gemm_functor(m, n, k, input_data, lda, filter_data, ldb, output_data, + ldc); + return; + } + + // These calculations define how the patches will be positioned within the + // input image. The actual definitions are quite complex, and rely on the + // previously-calculated output size. + int filter_left_offset; + int filter_top_offset; + if (padding == VALID) { + filter_left_offset = + ((output_width - 1) * stride_cols + filter_width - input_width + 1) / + 2; + filter_top_offset = ((output_height - 1) * stride_rows + filter_height - + input_height + 1) / + 2; + } else { + filter_left_offset = + ((output_width - 1) * stride_cols + filter_width - input_width) / 2; + filter_top_offset = + ((output_height - 1) * stride_rows + filter_height - input_height) / + 2; + } + + // The im2col buffer has # of patches rows, and # of filters cols. + // It's laid out like this, in row major order in memory: + // < filter value count > + // ^ +---------------------+ + // patch | | + // count | | + // v +---------------------+ + // Each patch row contains a filter_width x filter_height patch of the + // input, with the depth channel as the most contiguous in memory, followed + // by the width, then the height. This is the standard memory order in the + // image world if it helps to visualize it. + const int filter_value_count = filter_width * filter_height * input_depth; + if ((filter_value_count * sizeof(T1)) > kMaxChunkSize) { + std::cerr << "Im2Col patch too large for buffer" << std::endl; + return; + } + const int64_t patches_per_chunk = + kMaxChunkSize / (filter_value_count * sizeof(T1)); + const int64_t chunk_value_count = + (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); + // This means that multiple ops can't be run simultaneously on different + // threads, because we have a single shared resource. The platforms this is + // aimed at have intra-op parallelism as their focus though, so it shouldn't + // be an issue. + // T1* im2col_buffer = new T1[chunk_value_count]; + std::unique_ptr im2col_buffer(new T1[chunk_value_count]); + + const int64_t patch_count = (input_batches * output_height * output_width); + const int64_t chunk_count = + (patch_count + (patches_per_chunk - 1)) / patches_per_chunk; + for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { + const int64_t patch_index_start = chunk_index * patches_per_chunk; + const int64_t patch_index_end = + std::min(patch_index_start + patches_per_chunk, patch_count); + for (int64_t patch_index = patch_index_start; + patch_index < patch_index_end; ++patch_index) { + const int64_t batch = patch_index / (output_height * output_width); + const int64_t out_y = (patch_index / output_width) % output_height; + const int64_t out_x = patch_index % output_width; + const T1 *input_batch_start = + input_data + (batch * input_height * input_width * input_depth); + const int in_y_origin = (out_y * stride_rows) - filter_top_offset; + const int in_x_origin = (out_x * stride_cols) - filter_left_offset; + const int patch_index_within_chunk = patch_index % patches_per_chunk; + T1 *im2col_patch_start = + im2col_buffer.get() + + (patch_index_within_chunk * filter_value_count); + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + filter_y; + T1 *im2col_row_start = + im2col_patch_start + (filter_y * filter_width * input_depth); + // If we're off the top or the bottom of the input, fill the + // whole row with zeroes. + if ((in_y < 0) || (in_y >= input_height)) { + T1 *im2col_row_end = + im2col_row_start + (filter_width * input_depth); + std::fill(im2col_row_start, im2col_row_end, T1(0)); + } else { + // What we're doing here is trying to copy and fill the im2col + // buffer as efficiently as possible, using functions to set or + // duplicate values en masse. We know we don't have to worry about + // vertical edges because we dealt with that case above, so we + // just need to handle filters that overlap the left or right + // edges. Here's what that looks like: + // + // < left_zero_count > < center_copy_count > < right_zero_count > + // +------------------+---------------------+--------------------+ + // | (filter) | (image) | (filter) | + // +------------------+---------------------+--------------------+ + // in_x_origin 0 input_width in_x_end + // + // In reality it's unlikely that a filter patch will be wider + // than an input, but this shows all the edge cases. + // We use std::fill() to set the left and right sections to zeroes + // and std::copy() to copy over the input data for the center. + const int in_x_end = in_x_origin + filter_width; + const int left_zero_count = std::max(0, 0 - in_x_origin); + const int right_zero_count = std::max(0, in_x_end - input_width); + const int center_copy_count = + filter_width - (left_zero_count + right_zero_count); + if (left_zero_count > 0) { + T1 *im2col_left_start = im2col_row_start; + T1 *im2col_left_end = + im2col_left_start + (left_zero_count * input_depth); + std::fill(im2col_left_start, im2col_left_end, T1(0)); + } + if (center_copy_count > 0) { + const T1 *input_row_start = + input_batch_start + (in_y * input_width * input_depth) + + (std::max(0, in_x_origin) * input_depth); + const T1 *input_row_end = + input_row_start + (center_copy_count * input_depth); + T1 *im2col_center_start = + im2col_row_start + (left_zero_count * input_depth); + std::copy(input_row_start, input_row_end, im2col_center_start); + } + if (right_zero_count > 0) { + T1 *im2col_right_start = + im2col_row_start + + ((left_zero_count + center_copy_count) * input_depth); + T1 *im2col_right_end = + im2col_right_start + (right_zero_count * input_depth); + std::fill(im2col_right_start, im2col_right_end, T1(0)); + } + } + } + } + // Now we've assembled a set of image patches into a matrix, apply a + // GEMM matrix multiply of the patches as rows, times the filter + // weights in columns, to get partial results in the output matrix. + const int how_many_patches = patch_index_end - patch_index_start; + const int m = how_many_patches; + const int n = filter_count; + const int k = filter_value_count; + const int lda = filter_value_count; + const int ldb = filter_count; + const int ldc = filter_count; + T3 *chunk_output_data = output_data + (patch_index_start * filter_count); + TGemmFunctor gemm_functor; + gemm_functor(m, n, k, im2col_buffer.get(), lda, filter_data, ldb, + chunk_output_data, ldc); + } + } +}; + +} // namespace + +template struct Conv2DUsingGemmOp { + Conv2DUsingGemmOp() : data_format_("") {} + std::vector strides_; + Padding padding_; + std::string data_format_; +}; + +template +void *Conv2DUsingGemmOp_Create(TF_OpKernelConstruction *ctx) { + auto kernel = new Conv2DUsingGemmOp(); + + StatusSafePtr status(TF_NewStatus()); + int32_t list_size = 0; + int32_t total_size = 0; + + // Get strides + TF_OpKernelConstruction_GetAttrSize(ctx, "strides", &list_size, &total_size, + status.get()); + CHECK_CONSTRUCT_STATUS(ctx, status.get()); + kernel->strides_.resize(list_size); + TF_OpKernelConstruction_GetAttrInt32List( + ctx, "strides", kernel->strides_.data(), list_size, status.get()); + CHECK_CONSTRUCT_STATUS(ctx, status.get()); + + // Get data_format + TF_OpKernelConstruction_GetAttrSize(ctx, "data_format", &list_size, + &total_size, status.get()); + CHECK_CONSTRUCT_STATUS(ctx, status.get()); + std::vector format_vec(total_size); + TF_OpKernelConstruction_GetAttrString(ctx, "data_format", format_vec.data(), + total_size, status.get()); + CHECK_CONSTRUCT_STATUS(ctx, status.get()); + kernel->data_format_ = std::move(std::string(format_vec.data(), total_size)); + + // Get padding + TF_OpKernelConstruction_GetAttrSize(ctx, "padding", &list_size, &total_size, + status.get()); + CHECK_CONSTRUCT_STATUS(ctx, status.get()); + std::vector padding_vec(total_size); + TF_OpKernelConstruction_GetAttrString(ctx, "padding", padding_vec.data(), + total_size, status.get()); + CHECK_CONSTRUCT_STATUS(ctx, status.get()); + std::string padding_str(padding_vec.data(), total_size); + if (padding_str == "VALID") { + kernel->padding_ = Padding::VALID; + } else if (padding_str == "SAME") { + kernel->padding_ = Padding::SAME; + } else { + std::cerr << "Unsupported padding type: " << padding_str; + return nullptr; + } + return kernel; +} + +template void Conv2DUsingGemmOp_Delete(void *kernel) { + if (kernel != nullptr) { + delete static_cast *>(kernel); + } +} + +template +void Conv2DUsingGemmOp_Compute(void *kernel, TF_OpKernelContext *ctx) { + StatusSafePtr status(TF_NewStatus()); + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + TF_Tensor *input = nullptr; + TF_GetInput(ctx, 0, &input, status.get()); + CHECK_CTX_STATUS(ctx, status.get()); + TensorSafePtr input_safe_ptr(input); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, out_depth] + TF_Tensor *filter = nullptr; + TF_GetInput(ctx, 1, &filter, status.get()); + CHECK_CTX_STATUS(ctx, status.get()); + TensorSafePtr filter_safe_ptr(filter); + + if (TF_NumDims(input) != 4) { + std::cerr << "input must be 4 dimensional" << std::endl; + return; + } + if (TF_NumDims(filter) != 4) { + std::cerr << "filter must be 4 dimensional" << std::endl; + return; + } + + for (int i = 0; i < 3; i++) { + if (TF_Dim(filter, i) >= std::numeric_limits::max()) { + std::cerr << "filter too large" << std::endl; + return; + } + } + + // The last dimension for input is in_depth. It must be the same as the + // filter's in_depth. + const int64_t in_depth = GetTensorDim( + input, static_cast *>(kernel)->data_format_, 'C'); + if (in_depth != TF_Dim(filter, 2)) { + std::cerr << "input and filter must have the same depth" << std::endl; + return; + } + + // The last dimension for filter is out_depth. + const int out_depth = static_cast(TF_Dim(filter, 3)); + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64_t input_rows_raw = GetTensorDim( + input, static_cast *>(kernel)->data_format_, 'H'); + if (input_rows_raw >= std::numeric_limits::max()) { + std::cerr << "Input rows too large"; + return; + } + const int input_rows = static_cast(input_rows_raw); + const int filter_rows = static_cast(TF_Dim(filter, 0)); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64_t input_cols_raw = GetTensorDim( + input, static_cast *>(kernel)->data_format_, 'W'); + if (input_cols_raw >= std::numeric_limits::max()) { + std::cerr << "Input cols too large" << std::endl; + return; + } + const int input_cols = static_cast(input_cols_raw); + const int filter_cols = static_cast(TF_Dim(filter, 1)); + + // The first dimension for input is batch. + const int64_t batch_raw = GetTensorDim( + input, static_cast *>(kernel)->data_format_, 'N'); + if (batch_raw >= std::numeric_limits::max()) { + std::cerr << "batch is too large" << std::endl; + return; + } + const int batch = static_cast(batch_raw); + + // For now we take the stride from the second and third dimensions only (we + // do not support striding on the batch or depth dimension). + int stride_rows = 0; + int stride_cols = 0; + if (static_cast *>(kernel)->data_format_ == "NCHW") { + stride_rows = static_cast *>(kernel)->strides_[2]; + stride_cols = static_cast *>(kernel)->strides_[3]; + } else if (static_cast *>(kernel)->data_format_ == + "NHWC") { + stride_rows = static_cast *>(kernel)->strides_[1]; + stride_cols = static_cast *>(kernel)->strides_[2]; + } else { + std::cerr << "Unsupported data format" << std::endl; + return; + } + + int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + if (!GetWindowedOutputSize( + input_rows, filter_rows, 1, stride_rows, + static_cast *>(kernel)->padding_, &out_rows, + &pad_rows)) { + std::cerr << "Invalid filter size" << std::endl; + return; + } + + if (!GetWindowedOutputSize( + input_cols, filter_cols, 1, stride_cols, + static_cast *>(kernel)->padding_, &out_cols, + &pad_cols)) { + std::cerr << "Invalid filter size" << std::endl; + return; + } + auto output_size = batch * out_rows * out_cols * out_depth; + std::vector out_shape; + out_shape.push_back(batch); + if (static_cast *>(kernel)->data_format_ == "NCHW") { + out_shape.push_back(out_depth); + out_shape.push_back(out_rows); + out_shape.push_back(out_cols); + } else if (static_cast *>(kernel)->data_format_ == + "NHWC") { + out_shape.push_back(out_rows); + out_shape.push_back(out_cols); + out_shape.push_back(out_depth); + } else { + std::cerr << "Unsupported data_foramt" << std::endl; + return; + } + + // Output tensor is of the following dimensions: + // [ in_batch, out_rows, out_cols, out_depth ]`` + TensorSafePtr output_safe_ptr(TF_AllocateOutput( + ctx, 0, TF_ExpectedOutputDataType(ctx, 0), out_shape.data(), + out_shape.size(), sizeof(T) * output_size, status.get())); + + // If there is nothing to compute, return. + if (output_size == 0) { + return; + } + TConvFunctor conv_functor; + conv_functor(static_cast(TF_TensorData(input_safe_ptr.get())), batch, + input_rows, input_cols, in_depth, + static_cast(TF_TensorData(filter_safe_ptr.get())), + filter_rows, filter_cols, out_depth, stride_rows, stride_cols, + static_cast *>(kernel)->padding_, + static_cast(TF_TensorData(output_safe_ptr.get())), out_rows, + out_cols); +}; +template void RegisterConvOpKernel(const char *device_type) { + StatusSafePtr status(TF_NewStatus()); + auto *builder = TF_NewKernelBuilder( + "Conv2D", device_type, &Conv2DUsingGemmOp_Create, + &Conv2DUsingGemmOp_Compute< + T, Im2ColConvFunctor>>, + &Conv2DUsingGemmOp_Delete); + TF_KernelBuilder_TypeConstraint(builder, "T", TypeToEnum::v(), + status.get()); + if (TF_OK != TF_GetCode(status.get())) + std::cout << " Error while registering relu kernel with attribute T"; + TF_RegisterKernelBuilder("Conv2DOp", builder, status.get()); + if (TF_OK != TF_GetCode(status.get())) + std::cout << " Error while registering relu kernel"; +} + +} // namespace demo_plugin + +void RegisterDeviceConv2D(const char *device_type) { + demo_plugin::RegisterConvOpKernel(device_type); +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/cpu_kernel_init.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/cpu_kernel_init.cc new file mode 100644 index 000000000..d5447806e --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/cpu_kernel_init.cc @@ -0,0 +1,22 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include "tensorflow_plugin/src/kernels/cpu/cpu_kernel_init.h" +#include "tensorflow/c/kernels.h" + +void RegisterDeviceKernels(const char* device_type) { + RegisterDeviceRelu(device_type); + RegisterDeviceConv2D(device_type); +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/cpu_kernel_init.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/cpu_kernel_init.h new file mode 100644 index 000000000..30a02f7c1 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/cpu_kernel_init.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_KERNELS_CPU_KERNEL_INIT_H_ +#define TENSORFLOW_PLUGIN_SRC_KERNELS_CPU_KERNEL_INIT_H_ + +#include + +#include "tensorflow_plugin/src/device/cpu/cpu_device_plugin.h" + +void RegisterDeviceRelu(const char* device_type); +void RegisterDeviceConv2D(const char* device_type); + +void RegisterDeviceKernels(const char* device_type); +#endif // TENSORFLOW_PLUGIN_SRC_KERNELS_GPU_KERNEL_INIT_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/gemm_functors.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/gemm_functors.h new file mode 100644 index 000000000..e2d75618b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/gemm_functors.h @@ -0,0 +1,60 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_KERNEL_CPU_H_ +#define TENSORFLOW_PLUGIN_SRC_KERNEL_CPU_H_ + +//#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +// FixedPoint header must be included after Tensor. +// clang-format off +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" +// clang-format on + +template class FastGemmFunctor { +public: + void operator()(size_t m, size_t n, size_t k, const T1 *a, size_t lda, + const T2 *b, size_t ldb, T3 *c, size_t ldc) { + Eigen::array dim_a = {{m, k}}; + Eigen::array dim_b = {{k, n}}; + Eigen::array dim_c = {{m, n}}; + Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + a_matrix(a, dim_a); + Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + b_matrix(b, dim_b); + Eigen::TensorMap, + Eigen::Aligned> + c_matrix(c, dim_c); + + Eigen::array, 1> dim_pair; + dim_pair[0].first = 1; + dim_pair[0].second = 0; + Eigen::ThreadPool tp(8); + Eigen::ThreadPoolDevice thread_pool_device(&tp, 8); + c_matrix.device(thread_pool_device) = a_matrix.contract(b_matrix, dim_pair); + } +}; + +#endif // TENSORFLOW_PLUGIN_SRC_KERNEL_CPU_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/relu_op.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/relu_op.cc new file mode 100644 index 000000000..41f64bd25 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/kernels/cpu/relu_op.cc @@ -0,0 +1,107 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include "absl/container/inlined_vector.h" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/ops.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" + +#include +#include +#include + +namespace demo_plugin { + +struct StatusDeleter { + void operator()(TF_Status* s) { + if (s != nullptr) { + TF_DeleteStatus(s); + } + } +}; + +struct TensorDeleter { + void operator()(TF_Tensor* t) { + if (t != nullptr) { + TF_DeleteTensor(t); + } + } +}; + +using StatusSafePtr = std::unique_ptr; +using TensorSafePtr = std::unique_ptr; + +struct ReluOp { + float alpha_; + float beta_; +}; + +void* ReluOp_Create(ReluOp* kernel, float alpha, float beta) { + kernel->alpha_ = alpha; + kernel->beta_ = beta; + return kernel; +} + +template +void ReluOp_Compute(void* kernel, TF_OpKernelContext* ctx) { + ReluOp* relu = static_cast(kernel); + StatusSafePtr status(TF_NewStatus()); + TF_Tensor* input = nullptr; + TF_GetInput(ctx, 0, &input, status.get()); + TensorSafePtr input_safe_ptr(input); + if (TF_GetCode(status.get()) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + if (TF_TensorElementCount(input_safe_ptr.get()) == 0) return; + absl::InlinedVector dims(TF_NumDims(input_safe_ptr.get())); + for (auto i = 0; i < TF_NumDims(input_safe_ptr.get()); ++i) { + dims[i] = TF_Dim(input_safe_ptr.get(), i); + } + + TensorSafePtr output_safe_ptr(TF_AllocateOutput( + ctx, 0, TF_ExpectedOutputDataType(ctx, 0), dims.data(), dims.size(), + TF_TensorElementCount(input_safe_ptr.get()) * sizeof(T), status.get())); + if (TF_GetCode(status.get()) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status.get()); + return; + } + + auto input_ptr = static_cast(TF_TensorData(input_safe_ptr.get())); + auto output_ptr = static_cast(TF_TensorData(output_safe_ptr.get())); + for (auto i = 0; i < TF_TensorElementCount(input_safe_ptr.get()); ++i) { + output_ptr[i] = input_ptr[i] > 0 ? input_ptr[i] : 0; + } +} + +template +void RegisterReluOpKernel(const char* device_type) { + StatusSafePtr status(TF_NewStatus()); + auto* builder = TF_NewKernelBuilder("Relu", device_type, nullptr, + &ReluOp_Compute, nullptr); + TF_KernelBuilder_TypeConstraint(builder, "T", TF_FLOAT, status.get()); + if (TF_OK != TF_GetCode(status.get())) + std::cout << " Error while registering relu kernel with attribute T"; + TF_RegisterKernelBuilder("ReluOp", builder, status.get()); + if (TF_OK != TF_GetCode(status.get())) + std::cout << " Error while registering relu kernel"; +} + +} // namespace demo_plugin + +void RegisterDeviceRelu(const char* device_type) { + demo_plugin::RegisterReluOpKernel(device_type); +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_device.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_device.cc new file mode 100644 index 000000000..bd9efe672 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_device.cc @@ -0,0 +1,29 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow_plugin/src/device/cpu/cpu_device_plugin.h" + +#include "plugin_device.h" + +void SE_InitPlugin(SE_PlatformRegistrationParams* const params, + TF_Status* const status) { + params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE; + params->platform->name = DEVICE_NAME; + params->platform->type = DEVICE_TYPE; + SE_InitPluginFns(params, status); +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_device.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_device.h new file mode 100644 index 000000000..b9bb0f4d9 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_device.h @@ -0,0 +1,23 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_PLUGIN_DEVICE_H_ +#define TENSORFLOW_PLUGIN_SRC_PLUGIN_DEVICE_H_ + +constexpr char DEVICE_TYPE[] = "MY_DEVICE"; +constexpr char DEVICE_NAME[] = "FAKE_CPU_DEVICE"; + +#endif // TENSORFLOW_PLUGIN_SRC_PLUGIN_DEVICE_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_graph.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_graph.cc new file mode 100644 index 000000000..49c3ee6fa --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_graph.cc @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include "tensorflow/c/experimental/grappler/grappler.h" +#include "tensorflow_plugin/src/graph/plugin_optimizer.h" + +void TF_InitGraphPlugin(TP_OptimizerRegistrationParams* params, + TF_Status* status) { + params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE; + params->optimizer_configs->struct_size = TP_OPTIMIZER_CONFIGS_STRUCT_SIZE; + params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE; + + // Define some configs to turn off existing optimizers. + params->optimizer_configs->remapping = TF_TriState_Off; + params->optimizer_configs->layout_optimizer = TF_TriState_Off; + + // Set functions to create a new optimizer. + params->device_type = "MY_DEVICE"; + params->optimizer->create_func = (demo_plugin::graph::Optimizer_Create); + params->optimizer->optimize_func = (demo_plugin::graph::Optimizer_Optimize); + params->optimizer->destroy_func = (demo_plugin::graph::Optimizer_Destroy); +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_kernel.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_kernel.cc new file mode 100644 index 000000000..3ca6eaa37 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_kernel.cc @@ -0,0 +1,23 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include +#include "tensorflow/c/kernels.h" +#include "tensorflow_plugin/src/kernels/cpu/cpu_kernel_init.h" + +#include "plugin_device.h" + +void TF_InitKernel() { RegisterDeviceKernels(DEVICE_TYPE); } diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_kernel.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_kernel.h new file mode 100644 index 000000000..aacccc633 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_kernel.h @@ -0,0 +1,22 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_PLUGIN_KERNEL_H_ +#define TENSORFLOW_PLUGIN_SRC_PLUGIN_KERNEL_H_ + +void TF_InitKernel(); + +#endif // TENSORFLOW_PLUGIN_SRC_PLUGIN_KERNEL_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_profiler.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_profiler.cc new file mode 100644 index 000000000..b5e925f3c --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/plugin_profiler.cc @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "plugin_device.h" +#include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.h" +#include "tensorflow_plugin/src/profiler/cpu/demo_profiler.h" + +void TF_InitProfiler(TF_ProfilerRegistrationParams *params, TF_Status *status) { + params->struct_size = TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE; + params->struct_size = TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE; + params->profiler_fns->struct_size = TP_PROFILER_FNS_STRUCT_SIZE; + params->profiler->type = + DEVICE_TYPE; // type is device type, such as GPU, APU.. + TF_InitPluginProfilerFns(params, status); +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/BUILD new file mode 100644 index 000000000..aaef631d1 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/BUILD @@ -0,0 +1,16 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "demo_profiler", + srcs = ["demo_profiler.cc"], + hdrs = ["demo_profiler.h"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "@local_config_tf//:tf_header_lib", + "//tensorflow_plugin/src/utils:protos_all", + "//tensorflow_plugin/src/profiler/cpu/utils:xplane_utils", + ], + alwayslink = True, +) + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/demo_profiler.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/demo_profiler.cc new file mode 100644 index 000000000..c2d46851a --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/demo_profiler.cc @@ -0,0 +1,67 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_plugin/src/profiler/cpu/demo_profiler.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/time_utils.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/xplane_utils.h" +static void NormalizeTimeStamps(demo_plugin::profiler::XPlaneBuilder* plane, + uint64_t start_walltime_ns) { + plane->ForEachLine([&](demo_plugin::profiler::XLineBuilder line) { + line.SetTimestampNs(start_walltime_ns); + }); +} + +uint64_t start_walltime; + + +void plugin_start(const TP_Profiler* profiler, TF_Status* status) { + start_walltime = demo_plugin::profiler::GetCurrentTimeNanos(); +} + +void plugin_stop(const TP_Profiler* profiler, TF_Status* status) { +} + + + +void plugin_collect_data_xspace(const TP_Profiler* profiler, uint8_t* buffer, + size_t* size_in_bytes, TF_Status* status) { + demo_plugin::profiler::PerDeviceCollector collector(0, start_walltime); + demo_plugin::profiler::XSpace space; + std::string name = "/device:GPU:0"; + demo_plugin::profiler::XPlaneBuilder device_plane(demo_plugin::profiler::FindOrAddMutablePlaneWithName(&space, name)); + device_plane.SetId(0); + collector.Flush(&device_plane); + NormalizeTimeStamps(&device_plane, start_walltime); + + *size_in_bytes = space.ByteSizeLong(); + if (buffer == nullptr) { + return; + } + space.SerializeToArray(buffer, space.ByteSizeLong()); +} + +void plugin_destroy_profiler(TP_Profiler* profiler) {} + +void plugin_destroy_profiler_fns(TP_ProfilerFns* profiler_fns) {} + + +void TF_InitPluginProfilerFns(TF_ProfilerRegistrationParams* params, TF_Status* status) { + params->profiler_fns->start = plugin_start; + params->profiler_fns->stop = plugin_stop; + params->profiler_fns->collect_data_xspace = plugin_collect_data_xspace; + params->destroy_profiler = plugin_destroy_profiler; + params->destroy_profiler_fns = plugin_destroy_profiler_fns; +} + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/demo_profiler.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/demo_profiler.h new file mode 100644 index 000000000..b18143d6d --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/demo_profiler.h @@ -0,0 +1,67 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_CPU_PROFILER_PLUGIN_H_ +#define TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_CPU_PROFILER_PLUGIN_H_ +#include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.h" +#include "tensorflow_plugin/src/utils/xplane.pb.h" +namespace demo_plugin { +namespace profiler { +class PerDeviceCollector { + public: + PerDeviceCollector(int device_id, uint64_t start_walltime_ns) + : device_id_(device_id),start_walltime_ns_(start_walltime_ns) {} + + void CreateXEvent(XPlaneBuilder* plane, XLineBuilder* line) { + // Just provide a dummy case here, plugin authors need to get kernel + // execution profing data from their own device runtime. + std::string kernel_name = "DummyKernel"; + XEventMetadata* event_metadata = plane->GetOrCreateEventMetadata(std::move(kernel_name)); + XEventBuilder xevent = line->AddEvent(*event_metadata); + xevent.SetTimestampNs(10000 + start_walltime_ns_); + xevent.SetEndTimestampNs(1000000 + start_walltime_ns_); + + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata(std::string("SIMD width")),8); + } + + void Flush(XPlaneBuilder* device_plane) { + int64_t line_id = 0; + XLineBuilder line = device_plane->GetOrCreateLine(line_id); + line.SetTimestampNs(start_walltime_ns_); + CreateXEvent(device_plane, &line); + + device_plane->ForEachLine([&](XLineBuilder line) { + line.SetName("PluginDevice stream"); + }); + } + + private: + int device_id_; + uint64_t start_walltime_ns_; +}; + +} // namespace profiler +} // namespace demo_plugin +void TF_InitPluginProfilerFns(TF_ProfilerRegistrationParams* params, TF_Status* status); + + + +#endif // TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_CPU_PROFILER_PLUGIN_H_ + + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/BUILD new file mode 100644 index 000000000..224148ae8 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/BUILD @@ -0,0 +1,96 @@ +package(default_visibility = ["//visibility:public"]) + + + +cc_library( + name = "time_utils_impl", + srcs = [ + "time_utils.cc", + "time_utils.h", + ], + deps = [ + "@com_google_absl//absl/time", + "//tensorflow_plugin/src/utils:logging", + "//third_party/eigen3", + ], + alwayslink = True, +) + +cc_library( + name = "xplane_builder", + srcs = ["xplane_builder.cc"], + hdrs = ["xplane_builder.h"], + visibility = ["//visibility:public"], + deps = [ + ":time_utils", + ":timespan", + "//tensorflow_plugin/src/utils:protos_all", + "//tensorflow_plugin/src/utils:logging", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + + +cc_library( + name = "xplane_visitor", + srcs = ["xplane_visitor.cc"], + hdrs = ["xplane_visitor.h"], + visibility = ["//visibility:public"], + deps = [ + ":time_utils", + ":timespan", + "//tensorflow_plugin/src/utils:protos_all", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "time_utils", + hdrs = ["time_utils.h"], + deps = [ + ":time_utils_impl", + ], +) + + +cc_library( + name = "trace_utils", + hdrs = ["trace_utils.h"], + deps = [ + ], +) + + +cc_library( + name = "timespan", + hdrs = ["timespan.h"], + deps = [ + ":time_utils", + "@com_google_absl//absl/strings", + ], +) + + +cc_library( + name = "xplane_utils", + srcs = ["xplane_utils.cc"], + hdrs = ["xplane_utils.h"], + visibility = ["//visibility:public"], + deps = [ + ":time_utils", + ":timespan", + ":trace_utils", + ":xplane_builder", + ":xplane_visitor", + "//tensorflow_plugin/src/utils:protos_all", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "//third_party/eigen3", + ], +) + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/time_utils.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/time_utils.cc new file mode 100644 index 000000000..a7174a924 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/time_utils.cc @@ -0,0 +1,26 @@ +#include "tensorflow_plugin/src/profiler/cpu/utils/time_utils.h" + +#include "absl/time/clock.h" +#include "absl/time/time.h" + +namespace demo_plugin { +namespace profiler { + +int64_t GetCurrentTimeNanos() { + // absl::GetCurrentTimeNanos() is much faster than EnvTime::NowNanos(). + // It is wrapped under tensorflow::profiler::GetCurrentTimeNanos to avoid ODR + // violation and to allow switching to yet another implementation if required. + return absl::GetCurrentTimeNanos(); +} + +void SleepForNanos(int64_t ns) { absl::SleepFor(absl::Nanoseconds(ns)); } + +void SpinForNanos(int64_t ns) { + if (ns <= 0) return; + int64_t deadline = GetCurrentTimeNanos() + ns; + while (GetCurrentTimeNanos() < deadline) { + } +} + +} // namespace profiler +} // namespace demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/time_utils.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/time_utils.h new file mode 100644 index 000000000..cd35f6642 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/time_utils.h @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TIME_UTILS_H_ +#define TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TIME_UTILS_H_ + +#include "tensorflow_plugin/src/utils/types.h" + +namespace demo_plugin { +namespace profiler { + +// Converts among different time units. +// NOTE: We use uint64 for picoseconds and nanoseconds, which are used in +// storage, and double for other units that are used in the UI. +inline double PicosToNanos(uint64 ps) { return ps / 1E3; } +inline double PicosToMicros(uint64 ps) { return ps / 1E6; } +inline double PicosToMillis(uint64 ps) { return ps / 1E9; } +inline double PicosToSeconds(uint64 ps) { return ps / 1E12; } +inline uint64 NanosToPicos(uint64 ns) { return ns * 1000; } +inline double NanosToMicros(uint64 ns) { return ns / 1E3; } +inline double MicrosToNanos(double us) { return us * 1E3; } +inline double MicrosToMillis(double us) { return us / 1E3; } +inline uint64 MillisToPicos(double ms) { return ms * 1E9; } +inline uint64 MillisToNanos(double ms) { return ms * 1E6; } +inline double MillisToSeconds(double ms) { return ms / 1E3; } +inline uint64 SecondsToNanos(double s) { return s * 1E9; } + +// Returns the current CPU wallclock time in nanoseconds. +int64_t GetCurrentTimeNanos(); + +// Sleeps for the specified duration. +void SleepForNanos(int64_t ns); +inline void SleepForMicros(int64_t us) { SleepForNanos(us * 1000); } +inline void SleepForMillis(int64_t ms) { SleepForNanos(ms * 1000000); } +inline void SleepForSeconds(int64_t s) { SleepForNanos(s * 1000000000); } + +// Spins to simulate doing some work instead of sleeping, because sleep +// precision is poor. For testing only. +void SpinForNanos(int64_t ns); +inline void SpinForMicros(int64_t us) { SpinForNanos(us * 1000); } + +} // namespace profiler +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TIME_UTILS_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/timespan.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/timespan.h new file mode 100644 index 000000000..cbda11f56 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/timespan.h @@ -0,0 +1,117 @@ +#ifndef TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TIMESPAN_H_ +#define TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TIMESPAN_H_ + +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow_plugin/src/utils/logging.h" +#include "tensorflow_plugin/src/utils/types.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/time_utils.h" + +namespace demo_plugin { +namespace profiler { + +// A Timespan is the time extent of an event: a pair of (begin, duration). +// Events may have duration 0 ("instant events") but duration can't be negative. +class Timespan { + public: + static Timespan FromEndPoints(uint64 begin_ps, uint64 end_ps) { + DCHECK_LE(begin_ps, end_ps); + return Timespan(begin_ps, end_ps - begin_ps); + } + + explicit Timespan(uint64 begin_ps = 0, uint64 duration_ps = 0) + : begin_ps_(begin_ps), duration_ps_(duration_ps) {} + + uint64 begin_ps() const { return begin_ps_; } + uint64 middle_ps() const { return begin_ps_ + duration_ps_ / 2; } + uint64 end_ps() const { return begin_ps_ + duration_ps_; } + uint64 duration_ps() const { return duration_ps_; } + + // Returns true if the Timespan represents an instant in time (duration 0). + bool Instant() const { return duration_ps() == 0; } + + // Returns true if this is an empty timespan. + bool Empty() const { return begin_ps() == 0 && duration_ps() == 0; } + + // Note for Overlaps() and Includes(Timespan& other) below: + // We have a design choice whether the end-point comparison should be + // inclusive or exclusive. We decide to go for inclusive. The implication + // is that an instant timespan could belong to two consecutive intervals + // (e.g., Timespan(12, 0) will be included in both Timespan(11, 1) and + // Timespan(12, 1)). We think this is okay because the common scenario + // would be that we search for the interval that includes a point + // in time from left to right, and return the first interval found. + + // Returns true if the Timespan overlaps with other. + bool Overlaps(const Timespan& other) const { + return begin_ps() <= other.end_ps() && other.begin_ps() <= end_ps(); + } + + // Returns true if this Timespan includes the other. + bool Includes(const Timespan& other) const { + return begin_ps() <= other.begin_ps() && other.end_ps() <= end_ps(); + } + + // Returns true if time_ps is within this Timespan. + bool Includes(uint64 time_ps) const { return Includes(Timespan(time_ps)); } + + // Returns the duration in ps that this Timespan overlaps with the other. + uint64 OverlappedDurationPs(const Timespan& other) const { + if (!Overlaps(other)) return 0; + return std::min(end_ps(), other.end_ps()) - + std::max(begin_ps(), other.begin_ps()); + } + + // Expands the timespan to include other. + void ExpandToInclude(const Timespan& other) { + *this = FromEndPoints(std::min(begin_ps(), other.begin_ps()), + std::max(end_ps(), other.end_ps())); + } + + // Compares timespans by their begin time (ascending), duration (descending) + // so nested spans are sorted from outer to innermost. + bool operator<(const Timespan& other) const { + if (begin_ps_ < other.begin_ps_) return true; + if (begin_ps_ > other.begin_ps_) return false; + return duration_ps_ > other.duration_ps_; + } + + // Returns true if this timespan is equal to the given timespan. + bool operator==(const Timespan& other) const { + return begin_ps_ == other.begin_ps_ && duration_ps_ == other.duration_ps_; + } + + // Returns a string that shows the begin and end times. + std::string DebugString() const { + return absl::StrCat("[", begin_ps(), ", ", end_ps(), "]"); + } + + // Compares timespans by their duration_ps (ascending), begin time + // (ascending). + static bool ByDuration(const Timespan& a, const Timespan& b) { + if (a.duration_ps_ < b.duration_ps_) return true; + if (a.duration_ps_ > b.duration_ps_) return false; + return a.begin_ps_ < b.begin_ps_; + } + + private: + uint64 begin_ps_; + uint64 duration_ps_; // 0 for an instant event. +}; + +// Creates a Timespan from endpoints in picoseconds. +inline Timespan PicoSpan(uint64 start_ps, uint64 end_ps) { + return Timespan::FromEndPoints(start_ps, end_ps); +} + +// Creates a Timespan from endpoints in milliseconds. +inline Timespan MilliSpan(double start_ms, double end_ms) { + return PicoSpan(MillisToPicos(start_ms), MillisToPicos(end_ms)); +} + +} // namespace profiler +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TIMESPAN_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/trace_utils.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/trace_utils.h new file mode 100644 index 000000000..d8c8a453a --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/trace_utils.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TRACE_UTILS_H_ +#define TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TRACE_UTILS_H_ + +#include "tensorflow_plugin/src/utils/integral_types.h" +namespace demo_plugin { +namespace profiler { + +// Constants used as trace_viewer PID (device_id in trace_events.proto). +// PID 0 is unused. +// Support up to 500 accelerator devices. +constexpr uint32 kFirstDeviceId = 1; +constexpr uint32 kLastDeviceId = 500; +// Host threads are shown as a single fake device. +constexpr uint32 kHostThreadsDeviceId = kLastDeviceId + 1; + +// Constants used as trace_viewer TID (resource_id in trace_events.proto). +constexpr int kThreadIdDerivedMin = 0xdeadbeef; +constexpr int kThreadIdStepInfo = kThreadIdDerivedMin; +constexpr int kThreadIdKernelLaunch = kThreadIdDerivedMin + 1; +constexpr int kThreadIdTfNameScope = kThreadIdDerivedMin + 2; +constexpr int kThreadIdTfOp = kThreadIdDerivedMin + 3; +constexpr int kThreadIdHloModule = kThreadIdDerivedMin + 4; +constexpr int kThreadIdHloOp = kThreadIdDerivedMin + 5; +constexpr int kThreadIdOverhead = kThreadIdDerivedMin + 6; +constexpr int kThreadIdSource = kThreadIdDerivedMin + 7; +constexpr int kThreadIdDerivedMax = kThreadIdSource; + +static inline bool IsDerivedThreadId(int thread_id) { + return thread_id >= kThreadIdDerivedMin && thread_id <= kThreadIdDerivedMax; +} + +} // namespace profiler +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_TRACE_UTILS_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.cc new file mode 100644 index 000000000..e59f8148b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.cc @@ -0,0 +1,127 @@ +#include "tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow_plugin/src/utils/types.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/time_utils.h" +#include "tensorflow_plugin/src/utils/xplane.pb.h" +namespace demo_plugin { +namespace profiler { + +XPlaneBuilder::XPlaneBuilder(XPlane* plane) + : XStatsBuilder(plane, this), plane_(plane) { + for (auto& id_and_metadata : *plane->mutable_event_metadata()) { + auto& metadata = id_and_metadata.second; + last_event_metadata_id_ = + std::max(last_event_metadata_id_, metadata.id()); + if (!metadata.name().empty()) { + event_metadata_by_name_.try_emplace(metadata.name(), &metadata); + } + } + for (auto& id_and_metadata : *plane->mutable_stat_metadata()) { + auto& metadata = id_and_metadata.second; + last_stat_metadata_id_ = + std::max(last_stat_metadata_id_, metadata.id()); + if (!metadata.name().empty()) { + stat_metadata_by_name_.try_emplace(metadata.name(), &metadata); + } + } + for (XLine& line : *plane->mutable_lines()) { + lines_by_id_.try_emplace(line.id(), &line); + } +} + +XEventMetadata* XPlaneBuilder::GetOrCreateEventMetadata(int64 metadata_id) { + XEventMetadata& metadata = (*plane_->mutable_event_metadata())[metadata_id]; + metadata.set_id(metadata_id); + return &metadata; +} + +XEventMetadata* XPlaneBuilder::CreateEventMetadata() { + return GetOrCreateEventMetadata(++last_event_metadata_id_); +} + +XEventMetadata* XPlaneBuilder::GetOrCreateEventMetadata( + absl::string_view name) { + XEventMetadata*& metadata = event_metadata_by_name_[name]; + if (metadata == nullptr) { + metadata = CreateEventMetadata(); + metadata->set_name(std::string(name)); + } + return metadata; +} + +XEventMetadata* XPlaneBuilder::GetOrCreateEventMetadata(std::string&& name) { + XEventMetadata*& metadata = event_metadata_by_name_[name]; + if (metadata == nullptr) { + metadata = CreateEventMetadata(); + metadata->set_name(std::move(name)); + } + return metadata; +} + +XStatMetadata* XPlaneBuilder::GetOrCreateStatMetadata(int64 metadata_id) { + XStatMetadata& metadata = (*plane_->mutable_stat_metadata())[metadata_id]; + metadata.set_id(metadata_id); + return &metadata; +} + +XStatMetadata* XPlaneBuilder::CreateStatMetadata() { + return GetOrCreateStatMetadata(++last_stat_metadata_id_); +} + +XStatMetadata* XPlaneBuilder::GetOrCreateStatMetadata(absl::string_view name) { + XStatMetadata*& metadata = stat_metadata_by_name_[name]; + if (metadata == nullptr) { + metadata = CreateStatMetadata(); + metadata->set_name(std::string(name)); + } + return metadata; +} + +XStatMetadata* XPlaneBuilder::GetOrCreateStatMetadata(std::string&& name) { + XStatMetadata*& metadata = stat_metadata_by_name_[name]; + if (metadata == nullptr) { + metadata = CreateStatMetadata(); + metadata->set_name(std::move(name)); + } + return metadata; +} + +XLineBuilder XPlaneBuilder::GetOrCreateLine(int64 line_id) { + XLine*& line = lines_by_id_[line_id]; + if (line == nullptr) { + line = plane_->add_lines(); + line->set_id(line_id); + } + return XLineBuilder(line, this); +} + +XEventBuilder XLineBuilder::AddEvent(const XEventMetadata& metadata) { + XEvent* event = line_->add_events(); + event->set_metadata_id(metadata.id()); + return XEventBuilder(line_, plane_, event); +} + +XEventBuilder XLineBuilder::AddEvent(const XEvent& event) { + XEvent* new_event = line_->add_events(); + *new_event = event; + return XEventBuilder(line_, plane_, new_event); +} + +void XLineBuilder::SetTimestampNsAndAdjustEventOffsets(int64 timestamp_ns) { + int64 offset_ps = NanosToPicos(line_->timestamp_ns() - timestamp_ns); + line_->set_timestamp_ns(timestamp_ns); + if (offset_ps) { + for (auto& event : *line_->mutable_events()) { + event.set_offset_ps(event.offset_ps() + offset_ps); + } + } +} + +} // namespace profiler +} // namespace demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.h new file mode 100644 index 000000000..2a5b81d07 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.h @@ -0,0 +1,339 @@ +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_XPLANE_BUILDER_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_XPLANE_BUILDER_H_ + +#include + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "tensorflow_plugin/src/utils/macros.h" +#include "tensorflow_plugin/src/utils/protobuf.h" +#include "tensorflow_plugin/src/utils/types.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/time_utils.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/timespan.h" +#include "tensorflow_plugin/src/utils/xplane.pb.h" + +namespace demo_plugin { +namespace profiler { + +class XPlaneBuilder; + +template +class XStatsBuilder { + public: + explicit XStatsBuilder(T* stats_owner, XPlaneBuilder* stats_metadata_owner) + : stats_owner_(stats_owner), + stats_metadata_owner_(stats_metadata_owner) {} + + void AddStatValue(const XStatMetadata& metadata, uint32 value) { + AddStat(metadata)->set_uint64_value(value); + } + void AddStatValue(const XStatMetadata& metadata, + unsigned long value) { // NOLINT + AddStat(metadata)->set_uint64_value(value); + } + void AddStatValue(const XStatMetadata& metadata, + unsigned long long value) { // NOLINT + AddStat(metadata)->set_uint64_value(value); + } + void AddStatValue(const XStatMetadata& metadata, int32 value) { + AddStat(metadata)->set_int64_value(value); + } + void AddStatValue(const XStatMetadata& metadata, long value) { // NOLINT + AddStat(metadata)->set_int64_value(value); + } + void AddStatValue(const XStatMetadata& metadata, long long value) { // NOLINT + AddStat(metadata)->set_int64_value(value); + } + void AddStatValue(const XStatMetadata& metadata, double value) { + AddStat(metadata)->set_double_value(value); + } + void AddStatValue(const XStatMetadata& metadata, absl::string_view value) { + AddStat(metadata)->set_str_value(std::string(value)); + } + void AddStatValue(const XStatMetadata& metadata, std::string&& value) { + AddStat(metadata)->set_str_value(std::move(value)); + } + void AddStatValue(const XStatMetadata& metadata, const XStatMetadata& value) { + AddStat(metadata)->set_ref_value(value.id()); + } + void AddStatValue(const XStatMetadata& metadata, + const protobuf::MessageLite& proto) { + auto* bytes = AddStat(metadata)->mutable_bytes_value(); + proto.SerializeToString(bytes); + } + + // Adds a stat by copying a stat from another XPlane. Does not check if a stat + // with the same metadata already exists in the event. To avoid duplicated + // stats, use the variant below. + void AddStat(const XStatMetadata& metadata, const XStat& src_stat, + const XPlane& src_plane) { + CopyStatValue(src_stat, src_plane, AddStat(metadata)); + } + // Same as above but overrides an existing stat with the same metadata. + void SetOrAddStat(const XStatMetadata& metadata, const XStat& src_stat, + const XPlane& src_plane) { + CopyStatValue(src_stat, src_plane, FindOrAddStat(metadata)); + } + + void ParseAndAddStatValue(const XStatMetadata& metadata, + absl::string_view value) { + int64 int_value; + uint64 uint_value; + double double_value; + if (absl::SimpleAtoi(value, &int_value)) { + AddStatValue(metadata, int_value); + } else if (absl::SimpleAtoi(value, &uint_value)) { + AddStatValue(metadata, uint_value); + } else if (absl::SimpleAtod(value, &double_value)) { + AddStatValue(metadata, double_value); + } else { + AddStatValue(metadata, GetOrCreateStatMetadata(value)); + } + } + + void ReserveStats(size_t num_stats) { + stats_owner_->mutable_stats()->Reserve(num_stats); + } + + private: + XStat* AddStat(const XStatMetadata& metadata) { + XStat* stat = stats_owner_->add_stats(); + stat->set_metadata_id(metadata.id()); + return stat; + } + + XStat* FindOrAddStat(const XStatMetadata& metadata) { + for (auto& stat : *stats_owner_->mutable_stats()) { + if (stat.metadata_id() == metadata.id()) { + return &stat; + } + } + return AddStat(metadata); + } + + void CopyStatValue(const XStat& src_stat, const XPlane& src_plane, + XStat* dst_stat) { + switch (src_stat.value_case()) { + case XStat::VALUE_NOT_SET: + break; + case XStat::kInt64Value: + dst_stat->set_int64_value(src_stat.int64_value()); + break; + case XStat::kUint64Value: + dst_stat->set_uint64_value(src_stat.uint64_value()); + break; + case XStat::kDoubleValue: + dst_stat->set_double_value(src_stat.double_value()); + break; + case XStat::kStrValue: + dst_stat->set_str_value(src_stat.str_value()); + break; + case XStat::kRefValue: { + const auto& stat_metadata_by_id = src_plane.stat_metadata(); + const auto it = stat_metadata_by_id.find(src_stat.ref_value()); + if (TF_PREDICT_TRUE(it != stat_metadata_by_id.end())) { + absl::string_view value = it->second.name(); + dst_stat->set_ref_value(GetOrCreateStatMetadata(value).id()); + } + break; + } + case XStat::kBytesValue: + dst_stat->set_bytes_value(src_stat.bytes_value()); + break; + } + } + + const XStatMetadata& GetOrCreateStatMetadata(absl::string_view value); + + T* stats_owner_; + XPlaneBuilder* stats_metadata_owner_; +}; + +class XEventBuilder : public XStatsBuilder { + public: + XEventBuilder(const XLine* line, XPlaneBuilder* plane, XEvent* event) + : XStatsBuilder(event, plane), line_(line), event_(event) {} + + int64 OffsetPs() const { return event_->offset_ps(); } + int64 MetadataId() const { return event_->metadata_id(); } + + void SetOffsetPs(int64 offset_ps) { event_->set_offset_ps(offset_ps); } + + void SetOffsetNs(int64 offset_ns) { SetOffsetPs(NanosToPicos(offset_ns)); } + + void SetTimestampNs(int64 timestamp_ns) { + SetOffsetPs(NanosToPicos(timestamp_ns - line_->timestamp_ns())); + } + + void SetNumOccurrences(int64 num_occurrences) { + event_->set_num_occurrences(num_occurrences); + } + + void SetDurationPs(int64 duration_ps) { + event_->set_duration_ps(duration_ps); + } + void SetDurationNs(int64 duration_ns) { + SetDurationPs(NanosToPicos(duration_ns)); + } + + void SetEndTimestampPs(int64 end_timestamp_ps) { + SetDurationPs(end_timestamp_ps - PicosToNanos(line_->timestamp_ns()) - + event_->offset_ps()); + } + void SetEndTimestampNs(int64 end_timestamp_ns) { + SetDurationPs(NanosToPicos(end_timestamp_ns - line_->timestamp_ns()) - + event_->offset_ps()); + } + + Timespan GetTimespan() const { + return Timespan(NanosToPicos(line_->timestamp_ns()) + event_->offset_ps(), + event_->duration_ps()); + } + + private: + const XLine* line_; + XEvent* event_; +}; + +class XLineBuilder { + public: + explicit XLineBuilder(XLine* line, XPlaneBuilder* plane) + : line_(line), plane_(plane) {} + + // Returns the owner plane. + XPlaneBuilder* Plane() const { return plane_; } + + int64 Id() const { return line_->id(); } + void SetId(int64 id) { line_->set_id(id); } + + int64 NumEvents() const { return line_->events_size(); } + + void SetName(absl::string_view name) { line_->set_name(std::string(name)); } + + void SetNameIfEmpty(absl::string_view name) { + if (line_->name().empty()) SetName(name); + } + + int64 TimestampNs() const { return line_->timestamp_ns(); } + // This will set the line start timestamp. + // WARNING: The offset_ps of existing events will not be altered. + void SetTimestampNs(int64 timestamp_ns) { + line_->set_timestamp_ns(timestamp_ns); + } + // This will set the line start timestamp to specific time, and adjust + // the offset_ps of all existing events. + void SetTimestampNsAndAdjustEventOffsets(int64 timestamp_ns); + + void SetDurationPs(int64 duration_ps) { line_->set_duration_ps(duration_ps); } + + void ReserveEvents(size_t num_events) { + line_->mutable_events()->Reserve(num_events); + } + + void SetDisplayNameIfEmpty(absl::string_view display_name) { + if (line_->display_name().empty()) { + line_->set_display_name(std::string(display_name)); + } + } + + XEventBuilder AddEvent(const XEventMetadata& metadata); + XEventBuilder AddEvent(const XEvent& event); + + private: + XLine* line_; + XPlaneBuilder* plane_; +}; + +// Provides methods to build an XPlane. +// NOTE: avoid to use two builders to wrap the same XPlane. +class XPlaneBuilder : public XStatsBuilder { + public: + explicit XPlaneBuilder(XPlane* plane); + + int64 Id() const { return plane_->id(); } + void SetId(int64 id) { plane_->set_id(id); } + + void SetName(absl::string_view name) { plane_->set_name(std::string(name)); } + + void ReserveLines(size_t num_lines) { + plane_->mutable_lines()->Reserve(num_lines); + } + + template + void ForEachLine(ForEachLineFunc&& for_each_line) { + for (XLine& line : *plane_->mutable_lines()) { + for_each_line(XLineBuilder(&line, this)); + } + } + + // Returns a builder for the line with the given id. Creates a new line if the + // id was unused, otherwise the builder will add events to an existing line. + XLineBuilder GetOrCreateLine(int64 line_id); + + // Returns a new event metadata with an automatically generated metadata_id. + // WARNING: If calling this function, don't call GetOrCreateEventMetadata. + XEventMetadata* CreateEventMetadata(); + + // Returns event metadata with the given id. Creates a new metadata if the id + // was unused. + // WARNING: If calling this function, don't call the string overloads below + // on the same instance. + XEventMetadata* GetOrCreateEventMetadata(int64 metadata_id); + + // Returns event metadata with the given name. The id is internally assigned. + // Creates a new metadata if the name was unused. + // Using these overloads guarantees names are unique. + // WARNING: If calling any of these overloads, do not call the integer one + // above on the same instance. + XEventMetadata* GetOrCreateEventMetadata(absl::string_view name); + XEventMetadata* GetOrCreateEventMetadata(std::string&& name); + XEventMetadata* GetOrCreateEventMetadata(const char* name) { + return GetOrCreateEventMetadata(absl::string_view(name)); + } + + // Returns a new stat metadata with an automatically generated metadata_id. + // WARNING: If calling this function, don't call GetOrCreateEventMetadata. + XStatMetadata* CreateStatMetadata(); + + // Returns stat metadata with the given id. Creates a new metadata if the id + // was unused. + // WARNING: If calling this function, don't call the string overloads below + // on the same instance. + XStatMetadata* GetOrCreateStatMetadata(int64 metadata_id); + + // Returns stat metadata with the given name. The id is internally assigned. + // Creates a new metadata if the name was unused. + // Using these overloads guarantees names are unique. + // WARNING: If calling any of these overloads, do not call the integer one + // above on the same instance. + XStatMetadata* GetOrCreateStatMetadata(absl::string_view name); + XStatMetadata* GetOrCreateStatMetadata(std::string&& name); + XStatMetadata* GetOrCreateStatMetadata(const char* name) { + return GetOrCreateStatMetadata(absl::string_view(name)); + } + + private: + XPlane* plane_; + + // Artifacts to accelerate the builders. + int64 last_event_metadata_id_ = 0LL; + int64 last_stat_metadata_id_ = 0LL; + absl::flat_hash_map event_metadata_by_name_; + absl::flat_hash_map stat_metadata_by_name_; + absl::flat_hash_map lines_by_id_; +}; + +template +const XStatMetadata& XStatsBuilder::GetOrCreateStatMetadata( + absl::string_view value) { + return *stats_metadata_owner_->GetOrCreateStatMetadata(value); +} + +} // namespace profiler +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_XPLANE_BUILDER_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_utils.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_utils.cc new file mode 100644 index 000000000..6e5eb7997 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_utils.cc @@ -0,0 +1,322 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_plugin/src/profiler/cpu/utils/xplane_utils.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "tensorflow_plugin/src/utils/logging.h" +#include "tensorflow_plugin/src/utils/protobuf.h" +#include "tensorflow_plugin/src/utils/types.h" +#include "tensorflow_plugin/src/utils/xplane.pb.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/time_utils.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/timespan.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/xplane_builder.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/xplane_visitor.h" + +namespace demo_plugin { +namespace profiler { +namespace { + +// Returns the index of the first element in array for which pred is true. +// Returns -1 if no such element is found. +template +int Find(const protobuf::RepeatedPtrField& array, const Pred& pred) { + for (int i = 0; i < array.size(); ++i) { + if (pred(&array.Get(i))) return i; + } + return -1; +} + +// Returns the indices of all elements in array for which pred is true. +template +std::vector FindAll(const protobuf::RepeatedPtrField& array, + const Pred& pred) { + std::vector indices; + for (int i = 0; i < array.size(); ++i) { + if (pred(&array.Get(i))) indices.push_back(i); + } + return indices; +} + +template +void RemoveAt(protobuf::RepeatedPtrField* array, + const std::vector& indices) { + if (indices.empty()) return; + if (array->size() == indices.size()) { + // Assumes that 'indices' consists of [0 ... N-1]. + array->Clear(); + return; + } + auto remove_iter = indices.begin(); + int i = *(remove_iter++); + for (int j = i + 1; j < array->size(); ++j) { + if (remove_iter != indices.end() && *remove_iter == j) { + ++remove_iter; + } else { + array->SwapElements(j, i++); + } + } + array->DeleteSubrange(i, array->size() - i); +} + +// Removes the given element from array. +template +void Remove(protobuf::RepeatedPtrField* array, const T* elem) { + int i = Find(*array, [elem](const T* e) { return elem == e; }); + RemoveAt(array, {i}); +} + +template +void RemoveIf(protobuf::RepeatedPtrField* array, Pred&& pred) { + std::vector indices = FindAll(*array, pred); + RemoveAt(array, indices); +} + +} // namespace + +const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) { + int i = Find(space.planes(), + [name](const XPlane* plane) { return plane->name() == name; }); + return (i != -1) ? &space.planes(i) : nullptr; +} + +std::vector FindPlanesWithNames( + const XSpace& space, const std::vector& names) { + absl::flat_hash_set names_set(names.begin(), names.end()); + std::vector indices = + FindAll(space.planes(), [&names_set](const XPlane* plane) { + return names_set.contains(plane->name()); + }); + std::vector planes; + planes.reserve(indices.size()); + for (int i : indices) { + planes.push_back(&space.planes(i)); + } + return planes; +} + +XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name) { + int i = Find(space->planes(), + [name](const XPlane* plane) { return plane->name() == name; }); + return (i != -1) ? space->mutable_planes(i) : nullptr; +} + +XPlane* FindOrAddMutablePlaneWithName(XSpace* space, absl::string_view name) { + XPlane* plane = FindMutablePlaneWithName(space, name); + if (plane == nullptr) { + plane = space->add_planes(); + plane->set_name(name.data(), name.size()); + } + return plane; +} + +std::vector FindPlanesWithPrefix(const XSpace& space, + absl::string_view prefix) { + std::vector result; + for (const XPlane& plane : space.planes()) { + if (absl::StartsWith(plane.name(), prefix)) result.push_back(&plane); + } + return result; +} + +std::vector FindMutablePlanesWithPrefix(XSpace* space, + absl::string_view prefix) { + std::vector result; + for (XPlane& plane : *space->mutable_planes()) { + if (absl::StartsWith(plane.name(), prefix)) result.push_back(&plane); + } + return result; +} + +const XLine* FindLineWithId(const XPlane& plane, int64_t id) { + int i = + Find(plane.lines(), [id](const XLine* line) { return line->id() == id; }); + return (i != -1) ? &plane.lines(i) : nullptr; +} + +XStat* FindOrAddMutableStat(const XStatMetadata& stat_metadata, XEvent* event) { + for (auto& stat : *event->mutable_stats()) { + if (stat.metadata_id() == stat_metadata.id()) { + return &stat; + } + } + XStat* stat = event->add_stats(); + stat->set_metadata_id(stat_metadata.id()); + return stat; +} + +void RemovePlane(XSpace* space, const XPlane* plane) { + DCHECK(plane != nullptr); + Remove(space->mutable_planes(), plane); +} + +void RemovePlanes(XSpace* space, const std::vector& planes) { + absl::flat_hash_set planes_set(planes.begin(), planes.end()); + RemoveIf(space->mutable_planes(), [&planes_set](const XPlane* plane) { + return planes_set.contains(plane); + }); +} + +void RemoveLine(XPlane* plane, const XLine* line) { + DCHECK(line != nullptr); + Remove(plane->mutable_lines(), line); +} + +void RemoveEvents(XLine* line, + const absl::flat_hash_set& events) { + RemoveIf(line->mutable_events(), + [&](const XEvent* event) { return events.contains(event); }); +} + +void RemoveEmptyPlanes(XSpace* space) { + RemoveIf(space->mutable_planes(), + [&](const XPlane* plane) { return plane->lines().empty(); }); +} + +void RemoveEmptyLines(XPlane* plane) { + RemoveIf(plane->mutable_lines(), + [&](const XLine* line) { return line->events().empty(); }); +} + +bool XEventsComparator::operator()(const XEvent* a, const XEvent* b) const { + return XEventTimespan(*a) < XEventTimespan(*b); +} + +void SortXPlane(XPlane* plane) { + for (XLine& line : *plane->mutable_lines()) { + auto& events = *line.mutable_events(); + std::sort(events.pointer_begin(), events.pointer_end(), + XEventsComparator()); + } +} + +void SortXSpace(XSpace* space) { + for (XPlane& plane : *space->mutable_planes()) SortXPlane(&plane); +} + +// Normalize the line's timestamp in this XPlane. +// NOTE: This can be called multiple times on the same plane. Only the first +// call will do the normalization, subsequent calls will do nothing. +// The assumption is that both line's timestamp_ns and start_time_ns are +// nano-seconds from epoch time, the different of these values is much +// smaller than these value. +void NormalizeTimestamps(XPlane* plane, uint64 start_time_ns) { + for (XLine& line : *plane->mutable_lines()) { + if (line.timestamp_ns() >= static_cast(start_time_ns)) { + line.set_timestamp_ns(line.timestamp_ns() - start_time_ns); + } + } +} + +void NormalizeTimestamps(XSpace* space, uint64 start_time_ns) { + for (XPlane& plane : *space->mutable_planes()) { + NormalizeTimestamps(&plane, start_time_ns); + } +} + +void MergePlanes(const XPlane& src_plane, XPlane* dst_plane) { + RemoveEmptyLines(dst_plane); + XPlaneVisitor src(&src_plane); + XPlaneBuilder dst(dst_plane); + src.ForEachStat([&](const demo_plugin::profiler::XStatVisitor& stat) { + XStatMetadata* stat_metadata = dst.GetOrCreateStatMetadata(stat.Name()); + // Use SetOrAddStat to avoid duplicating stats in dst_plane. + dst.SetOrAddStat(*stat_metadata, stat.RawStat(), src_plane); + }); + src.ForEachLine([&](const demo_plugin::profiler::XLineVisitor& line) { + XLineBuilder dst_line = dst.GetOrCreateLine(line.Id()); + int64_t time_offset_ps = 0LL; + if (dst_line.NumEvents() == 0) { + // Since we RemoveEmptyLines above, this could only mean that current + // line only exist in src plane. + dst_line.SetTimestampNs(line.TimestampNs()); + dst_line.SetName(line.Name()); + dst_line.SetDisplayNameIfEmpty(line.DisplayName()); + } else { + if (line.TimestampNs() <= dst_line.TimestampNs()) { + dst_line.SetTimestampNsAndAdjustEventOffsets(line.TimestampNs()); + } else { + time_offset_ps = + NanosToPicos(line.TimestampNs() - dst_line.TimestampNs()); + } + dst_line.SetNameIfEmpty(line.Name()); + // Don't override dst_line's display name because if both lines have name, + // but no display name, line's name will became display name of dst_line. + } + + line.ForEachEvent([&](const demo_plugin::profiler::XEventVisitor& event) { + const XEventMetadata* src_event_metadata = event.metadata(); + XEventMetadata* dst_event_metadata = + dst.GetOrCreateEventMetadata(event.Name()); + if (dst_event_metadata->display_name().empty() && + !src_event_metadata->display_name().empty()) { + dst_event_metadata->set_display_name( + src_event_metadata->display_name()); + } + if (dst_event_metadata->metadata().empty() && + !src_event_metadata->metadata().empty()) { + dst_event_metadata->set_metadata(src_event_metadata->metadata()); + } + XEventBuilder dst_event = dst_line.AddEvent(*dst_event_metadata); + dst_event.SetOffsetPs(event.OffsetPs() + time_offset_ps); + dst_event.SetDurationPs(event.DurationPs()); + if (event.NumOccurrences()) { + dst_event.SetNumOccurrences(event.NumOccurrences()); + } + event.ForEachStat([&](const demo_plugin::profiler::XStatVisitor& stat) { + // Here we can call AddStat instead of SetOrAddStat because dst_event + // was just added. + dst_event.AddStat(*dst.GetOrCreateStatMetadata(stat.Name()), + stat.RawStat(), src_plane); + }); + }); + }); +} + +void MergePlanes(const std::vector& src_planes, + XPlane* dst_plane) { + for (const XPlane* src_plane : src_planes) { + MergePlanes(*src_plane, dst_plane); + } +} + +uint64 GetStartTimestampNs(const XPlane& plane) { + int64_t plane_timestamp = 0; + for (const auto& line : plane.lines()) { + plane_timestamp = std::min(plane_timestamp, line.timestamp_ns()); + } + return plane_timestamp; +} + +bool IsEmpty(const XSpace& space) { + for (const auto& plane : space.planes()) { + for (const auto& line : plane.lines()) { + if (!line.events().empty()) { + return false; + } + } + } + return true; +} + +} // namespace profiler +} // namespace demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_utils.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_utils.h new file mode 100644 index 000000000..fdcae14f3 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_utils.h @@ -0,0 +1,120 @@ +#ifndef TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_XPLANE_UTILS_H_ +#define TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_XPLANE_UTILS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow_plugin/src/utils/xplane.pb.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/timespan.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/trace_utils.h" + +namespace demo_plugin { +namespace profiler { + +// Returns a Timespan from an XEvent. +// WARNING: This should only be used when comparing events from the same XLine. +inline Timespan XEventTimespan(const XEvent& event) { + return Timespan(event.offset_ps(), event.duration_ps()); +} + +// Returns the plane with the given name or nullptr if not found. +const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name); +XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name); + +// Returns the planes with the given names, if found. +std::vector FindPlanesWithNames( + const XSpace& space, const std::vector& names); + +// Returns the plane with the given name in the container. If necessary, adds a +// new plane to the container. +XPlane* FindOrAddMutablePlaneWithName(XSpace* space, absl::string_view name); + +// Returns all the planes with a given prefix. +std::vector FindPlanesWithPrefix(const XSpace& space, + absl::string_view prefix); +std::vector FindMutablePlanesWithPrefix(XSpace* space, + absl::string_view prefix); + +// Returns the plane with the given id or nullptr if not found. +const XLine* FindLineWithId(const XPlane& plane, int64_t id); + +XStat* FindOrAddMutableStat(const XStatMetadata& stat_metadata, XEvent* event); + +void RemovePlane(XSpace* space, const XPlane* plane); +void RemovePlanes(XSpace* space, const std::vector& planes); +void RemoveLine(XPlane* plane, const XLine* line); +void RemoveEvents(XLine* line, + const absl::flat_hash_set& events); + +void RemoveEmptyPlanes(XSpace* space); +void RemoveEmptyLines(XPlane* plane); + +// Sort lines in plane with a provided comparator. +template +void SortXLinesBy(XPlane* plane, Compare comp) { + std::sort(plane->mutable_lines()->pointer_begin(), + plane->mutable_lines()->pointer_end(), comp); +} + +class XLinesComparatorByName { + public: + bool operator()(const XLine* a, const XLine* b) const { + auto& line_a = a->display_name().empty() ? a->name() : a->display_name(); + auto& line_b = b->display_name().empty() ? b->name() : b->display_name(); + return line_a < line_b; + } +}; + +// Sorts each XLine's XEvents by offset_ps (ascending) and duration_ps +// (descending) so nested events are sorted from outer to innermost. +void SortXPlane(XPlane* plane); +// Sorts each plane of the XSpace. +void SortXSpace(XSpace* space); + +// Functor that compares XEvents for sorting by timespan. +struct XEventsComparator { + bool operator()(const XEvent* a, const XEvent* b) const; +}; + +// Returns a sorted vector of all XEvents in the given XPlane. +template +std::vector GetSortedEvents(XPlane* plane, Compare comp, + bool include_derived_events = false) { + std::vector events; + for (XLine& line : *plane->mutable_lines()) { + if (!include_derived_events && IsDerivedThreadId(line.id())) continue; + for (XEvent& event : *line.mutable_events()) { + events.push_back(&event); + } + } + absl::c_sort(events, XEventsComparator()); + return events; +} + +// Normalize timestamps by time-shifting to start_time_ns_ as origin. +void NormalizeTimestamps(XPlane* plane, uint64 start_time_ns); +void NormalizeTimestamps(XSpace* space, uint64 start_time_ns); + +// Merges src_plane into dst_plane. Both plane level stats, lines, events and +// event level stats are merged. If src_plane and dst_plane both have the same +// line, which have different start timestamps, we will normalize the events +// offset timestamp correspondingly. +void MergePlanes(const XPlane& src_plane, XPlane* dst_plane); + +// Merges each plane with a src_planes, into the dst_plane. +void MergePlanes(const std::vector& src_planes, + XPlane* dst_plane); + +// Plane's start timestamp is defined as the minimum of all lines' start +// timestamps. If zero line exists, return 0; +uint64 GetStartTimestampNs(const XPlane& plane); + +// Returns true if there are no XEvents. +bool IsEmpty(const XSpace& space); + +} // namespace profiler +} // namespace demo_profiler + +#endif // TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_XPLANE_UTILS_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_visitor.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_visitor.cc new file mode 100644 index 000000000..8a7b604f8 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_visitor.cc @@ -0,0 +1,146 @@ +#include "tensorflow_plugin/src/profiler/cpu/utils/xplane_visitor.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow_plugin/src/utils/xplane.pb.h" + +namespace demo_plugin { +namespace profiler { + +XStatVisitor::XStatVisitor(const XPlaneVisitor* plane, const XStat* stat) + : XStatVisitor(plane, stat, plane->GetStatMetadata(stat->metadata_id()), + plane->GetStatType(stat->metadata_id())) {} + +XStatVisitor::XStatVisitor(const XPlaneVisitor* plane, const XStat* stat, + const XStatMetadata* metadata, + absl::optional type) + : stat_(stat), metadata_(metadata), plane_(plane), type_(type) {} + +std::string XStatVisitor::ToString() const { + switch (stat_->value_case()) { + case XStat::kInt64Value: + return absl::StrCat(stat_->int64_value()); + case XStat::kUint64Value: + return absl::StrCat(stat_->uint64_value()); + case XStat::kDoubleValue: + return absl::StrCat(stat_->double_value()); + case XStat::kStrValue: + return stat_->str_value(); + case XStat::kBytesValue: + return ""; + case XStat::kRefValue: + return plane_->GetStatMetadata(stat_->ref_value())->name(); + case XStat::VALUE_NOT_SET: + return ""; + } +} + +absl::string_view XStatVisitor::StrOrRefValue() const { + switch (stat_->value_case()) { + case XStat::kStrValue: + return stat_->str_value(); + case XStat::kRefValue: + return plane_->GetStatMetadata(stat_->ref_value())->name(); + case XStat::kInt64Value: + case XStat::kUint64Value: + case XStat::kDoubleValue: + case XStat::kBytesValue: + case XStat::VALUE_NOT_SET: + return absl::string_view(); + } +} + +XEventVisitor::XEventVisitor(const XPlaneVisitor* plane, const XLine* line, + const XEvent* event) + : XStatsOwner(plane, event), + plane_(plane), + line_(line), + event_(event), + metadata_(plane->GetEventMetadata(event_->metadata_id())), + type_(plane->GetEventType(event_->metadata_id())) {} + +XPlaneVisitor::XPlaneVisitor(const XPlane* plane, + const TypeGetterList& event_type_getter_list, + const TypeGetterList& stat_type_getter_list) + : XStatsOwner(this, plane), plane_(plane) { + BuildEventTypeMap(plane, event_type_getter_list); + BuildStatTypeMap(plane, stat_type_getter_list); +} + +void XPlaneVisitor::BuildEventTypeMap( + const XPlane* plane, const TypeGetterList& event_type_getter_list) { + for (const auto& event_metadata : plane->event_metadata()) { + uint64 metadata_id = event_metadata.first; + const auto& metadata = event_metadata.second; + for (const auto& event_type_getter : event_type_getter_list) { + absl::optional event_type = event_type_getter(metadata.name()); + if (event_type.has_value()) { + auto result = event_type_by_id_.emplace(metadata_id, *event_type); + DCHECK(result.second); // inserted + break; + } + } + } +} + +const XEventMetadata* XPlaneVisitor::GetEventMetadata( + int64_t event_metadata_id) const { + const auto& event_metadata_by_id = plane_->event_metadata(); + const auto it = event_metadata_by_id.find(event_metadata_id); + if (it != event_metadata_by_id.end()) return &it->second; + return &XEventMetadata::default_instance(); +} + +absl::optional XPlaneVisitor::GetEventType( + int64_t event_metadata_id) const { + const auto it = event_type_by_id_.find(event_metadata_id); + if (it != event_type_by_id_.end()) return it->second; + return absl::nullopt; +} + +void XPlaneVisitor::BuildStatTypeMap( + const XPlane* plane, const TypeGetterList& stat_type_getter_list) { + for (const auto& stat_metadata : plane->stat_metadata()) { + uint64 metadata_id = stat_metadata.first; + const auto& metadata = stat_metadata.second; + for (const auto& stat_type_getter : stat_type_getter_list) { + absl::optional stat_type = stat_type_getter(metadata.name()); + if (stat_type.has_value()) { + auto result = stat_type_by_id_.emplace(metadata_id, *stat_type); + DCHECK(result.second); // inserted + stat_metadata_by_type_.emplace(*stat_type, &metadata); + break; + } + } + } +} + +const XStatMetadata* XPlaneVisitor::GetStatMetadata( + int64_t stat_metadata_id) const { + const auto& stat_metadata_by_id = plane_->stat_metadata(); + const auto it = stat_metadata_by_id.find(stat_metadata_id); + if (it != stat_metadata_by_id.end()) return &it->second; + return &XStatMetadata::default_instance(); +} + +absl::optional XPlaneVisitor::GetStatType( + int64_t stat_metadata_id) const { + const auto it = stat_type_by_id_.find(stat_metadata_id); + if (it != stat_type_by_id_.end()) return it->second; + return absl::nullopt; +} + +const XStatMetadata* XPlaneVisitor::GetStatMetadataByType( + int64_t stat_type) const { + const auto it = stat_metadata_by_type_.find(stat_type); + if (it != stat_metadata_by_type_.end()) return it->second; + return nullptr; +} + +} // namespace profiler +} // namespace demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_visitor.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_visitor.h new file mode 100644 index 000000000..0efd2f99e --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/profiler/cpu/utils/xplane_visitor.h @@ -0,0 +1,305 @@ +#ifndef TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_XPLANE_VISITOR_H_ +#define TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_XPLANE_VISITOR_H_ + +#include + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow_plugin/src/utils/types.h" +#include "tensorflow_plugin/src/utils/xplane.pb.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/time_utils.h" +#include "tensorflow_plugin/src/profiler/cpu/utils/timespan.h" + +namespace demo_plugin { +namespace profiler { + +class XPlaneVisitor; + +class XStatVisitor { + public: + // REQUIRED: plane and stat cannot be nullptr. + XStatVisitor(const XPlaneVisitor* plane, const XStat* stat); + + // REQUIRED: plane, stat and metadata cannot be nullptr. + XStatVisitor(const XPlaneVisitor* plane, const XStat* stat, + const XStatMetadata* metadata, absl::optional type); + + int64_t Id() const { return stat_->metadata_id(); } + + absl::string_view Name() const { return metadata_->name(); } + + absl::optional Type() const { return type_; } + + absl::string_view Description() const { return metadata_->description(); } + + XStat::ValueCase ValueCase() const { return stat_->value_case(); } + + int64_t IntValue() const { return stat_->int64_value(); } + + uint64 UintValue() const { return stat_->uint64_value(); } + + uint64 IntOrUintValue() const { + return ValueCase() == XStat::kUint64Value ? UintValue() + : static_cast(IntValue()); + } + + double DoubleValue() const { return stat_->double_value(); } + + // Returns a string view. + // REQUIRED: the value type should be string type or reference type. + absl::string_view StrOrRefValue() const; + + const XStat& RawStat() const { return *stat_; } + + // Return a string representation of all value type. + std::string ToString() const; + + private: + const XStat* stat_; + const XStatMetadata* metadata_; + const XPlaneVisitor* plane_; + absl::optional type_; +}; + +template +class XStatsOwner { + public: + // REQUIRED: plane and stats_owner cannot be nullptr. + XStatsOwner(const XPlaneVisitor* plane, const T* stats_owner) + : plane_(plane), stats_owner_(stats_owner) {} + + // For each stat, call the specified lambda. + template + void ForEachStat(ForEachStatFunc&& for_each_stat) const { + for (const XStat& stat : stats_owner_->stats()) { + for_each_stat(XStatVisitor(plane_, &stat)); + } + } + + // Shortcut to get a specific stat type, nullopt if absent. + // This function performs a linear search for the requested stat value. + // Prefer ForEachStat above when multiple stat values are necessary. + absl::optional GetStat(int64_t stat_type) const; + + // Same as above that skips searching for the stat. + absl::optional GetStat( + int64_t stat_type, const XStatMetadata& stat_metadata) const { + for (const XStat& stat : stats_owner_->stats()) { + if (stat.metadata_id() == stat_metadata.id()) { + return XStatVisitor(plane_, &stat, &stat_metadata, stat_type); + } + } + return absl::nullopt; // type does not exist in this owner. + } + + protected: + const XPlaneVisitor* plane() const { return plane_; } + const T* stats_owner() const { return stats_owner_; } + + private: + const XPlaneVisitor* plane_; + const T* stats_owner_; +}; + +class XEventMetadataVisitor : public XStatsOwner { + public: + // REQUIRED: plane and metadata cannot be nullptr. + XEventMetadataVisitor(const XPlaneVisitor* plane, + const XEventMetadata* metadata) + : XStatsOwner(plane, metadata) {} + + absl::string_view Name() const { return metadata()->name(); } + + bool HasDisplayName() const { return !metadata()->display_name().empty(); } + + absl::string_view DisplayName() const { return metadata()->display_name(); } + + // For each child event metadata, call the specified lambda. + template + void ForEachChild(ForEachChildFunc&& for_each_child) const; + + private: + const XEventMetadata* metadata() const { return stats_owner(); } +}; + +class XEventVisitor : public XStatsOwner { + public: + // REQUIRED: plane, line and event cannot be nullptr. + XEventVisitor(const XPlaneVisitor* plane, const XLine* line, + const XEvent* event); + + int64_t Id() const { return event_->metadata_id(); } + + absl::string_view Name() const { return metadata_->name(); } + + absl::optional Type() const { return type_; } + + bool HasDisplayName() const { return !metadata_->display_name().empty(); } + + absl::string_view DisplayName() const { return metadata_->display_name(); } + + double OffsetNs() const { return PicosToNanos(event_->offset_ps()); } + + int64_t OffsetPs() const { return event_->offset_ps(); } + + int64_t LineTimestampNs() const { return line_->timestamp_ns(); } + + double TimestampNs() const { return line_->timestamp_ns() + OffsetNs(); } + + int64_t TimestampPs() const { + return NanosToPicos(line_->timestamp_ns()) + event_->offset_ps(); + } + + double DurationNs() const { return PicosToNanos(event_->duration_ps()); } + + int64_t DurationPs() const { return event_->duration_ps(); } + + int64_t EndOffsetPs() const { + return event_->offset_ps() + event_->duration_ps(); + } + int64_t EndTimestampPs() const { return TimestampPs() + DurationPs(); } + + int64_t NumOccurrences() const { return event_->num_occurrences(); } + + bool operator<(const XEventVisitor& other) const { + return GetTimespan() < other.GetTimespan(); + } + + const XEventMetadata* metadata() const { return metadata_; } + + XEventMetadataVisitor Metadata() const { + return XEventMetadataVisitor(plane_, metadata_); + } + + Timespan GetTimespan() const { return Timespan(TimestampPs(), DurationPs()); } + + private: + const XPlaneVisitor* plane_; + const XLine* line_; + const XEvent* event_; + const XEventMetadata* metadata_; + absl::optional type_; +}; + +class XLineVisitor { + public: + // REQUIRED: plane and line cannot be nullptr. + XLineVisitor(const XPlaneVisitor* plane, const XLine* line) + : plane_(plane), line_(line) {} + + int64_t Id() const { return line_->id(); } + + int64_t DisplayId() const { + return line_->display_id() ? line_->display_id() : line_->id(); + } + + absl::string_view Name() const { return line_->name(); } + + absl::string_view DisplayName() const { + return !line_->display_name().empty() ? line_->display_name() + : line_->name(); + } + + double TimestampNs() const { return line_->timestamp_ns(); } + + int64_t DurationPs() const { return line_->duration_ps(); } + + size_t NumEvents() const { return line_->events_size(); } + + template + void ForEachEvent(ForEachEventFunc&& for_each_event) const { + for (const XEvent& event : line_->events()) { + for_each_event(XEventVisitor(plane_, line_, &event)); + } + } + + private: + const XPlaneVisitor* plane_; + const XLine* line_; +}; + +using TypeGetter = std::function(absl::string_view)>; +using TypeGetterList = std::vector; + +class XPlaneVisitor : public XStatsOwner { + public: + // REQUIRED: plane cannot be nullptr. + explicit XPlaneVisitor( + const XPlane* plane, + const TypeGetterList& event_type_getter_list = TypeGetterList(), + const TypeGetterList& stat_type_getter_list = TypeGetterList()); + + int64_t Id() const { return plane_->id(); } + + absl::string_view Name() const { return plane_->name(); } + + size_t NumLines() const { return plane_->lines_size(); } + + template + void ForEachLine(ForEachLineFunc&& for_each_line) const { + for (const XLine& line : plane_->lines()) { + for_each_line(XLineVisitor(this, &line)); + } + } + + // Returns event metadata given its id. Returns a default value if not found. + const XEventMetadata* GetEventMetadata(int64_t event_metadata_id) const; + + // Returns the type of an event given its id. + absl::optional GetEventType(int64_t event_metadata_id) const; + + // Returns stat metadata given its id. Returns a default value if not found. + const XStatMetadata* GetStatMetadata(int64_t stat_metadata_id) const; + + // Returns stat metadata given its type. Returns nullptr if not found. + // Use as an alternative to GetStatMetadata above. + const XStatMetadata* GetStatMetadataByType(int64_t stat_type) const; + + // Returns the type of an stat given its id. + absl::optional GetStatType(int64_t stat_metadata_id) const; + + private: + void BuildEventTypeMap(const XPlane* plane, + const TypeGetterList& event_type_getter_list); + void BuildStatTypeMap(const XPlane* plane, + const TypeGetterList& stat_type_getter_list); + + const XPlane* plane_; + + absl::flat_hash_map + event_type_by_id_; + absl::flat_hash_map + stat_type_by_id_; + absl::flat_hash_map + stat_metadata_by_type_; +}; + +template +absl::optional XStatsOwner::GetStat(int64_t stat_type) const { + const auto* stat_metadata = plane_->GetStatMetadataByType(stat_type); + if (stat_metadata != nullptr) { + return GetStat(stat_type, *stat_metadata); + } + return absl::nullopt; // type does not exist in this owner. +} + +template +void XEventMetadataVisitor::ForEachChild( + ForEachChildFunc&& for_each_child) const { + for (int64_t child_id : metadata()->child_id()) { + const auto* event_metadata = plane()->GetEventMetadata(child_id); + if (event_metadata != nullptr) { + for_each_child(XEventMetadataVisitor(plane(), event_metadata)); + } + } +} + +} // namespace profiler +} // namespace demo_profiler + +#endif // TENSORFLOW_PLUGIN_SRC_PROFILER_CPU_UTILS_XPLANE_VISITOR_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/BUILD new file mode 100644 index 000000000..091a7d79b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/BUILD @@ -0,0 +1,191 @@ +load("@org_tensorflow_plugin//tensorflow_plugin:build_config.bzl", "cc_proto") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +cc_proto( + name = "types", + src = "types.proto", +) + +cc_proto( + name = "tensor_shape", + src = "tensor_shape.proto", +) + +cc_proto( + name = "versions", + src = "versions.proto", +) + +cc_proto( + name = "cost_graph", + src = "cost_graph.proto", + deps = [ + ":tensor_shape_proto", + ":types_proto", + ], +) + +cc_proto( + name = "resource_handle", + src = "resource_handle.proto", + deps = [ + ":tensor_shape_proto", + ":types_proto", + ], +) + +cc_proto( + name = "tensor", + src = "tensor.proto", + deps = [ + ":resource_handle_proto", + ], +) + +cc_proto( + name = "attr_value", + src = "attr_value.proto", + deps = [ + ":tensor_proto", + ], +) + +cc_proto( + name = "node_def", + src = "node_def.proto", + deps = [ + ":attr_value_proto", + ], +) + +cc_proto( + name = "op_def", + src = "op_def.proto", + deps = [ + ":attr_value_proto", + ], +) + +cc_proto( + name = "function", + src = "function.proto", + deps = [ + ":node_def_proto", + ":op_def_proto", + ], +) + +cc_proto( + name = "graph", + src = "graph.proto", + deps = [ + ":function_proto", + ":node_def_proto", + ":versions_proto", + ], +) + +cc_proto( + name = "device_properties", + src = "device_properties.proto", +) + +cc_proto( + name = "op_performance_data", + src = "op_performance_data.proto", + deps = [ + ":attr_value_proto", + ":device_properties_proto", + ], +) + +cc_proto( + name = "api_def", + src = "api_def.proto", + deps = [ + ":attr_value_proto", + ], +) + +cc_proto( + name = "xplane", + src = "xplane.proto", +) + + + +cc_library( + name = "protos_all", + visibility = ["//visibility:public"], + deps = [ + ":api_def_proto", + ":graph_proto", + ":op_performance_data_proto", + ":xplane_proto", + ], +) + +cc_library( + name = "platform", + hdrs = ["platform.h"], +) + + +cc_library( + name = "prefetch", + hdrs = ["prefetch.h"], + visibility = ["//visibility:public"], + deps = [":platform"], +) + +cc_library( + name = "tstring", + hdrs = [ + "ctstring.h", + "tstring.h", + ], + deps = [ + "@com_google_absl//absl/strings", + "@local_config_tf//:tf_header_lib", + ], +) + +cc_library( + name = "types", + hdrs = ["types.h"], + visibility = ["//visibility:public"], + deps = [ + ":platform", + ":tstring" + ], +) + + +cc_library( + name = "logging", + srcs = ["logging.cc"], + visibility = ["//visibility:public"], + hdrs = [ + "env_time.h", + "integral_types.h", + "logging.h", + "macros.h", + #"numeric_types.h", + "stringpiece.h", + "types.h", + "protobuf.h" + ], + deps = [ + ":platform", + ":tstring", + "//tensorflow_plugin/src/utils/gtl:gtl_libs", + #"//third_party/eigen3", + ":protos_all", + ], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/api_def.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/api_def.proto new file mode 100644 index 000000000..15c81803f --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/api_def.proto @@ -0,0 +1,136 @@ +// Defines the text format for including per-op API definition and +// overrides for client language op code generators. + +syntax = "proto3"; + +package demo_plugin; +option cc_enable_arenas = true; +option java_outer_classname = "ApiDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/api_def_go_proto"; +import "tensorflow_plugin/src/utils/attr_value.proto"; + +// Used to specify and override the default API & behavior in the +// generated code for client languages, from what you would get from +// the OpDef alone. There will be a set of ApiDefs that are common +// to all client languages, and another set per client language. +// The per-client-language ApiDefs will inherit values from the +// common ApiDefs which it can either replace or modify. +// +// We separate the API definition from the OpDef so we can evolve the +// API while remaining backwards compatible when interpretting old +// graphs. Overrides go in an "api_def.pbtxt" file with a text-format +// ApiDefs message. +// +// WARNING: Be *very* careful changing the API for any existing op -- +// you can change the semantics of existing code. These changes may +// need to wait until a major release of TensorFlow to avoid breaking +// our compatibility promises. +message ApiDef { + // Name of the op (in the OpDef) to specify the API for. + string graph_op_name = 1; + // If this op is deprecated, set deprecation message to the message + // that should be logged when this op is used. + // The message should indicate alternative op to use, if any. + string deprecation_message = 12; + // Major version when the op will be deleted. For e.g. set this + // value to 2 if op API should be removed in TensorFlow 2.0 and + // deprecated in versions before that. + int32 deprecation_version = 13; + + enum Visibility { + // Normally this is "VISIBLE" unless you are inheriting a + // different value from another ApiDef. + DEFAULT_VISIBILITY = 0; + // Publicly visible in the API. + VISIBLE = 1; + // Do not include this op in the generated API. If visibility is + // set to 'SKIP', other fields are ignored for this op. + SKIP = 2; + // Hide this op by putting it into an internal namespace (or whatever + // is appropriate in the target language). + HIDDEN = 3; + } + Visibility visibility = 2; + + // If you specify any endpoint, this will replace all of the + // inherited endpoints. The first endpoint should be the + // "canonical" endpoint, and should not be deprecated (unless all + // endpoints are deprecated). + message Endpoint { + // Name should be either like "CamelCaseName" or + // "Package.CamelCaseName". Client-language-specific ApiDefs may + // use a snake_case convention instead of CamelCase. + string name = 1; + + // Set if this endpoint is deprecated. If set to true, a message suggesting + // to use a non-deprecated endpoint instead will be printed. If all + // endpoints are deprecated, set deprecation_message in ApiDef instead. + bool deprecated = 3; + + // Major version when an endpoint will be deleted. For e.g. set this + // value to 2 if endpoint should be removed in TensorFlow 2.0 and + // deprecated in versions before that. + int32 deprecation_version = 4; + } + repeated Endpoint endpoint = 3; + + message Arg { + string name = 1; + + // Change the name used to access this arg in the API from what + // is used in the GraphDef. Note that these names in `backticks` + // will also be replaced in the summary & description fields. + string rename_to = 2; + + // Note: this will replace any inherited arg doc. There is no + // current way of modifying arg descriptions (other than replacing + // them entirely) as can be done with op descriptions. + string description = 3; + } + repeated Arg in_arg = 4; + repeated Arg out_arg = 5; + // List of original in_arg names to specify new argument order. + // Length of arg_order should be either empty to keep current order + // or match size of in_arg. + repeated string arg_order = 11; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message Attr { + string name = 1; + + // Change the name used to access this attr in the API from what + // is used in the GraphDef. Note that these names in `backticks` + // will also be replaced in the summary & description fields. + string rename_to = 2; + + // Specify a new default value to use for this attr. This default + // will be used when creating new graphs, as opposed to the + // default in the OpDef, which will be used when interpreting old + // GraphDefs. + AttrValue default_value = 3; + + // Note: this will replace any inherited attr doc, there is no current + // way of modifying attr descriptions as can be done with op descriptions. + string description = 4; + } + repeated Attr attr = 6; + + // One-line human-readable description of what the Op does. + string summary = 7; + + // Additional, longer human-readable description of what the Op does. + string description = 8; + + // Modify an existing/inherited description by adding text to the beginning + // or end. + string description_prefix = 9; + string description_suffix = 10; +} + +message ApiDefs { + repeated ApiDef op = 1; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/attr_value.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/attr_value.proto new file mode 100644 index 000000000..461f3fa5f --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/attr_value.proto @@ -0,0 +1,64 @@ +syntax = "proto3"; + +package demo_plugin; + +import "tensorflow_plugin/src/utils/tensor.proto"; +import "tensorflow_plugin/src/utils/tensor_shape.proto"; +import "tensorflow_plugin/src/utils/types.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/attr_value_go_proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/cost_graph.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/cost_graph.proto new file mode 100644 index 000000000..4dc8e0f01 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/cost_graph.proto @@ -0,0 +1,89 @@ +syntax = "proto3"; + +package demo_plugin; + +import "tensorflow_plugin/src/utils/tensor_shape.proto"; +import "tensorflow_plugin/src/utils/types.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "CostGraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/cost_graph_go_proto"; + +message CostGraphDef { + message Node { + // The name of the node. Names are globally unique. + string name = 1; + + // The device of the node. Can be empty if the node is mapped to the + // default partition or partitioning hasn't been run yet. + string device = 2; + + // The id of the node. Node ids are only unique inside a partition. + int32 id = 3; + + // Inputs of this node. They must be executed before this node can be + // executed. An input is a particular output of another node, specified + // by the node id and the output index. + message InputInfo { + int32 preceding_node = 1; + int32 preceding_port = 2; + } + repeated InputInfo input_info = 4; + + // Outputs of this node. + message OutputInfo { + int64 size = 1; + // If >= 0, the output is an alias of an input. Note that an alias input + // may itself be an alias. The algorithm will therefore need to follow + // those pointers. + int64 alias_input_port = 2; + TensorShapeProto shape = 3; + DataType dtype = 4; + } + repeated OutputInfo output_info = 5; + + // Temporary memory used by this node. + int64 temporary_memory_size = 6; + + // Persistent memory used by this node. + int64 persistent_memory_size = 12; + + int64 host_temp_memory_size = 10 [deprecated = true]; + int64 device_temp_memory_size = 11 [deprecated = true]; + int64 device_persistent_memory_size = 16 [deprecated = true]; + + // Estimate of the computational cost of this node, in microseconds. + int64 compute_cost = 9; + + // Analytical estimate of the computational cost of this node, in + // microseconds. + int64 compute_time = 14; + + // Analytical estimate of the memory access cost of this node, in + // microseconds. + int64 memory_time = 15; + + // If true, the output is permanent: it can't be discarded, because this + // node is part of the "final output". Nodes may depend on final nodes. + bool is_final = 7; + + // Ids of the control inputs for this node. + repeated int32 control_input = 8; + + // Are the costs inaccurate? + bool inaccurate = 17; + } + repeated Node node = 1; + + // Total cost of this graph, typically used for balancing decisions. + message AggregatedCost { + // Aggregated cost value. + float cost = 1; + + // Aggregated cost dimension (e.g. 'memory', 'compute', 'network'). + string dimension = 2; + } + repeated AggregatedCost cost = 2; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/ctstring.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/ctstring.h new file mode 100644 index 000000000..8c6b358b2 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/ctstring.h @@ -0,0 +1,121 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_CTSTRING_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_CTSTRING_H_ + +#include +#include + +#include "tensorflow/core/platform/ctstring_internal.h" + +// Initialize a new tstring. This must be called before using any function +// below. +inline void TF_TString_Init(TF_TString *str); +// Deallocate a tstring. +inline void TF_TString_Dealloc(TF_TString *str); + +// Resizes `str' to `new_size'. This function will appropriately grow or shrink +// the string buffer to fit a `new_size' string. Grown regions of the string +// will be initialized with `c'. +inline char *TF_TString_Resize(TF_TString *str, size_t new_size, char c); +// Similar to TF_TString_Resize, except the newly allocated regions will remain +// uninitialized. This is useful if you plan on overwriting the newly grown +// regions immediately after allocation; doing so will elide a superfluous +// initialization of the new buffer. +inline char *TF_TString_ResizeUninitialized(TF_TString *str, size_t new_size); +// Reserves a string buffer with a capacity of at least `new_cap'. +// ResizeUninitialized will not change the size, or the contents of the existing +// string. This is useful if you have a rough idea of `str's upperbound in +// size, and want to avoid allocations as you append to `str'. It should not be +// considered safe to write in the region between size and capacity; explicitly +// resize before doing so. +inline void TF_TString_Reserve(TF_TString *str, size_t new_cap); + +// Returns the size of the string. +inline size_t TF_TString_GetSize(const TF_TString *str); +// Returns the capacity of the string buffer. It should not be considered safe +// to write in the region between size and capacity---call Resize or +// ResizeUninitialized before doing so. +inline size_t TF_TString_GetCapacity(const TF_TString *str); +// Returns the underlying type of the tstring: +// TF_TSTR_SMALL: +// Small string optimization; the contents of strings +// less than 22-bytes are stored in the TF_TString struct. This avoids any +// heap allocations. +// TF_TSTR_LARGE: +// Heap allocated string. +// TF_TSTR_OFFSET: (currently unused) +// An offset defined string. The string buffer begins at an internally +// defined little-endian offset from `str'; i.e. GetDataPointer() = str + +// offset. This type is useful for memory mapping or reading string tensors +// directly from file, without the need to deserialize the data. For +// security reasons, it is imperative that OFFSET based string tensors are +// validated before use, or are from a trusted source. +// TF_TSTR_VIEW: +// A view into an unowned character string. +// +// NOTE: +// VIEW and OFFSET types are immutable, so any modifcation via Append, +// AppendN, or GetMutableDataPointer of a VIEW/OFFSET based tstring will +// result in a conversion to an owned type (SMALL/LARGE). +inline TF_TString_Type TF_TString_GetType(const TF_TString *str); + +// Returns a const char pointer to the start of the underlying string. The +// underlying character buffer may not be null-terminated. +inline const char *TF_TString_GetDataPointer(const TF_TString *str); +// Returns a char pointer to a mutable representation of the underlying string. +// In the case of VIEW and OFFSET types, `src' is converted to an owned type +// (SMALL/LARGE). The underlying character buffer may not be null-terminated. +inline char *TF_TString_GetMutableDataPointer(TF_TString *str); + +// Sets `dst' as a VIEW type to `src'. `dst' will not take ownership of `src'. +// It is the user's responsibility to ensure that the lifetime of `src' exceeds +// `dst'. Any mutations to `dst' via Append, AppendN, or GetMutableDataPointer, +// will result in a copy into an owned SMALL or LARGE type, and will not modify +// `src'. +inline void TF_TString_AssignView(TF_TString *dst, const char *src, + size_t size); + +// Appends `src' onto `dst'. If `dst' is a VIEW or OFFSET type, it will first +// be converted to an owned LARGE or SMALL type. `dst' should not point to +// memory owned by `src'. +inline void TF_TString_Append(TF_TString *dst, const TF_TString *src); +inline void TF_TString_AppendN(TF_TString *dst, const char *src, size_t size); + +// Copy/Move/Assign semantics +// +// | src | dst | complexity +// Copy | * | SMALL/LARGE | fixed/O(size) +// Assign | SMALL | SMALL | fixed +// Assign | OFFSET | VIEW | fixed +// Assign | VIEW | VIEW | fixed +// Assign | LARGE | LARGE | O(size) +// Move | * | same as src | fixed + +// Copies `src' to `dst'. `dst' will be an owned type (SMALL/LARGE). `src' +// should not point to memory owned by `dst'. +inline void TF_TString_Copy(TF_TString *dst, const char *src, size_t size); +// Assigns a `src' tstring to `dst'. An OFFSET `src' type will yield a `VIEW' +// `dst'. LARGE `src' types will be copied to a new buffer; all other `src' +// types will incur a fixed cost. +inline void TF_TString_Assign(TF_TString *dst, const TF_TString *src); +// Moves a `src' tstring to `dst'. Moving a LARGE `src' to `dst' will result in +// a valid but unspecified `src'. This function incurs a fixed cost for all +// inputs. +inline void TF_TString_Move(TF_TString *dst, TF_TString *src); + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_CTSTRING_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/ctstring_internal.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/ctstring_internal.h new file mode 100644 index 000000000..0e7f9829b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/ctstring_internal.h @@ -0,0 +1,452 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_CTSTRING_INTERNAL_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_CTSTRING_INTERNAL_H_ + +#include +#include +#include +#include + +#if (defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) && \ + __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) || \ + defined(_WIN32) +#define TF_TSTRING_LITTLE_ENDIAN 1 +#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ + __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define TF_TSTRING_LITTLE_ENDIAN 0 +#else +#error "Unable to detect endianness." +#endif + +#if defined(__clang__) || \ + (defined(__GNUC__) && \ + ((__GNUC__ == 4 && __GNUC_MINOR__ >= 8) || __GNUC__ >= 5)) +static inline uint32_t TF_swap32(uint32_t host_int) { + return __builtin_bswap32(host_int); +} + +#elif defined(_MSC_VER) +static inline uint32_t TF_swap32(uint32_t host_int) { + return _byteswap_ulong(host_int); +} + +#elif defined(__APPLE__) +static inline uint32_t TF_swap32(uint32_t host_int) { + return OSSwapInt32(host_int); +} + +#else +static inline uint32_t TF_swap32(uint32_t host_int) { +#if defined(__GLIBC__) + return bswap_32(host_int); +#else // defined(__GLIBC__) + return (((host_int & uint32_t{0xFF}) << 24) | + ((host_int & uint32_t{0xFF00}) << 8) | + ((host_int & uint32_t{0xFF0000}) >> 8) | + ((host_int & uint32_t{0xFF000000}) >> 24)); +#endif // defined(__GLIBC__) +} +#endif + +#if TF_TSTRING_LITTLE_ENDIAN +#define TF_le32toh(x) TF_swap32(x) +#else // TF_TSTRING_LITTLE_ENDIAN +#define TF_le32toh(x) x +#endif // TF_TSTRING_LITTLE_ENDIAN + +static inline size_t TF_align16(size_t i) { return (i + 0xF) & ~0xF; } + +static inline size_t TF_max(size_t a, size_t b) { return a > b ? a : b; } +static inline size_t TF_min(size_t a, size_t b) { return a < b ? a : b; } + +typedef enum TF_TString_Type { // NOLINT + TF_TSTR_SMALL = 0x00, + TF_TSTR_LARGE = 0x01, + TF_TSTR_OFFSET = 0x02, + TF_TSTR_VIEW = 0x03, + TF_TSTR_TYPE_MASK = 0x03 +} TF_TString_Type; + +typedef struct TF_TString_Large { // NOLINT + size_t size; + size_t cap; + char *ptr; +} TF_TString_Large; + +typedef struct TF_TString_Offset { // NOLINT + uint32_t size; + uint32_t offset; + uint32_t count; +} TF_TString_Offset; + +typedef struct TF_TString_View { // NOLINT + size_t size; + const char *ptr; +} TF_TString_View; + +typedef struct TF_TString_Raw { // NOLINT + uint8_t raw[24]; +} TF_TString_Raw; + +typedef union TF_TString_Union { // NOLINT + TF_TString_Large large; + TF_TString_Offset offset; + TF_TString_View view; + TF_TString_Raw raw; +} TF_TString_Union; + +enum { + TF_TString_SmallCapacity = + (sizeof(TF_TString_Union) - sizeof(/* null delim */ char) - + sizeof(/* uint8_t size */ uint8_t)), +}; + +typedef struct TF_TString_Small { // NOLINT + uint8_t size; + char str[TF_TString_SmallCapacity + sizeof(/* null delim */ char)]; +} TF_TString_Small; + +typedef struct TF_TString { // NOLINT + union { + // small conflicts with '#define small char' in RpcNdr.h for MSVC, so we use + // smll instead. + TF_TString_Small smll; + TF_TString_Large large; + TF_TString_Offset offset; + TF_TString_View view; + TF_TString_Raw raw; + } u; +} TF_TString; + +// TODO(dero): Fix for OSS, and add C only build test. +// _Static_assert(CHAR_BIT == 8); +// _Static_assert(sizeof(TF_TString) == 24); + +static inline TF_TString_Type TF_TString_GetType(const TF_TString *str) { + return (TF_TString_Type)(str->u.raw.raw[0] & TF_TSTR_TYPE_MASK); // NOLINT +} + +// XXX(dero): For the big-endian case, this function could potentially be more +// performant and readable by always storing the string size as little-endian +// and always byte-swapping on big endian, resulting in a simple 'bswap'+'shr' +// (for architectures that have a bswap op). +static inline size_t TF_TString_ToActualSizeT(size_t size) { +#if TF_TSTRING_LITTLE_ENDIAN + return size >> 2; +#else // TF_TSTRING_LITTLE_ENDIAN + // 0xFF000000 or 0xFF00000000000000 depending on platform + static const size_t mask = ~((~(size_t)0) >> 8); + + return (((mask << 2) & size) >> 2) | (~mask & size); +#endif // TF_TSTRING_LITTLE_ENDIAN +} + +static inline size_t TF_TString_ToInternalSizeT(size_t size, + TF_TString_Type type) { +#if TF_TSTRING_LITTLE_ENDIAN + return (size << 2) | type; +#else // TF_TSTRING_LITTLE_ENDIAN + // 0xFF000000 or 0xFF00000000000000 depending on platform + static const size_t mask = ~((~(size_t)0) >> 8); + + return (mask & (size << 2)) | (~mask & size) | + ((size_t)type << ((sizeof(size_t) - 1) * 8)); // NOLINT +#endif // TF_TSTRING_LITTLE_ENDIAN +} + +static inline void TF_TString_Init(TF_TString *str) { + memset(str->u.raw.raw, 0, sizeof(TF_TString_Raw)); +} + +static inline void TF_TString_Dealloc(TF_TString *str) { + if (TF_TString_GetType(str) == TF_TSTR_LARGE && + str->u.large.ptr != NULL) { // NOLINT + free(str->u.large.ptr); + TF_TString_Init(str); + } +} + +static inline size_t TF_TString_GetSize(const TF_TString *str) { + switch (TF_TString_GetType(str)) { + case TF_TSTR_SMALL: + return str->u.smll.size >> 2; + case TF_TSTR_LARGE: + return TF_TString_ToActualSizeT(str->u.large.size); + case TF_TSTR_OFFSET: + return TF_le32toh(str->u.offset.size) >> 2; + case TF_TSTR_VIEW: + return TF_TString_ToActualSizeT(str->u.view.size); + default: + return 0; // Unreachable. + } +} + +static inline size_t TF_TString_GetCapacity(const TF_TString *str) { + switch (TF_TString_GetType(str)) { + case TF_TSTR_SMALL: + return TF_TString_SmallCapacity; + case TF_TSTR_LARGE: + return str->u.large.cap; + case TF_TSTR_OFFSET: + case TF_TSTR_VIEW: + default: + return 0; + } +} + +static inline const char *TF_TString_GetDataPointer(const TF_TString *str) { + switch (TF_TString_GetType(str)) { + case TF_TSTR_SMALL: + return str->u.smll.str; + case TF_TSTR_LARGE: + return str->u.large.ptr; + case TF_TSTR_OFFSET: + return (const char *)str + str->u.offset.offset; // NOLINT + case TF_TSTR_VIEW: + return str->u.view.ptr; + default: + // Unreachable. + return NULL; // NOLINT + } +} + +static inline char *TF_TString_ResizeUninitialized(TF_TString *str, + size_t new_size) { + size_t curr_size = TF_TString_GetSize(str); + size_t copy_size = TF_min(new_size, curr_size); + + TF_TString_Type curr_type = TF_TString_GetType(str); + const char *curr_ptr = TF_TString_GetDataPointer(str); + + // Case: SMALL/LARGE/VIEW/OFFSET -> SMALL + if (new_size <= TF_TString_SmallCapacity) { + str->u.smll.size = (uint8_t)((new_size << 2) | TF_TSTR_SMALL); // NOLINT + str->u.smll.str[new_size] = '\0'; + + if (curr_type != TF_TSTR_SMALL && copy_size) { + memcpy(str->u.smll.str, curr_ptr, copy_size); + } + + if (curr_type == TF_TSTR_LARGE) { + free((void *)curr_ptr); // NOLINT + } + + // We do not clear out the newly excluded region. + + return str->u.smll.str; + } + + // Case: SMALL/LARGE/VIEW/OFFSET -> LARGE + size_t new_cap; + size_t curr_cap = TF_TString_GetCapacity(str); + // We assume SIZE_MAX % 16 == 0. + size_t curr_cap_x2 = curr_cap >= SIZE_MAX / 2 ? SIZE_MAX - 1 : curr_cap * 2; + + if (new_size < curr_size && new_size < curr_cap / 2) { + // TODO(dero): Replace with shrink_to_fit flag. + new_cap = TF_align16(curr_cap / 2 + 1) - 1; + } else if (new_size > curr_cap_x2) { + new_cap = TF_align16(new_size + 1) - 1; + } else if (new_size > curr_cap) { + new_cap = TF_align16(curr_cap_x2 + 1) - 1; + } else { + new_cap = curr_cap; + } + + char *new_ptr; + if (new_cap == curr_cap) { + new_ptr = str->u.large.ptr; + } else if (curr_type == TF_TSTR_LARGE) { + new_ptr = (char *)realloc(str->u.large.ptr, new_cap + 1); // NOLINT + } else { + new_ptr = (char *)malloc(new_cap + 1); // NOLINT + if (copy_size) { + memcpy(new_ptr, curr_ptr, copy_size); + } + } + + str->u.large.size = TF_TString_ToInternalSizeT(new_size, TF_TSTR_LARGE); + str->u.large.ptr = new_ptr; + str->u.large.ptr[new_size] = '\0'; + str->u.large.cap = new_cap; + + return str->u.large.ptr; +} + +static inline char *TF_TString_GetMutableDataPointer(TF_TString *str) { + switch (TF_TString_GetType(str)) { + case TF_TSTR_SMALL: + return str->u.smll.str; + case TF_TSTR_OFFSET: + case TF_TSTR_VIEW: + // Convert OFFSET/VIEW to SMALL/LARGE + TF_TString_ResizeUninitialized(str, TF_TString_GetSize(str)); + return (TF_TString_GetType(str) == TF_TSTR_SMALL) ? str->u.smll.str + : str->u.large.ptr; + case TF_TSTR_LARGE: + return str->u.large.ptr; + default: + // Unreachable. + return NULL; // NOLINT + } +} + +static inline void TF_TString_Reserve(TF_TString *str, size_t new_cap) { + TF_TString_Type curr_type = TF_TString_GetType(str); + + if (new_cap <= TF_TString_SmallCapacity) { + // We do nothing, we let Resize/GetMutableDataPointer handle the + // conversion to SMALL from VIEW/OFFSET when the need arises. + // In the degenerate case, where new_cap <= TF_TString_SmallCapacity, + // curr_size > TF_TString_SmallCapacity, and the type is VIEW/OFFSET, we + // defer the malloc to Resize/GetMutableDataPointer. + return; + } + + if (curr_type == TF_TSTR_LARGE && new_cap <= str->u.large.cap) { + // We handle reduced cap in resize. + return; + } + + // Case: VIEW/OFFSET -> LARGE or grow an existing LARGE type + size_t curr_size = TF_TString_GetSize(str); + const char *curr_ptr = TF_TString_GetDataPointer(str); + + // Since VIEW and OFFSET types are read-only, their capacity is effectively 0. + // So we make sure we have enough room in the VIEW and OFFSET cases. + new_cap = TF_align16(TF_max(new_cap, curr_size) + 1) - 1; + + if (curr_type == TF_TSTR_LARGE) { + str->u.large.ptr = (char *)realloc(str->u.large.ptr, new_cap + 1); // NOLINT + } else { + // Convert to Large + char *new_ptr = (char *)malloc(new_cap + 1); // NOLINT + memcpy(new_ptr, curr_ptr, curr_size); + + str->u.large.size = TF_TString_ToInternalSizeT(curr_size, TF_TSTR_LARGE); + str->u.large.ptr = new_ptr; + str->u.large.ptr[curr_size] = '\0'; + } + + str->u.large.cap = new_cap; +} + +static inline char *TF_TString_Resize(TF_TString *str, size_t new_size, + char c) { + size_t curr_size = TF_TString_GetSize(str); + char *cstr = TF_TString_ResizeUninitialized(str, new_size); + + if (new_size > curr_size) { + memset(cstr + curr_size, c, new_size - curr_size); + } + + return cstr; +} + +static inline void TF_TString_AssignView(TF_TString *dst, const char *src, + size_t size) { + TF_TString_Dealloc(dst); + + dst->u.view.size = TF_TString_ToInternalSizeT(size, TF_TSTR_VIEW); + dst->u.view.ptr = src; +} + +static inline void TF_TString_AppendN(TF_TString *dst, const char *src, + size_t src_size) { + if (!src_size) + return; + + size_t dst_size = TF_TString_GetSize(dst); + + char *dst_c = TF_TString_ResizeUninitialized(dst, dst_size + src_size); + + memcpy(dst_c + dst_size, src, src_size); +} + +static inline void TF_TString_Append(TF_TString *dst, const TF_TString *src) { + const char *src_c = TF_TString_GetDataPointer(src); + size_t size = TF_TString_GetSize(src); + + TF_TString_AppendN(dst, src_c, size); +} + +static inline void TF_TString_Copy(TF_TString *dst, const char *src, + size_t size) { + char *dst_c = TF_TString_ResizeUninitialized(dst, size); + + if (size) + memcpy(dst_c, src, size); +} + +static inline void TF_TString_Assign(TF_TString *dst, const TF_TString *src) { + if (dst == src) + return; + + TF_TString_Dealloc(dst); + + switch (TF_TString_GetType(src)) { + case TF_TSTR_SMALL: + case TF_TSTR_VIEW: + *dst = *src; + return; + case TF_TSTR_LARGE: { + const char *src_c = TF_TString_GetDataPointer(src); + size_t size = TF_TString_GetSize(src); + + TF_TString_Copy(dst, src_c, size); + } + return; + case TF_TSTR_OFFSET: { + const char *src_c = TF_TString_GetDataPointer(src); + size_t size = TF_TString_GetSize(src); + + TF_TString_AssignView(dst, src_c, size); + } + return; + default: + return; // Unreachable. + } +} + +static inline void TF_TString_Move(TF_TString *dst, TF_TString *src) { + if (dst == src) + return; + + TF_TString_Dealloc(dst); + + switch (TF_TString_GetType(src)) { + case TF_TSTR_SMALL: + case TF_TSTR_VIEW: + *dst = *src; + return; + case TF_TSTR_LARGE: + *dst = *src; + TF_TString_Init(src); + return; + case TF_TSTR_OFFSET: { + const char *src_c = TF_TString_GetDataPointer(src); + size_t size = TF_TString_GetSize(src); + + TF_TString_AssignView(dst, src_c, size); + } + return; + default: + return; // Unreachable. + } +} + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_CTSTRING_INTERNAL_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/device_properties.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/device_properties.proto new file mode 100644 index 000000000..25f31252b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/device_properties.proto @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package demo_plugin; + +option cc_enable_arenas = true; +option java_outer_classname = "DevicePropertiesProtos"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; + +message DeviceProperties { + // Device type (CPU, GPU, ...) + string type = 1; + // Vendor (Intel, nvidia, ...) + string vendor = 2; + // Model (Haswell, K40, ...) + string model = 3; + // Core Frequency in Mhz + int64 frequency = 4; + // Number of cores + int64 num_cores = 5; + // Version of the tools and libraries used with this device (e.g. gcc 4.9, + // cudnn 5.1) + map environment = 6; + // Number of registers per core. + int64 num_registers = 7; + // L1 cache size in bytes + int64 l1_cache_size = 8; + // L2 cache size in bytes + int64 l2_cache_size = 9; + // L3 cache size in bytes + int64 l3_cache_size = 10; + // Shared memory size per multiprocessor in bytes. This field is + // applicable to GPUs only. + int64 shared_memory_size_per_multiprocessor = 11; + // Memory size in bytes + int64 memory_size = 12; + // Memory bandwidth in KB/s + int64 bandwidth = 13; +} + +message NamedDevice { + string name = 1; + DeviceProperties properties = 2; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/env_time.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/env_time.cc new file mode 100644 index 000000000..8457f3c6a --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/env_time.cc @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include +#include + +#include "tensorflow_plugin/src/utils/env_time.h" + +namespace demo_plugin { + +/* static */ +uint64 EnvTime::NowNanos() { + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + return (static_cast(ts.tv_sec) * kSecondsToNanos + + static_cast(ts.tv_nsec)); +} + +} // namespace demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/env_time.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/env_time.h new file mode 100644 index 000000000..9274fe72d --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/env_time.h @@ -0,0 +1,66 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_ENV_TIME_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_ENV_TIME_H_ + +#include + +#include "tensorflow_plugin/src/utils/types.h" + +namespace demo_plugin { + +/// \brief An interface used by the tensorflow implementation to +/// access timer related operations. +class EnvTime { +public: + static constexpr uint64 kMicrosToPicos = 1000ULL * 1000ULL; + static constexpr uint64 kMicrosToNanos = 1000ULL; + static constexpr uint64 kMillisToMicros = 1000ULL; + static constexpr uint64 kMillisToNanos = 1000ULL * 1000ULL; + static constexpr uint64 kNanosToPicos = 1000ULL; + static constexpr uint64 kSecondsToMillis = 1000ULL; + static constexpr uint64 kSecondsToMicros = 1000ULL * 1000ULL; + static constexpr uint64 kSecondsToNanos = 1000ULL * 1000ULL * 1000ULL; + + EnvTime() = default; + virtual ~EnvTime() = default; + + /// \brief Returns the number of nano-seconds since the Unix epoch. + static uint64 NowNanos(); + + /// \brief Returns the number of micro-seconds since the Unix epoch. + static uint64 NowMicros() { return NowNanos() / kMicrosToNanos; } + + /// \brief Returns the number of seconds since the Unix epoch. + static uint64 NowSeconds() { return NowNanos() / kSecondsToNanos; } + + /// \brief A version of NowNanos() that may be overridden by a subclass. + virtual uint64 GetOverridableNowNanos() const { return NowNanos(); } + + /// \brief A version of NowMicros() that may be overridden by a subclass. + virtual uint64 GetOverridableNowMicros() const { + return GetOverridableNowNanos() / kMicrosToNanos; + } + + /// \brief A version of NowSeconds() that may be overridden by a subclass. + virtual uint64 GetOverridableNowSeconds() const { + return GetOverridableNowNanos() / kSecondsToNanos; + } +}; + +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_ENV_TIME_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/function.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/function.proto new file mode 100644 index 000000000..a3c95eb39 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/function.proto @@ -0,0 +1,126 @@ +syntax = "proto3"; + +package demo_plugin; + +import "tensorflow_plugin/src/utils/attr_value.proto"; +import "tensorflow_plugin/src/utils/node_def.proto"; +import "tensorflow_plugin/src/utils/op_def.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/function_go_proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// +// TODO(zhifengc): +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // Attributes for function arguments. These attributes are the same set of + // valid attributes as to _Arg nodes. + message ArgAttrs { + map attr = 1; + } + map arg_attr = 7; + + // Unique IDs for each resource argument, used to track aliasing resources. If + // Argument A and Argument B alias each other, then + // resource_arg_unique_ids[A.index] == resource_arg_unique_ids[B.index]. + // + // If this field is empty, none of the arguments could alias; otherwise, every + // resource argument should have an entry in this field. + // + // When instantiated, the unique IDs will be attached to the _Arg nodes' + // "_resource_arg_unique_id" attribute. + map resource_arg_unique_id = 8; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; + + // A mapping from control output names from `signature` to node names in + // `node_def` which should be control outputs of this function. + map control_ret = 6; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/graph.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/graph.proto new file mode 100644 index 000000000..81474cf3e --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package demo_plugin; + +import "tensorflow_plugin/src/utils/function.proto"; +import "tensorflow_plugin/src/utils/node_def.proto"; +import "tensorflow_plugin/src/utils/versions.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/graph_go_proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/BUILD new file mode 100644 index 000000000..5f597fc28 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/BUILD @@ -0,0 +1,18 @@ +cc_library( + name = "gtl_libs", + srcs = glob([ + "*.cc", + ]), + hdrs = glob([ + "*.h", + ]), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_plugin/src/utils:prefetch", + "//tensorflow_plugin/src/utils:types", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/array_slice.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/array_slice.h new file mode 100644 index 000000000..4c0417c92 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/array_slice.h @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_GTL_ARRAY_SLICE_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_GTL_ARRAY_SLICE_H_ + +#include "absl/types/span.h" +// TODO(Intel-tf): This is kept only because lots of targets transitively depend +// on it. Remove all targets' dependencies. +#include "tensorflow_plugin/src/utils/gtl/inlined_vector.h" + +namespace demo_plugin { +namespace gtl { + +template +using ArraySlice = absl::Span; + +template +using MutableArraySlice = absl::Span; + +} // namespace gtl +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_GTL_ARRAY_SLICE_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatmap.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatmap.h new file mode 100644 index 000000000..79af8676a --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatmap.h @@ -0,0 +1,393 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATMAP_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATMAP_H_ + +#include +#include +#include +#include +#include +#include "tensorflow_plugin/src/utils/gtl/flatrep.h" +#include "tensorflow_plugin/src/utils/hash.h" +#include "tensorflow_plugin/src/utils/logging.h" +#include "tensorflow_plugin/src/utils/types.h" + +namespace demo_plugin { +namespace gtl { + +// FlatMap provides a map from K to V. +// +// The map is implemented using an open-addressed hash table. A +// single array holds entire map contents and collisions are resolved +// by probing at a sequence of locations in the array. +template , + class Eq = std::equal_to> +class FlatMap { + private: + // Forward declare some internal types needed in public section. + struct Bucket; + + // We cannot use std::pair<> since internal representation stores + // keys and values in separate arrays, so we make a custom struct + // that holds references to the internal key, value elements. + // + // We define the struct as private ValueType, and typedef it as public + // value_type, to work around a gcc bug when compiling the iterators. + struct ValueType { + typedef Key first_type; + typedef Val second_type; + + const Key& first; + Val& second; + ValueType(const Key& k, Val& v) : first(k), second(v) {} + }; + + public: + typedef Key key_type; + typedef Val mapped_type; + typedef Hash hasher; + typedef Eq key_equal; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef ValueType value_type; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef value_type& reference; + typedef const value_type& const_reference; + + FlatMap() : FlatMap(1) {} + + explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) + : rep_(N, hf, eq) {} + + FlatMap(const FlatMap& src) : rep_(src.rep_) {} + + // Move constructor leaves src in a valid but unspecified state (same as + // std::unordered_map). + FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {} + + template + FlatMap(InputIter first, InputIter last, size_t N = 1, + const Hash& hf = Hash(), const Eq& eq = Eq()) + : FlatMap(N, hf, eq) { + insert(first, last); + } + + FlatMap(std::initializer_list> init, size_t N = 1, + const Hash& hf = Hash(), const Eq& eq = Eq()) + : FlatMap(init.begin(), init.end(), N, hf, eq) {} + + FlatMap& operator=(const FlatMap& src) { + rep_.CopyFrom(src.rep_); + return *this; + } + + // Move-assignment operator leaves src in a valid but unspecified state (same + // as std::unordered_map). + FlatMap& operator=(FlatMap&& src) { + rep_.MoveFrom(std::move(src.rep_)); + return *this; + } + + ~FlatMap() {} + + void swap(FlatMap& x) { rep_.swap(x.rep_); } + void clear_no_resize() { rep_.clear_no_resize(); } + void clear() { rep_.clear(); } + void reserve(size_t N) { rep_.Resize(std::max(N, size())); } + void rehash(size_t N) { rep_.Resize(std::max(N, size())); } + void resize(size_t N) { rep_.Resize(std::max(N, size())); } + size_t size() const { return rep_.size(); } + bool empty() const { return size() == 0; } + size_t bucket_count() const { return rep_.bucket_count(); } + hasher hash_function() const { return rep_.hash_function(); } + key_equal key_eq() const { return rep_.key_eq(); } + + class iterator { + public: + typedef typename FlatMap::difference_type difference_type; + typedef typename FlatMap::value_type value_type; + typedef typename FlatMap::pointer pointer; + typedef typename FlatMap::reference reference; + typedef ::std::forward_iterator_tag iterator_category; + + iterator() : b_(nullptr), end_(nullptr), i_(0) {} + + // Make iterator pointing at first element at or after b. + iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { SkipUnused(); } + + // Make iterator pointing exactly at ith element in b, which must exist. + iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) { + FillValue(); + } + + reference operator*() { return *val(); } + pointer operator->() { return val(); } + bool operator==(const iterator& x) const { + return b_ == x.b_ && i_ == x.i_; + } + bool operator!=(const iterator& x) const { return !(*this == x); } + iterator& operator++() { + DCHECK(b_ != end_); + i_++; + SkipUnused(); + return *this; + } + iterator operator++(int /*indicates postfix*/) { + iterator tmp(*this); + ++*this; + return tmp; + } + + private: + friend class FlatMap; + Bucket* b_; + Bucket* end_; + char space_ alignas(value_type)[sizeof(value_type)]; + uint32 i_; + + pointer val() { return reinterpret_cast(space_); } + void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); } + void SkipUnused() { + while (b_ < end_) { + if (i_ >= Rep::kWidth) { + i_ = 0; + b_++; + } else if (b_->marker[i_] < 2) { + i_++; + } else { + FillValue(); + break; + } + } + } + }; + + class const_iterator { + private: + mutable iterator rep_; // Share state and logic with non-const iterator. + public: + typedef typename FlatMap::difference_type difference_type; + typedef typename FlatMap::value_type value_type; + typedef typename FlatMap::const_pointer pointer; + typedef typename FlatMap::const_reference reference; + typedef ::std::forward_iterator_tag iterator_category; + + const_iterator() : rep_() {} + const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {} + const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {} + + reference operator*() const { return *rep_.val(); } + pointer operator->() const { return rep_.val(); } + bool operator==(const const_iterator& x) const { return rep_ == x.rep_; } + bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; } + const_iterator& operator++() { + ++rep_; + return *this; + } + const_iterator operator++(int /*indicates postfix*/) { + const_iterator tmp(*this); + ++*this; + return tmp; + } + }; + + iterator begin() { return iterator(rep_.start(), rep_.limit()); } + iterator end() { return iterator(rep_.limit(), rep_.limit()); } + const_iterator begin() const { + return const_iterator(rep_.start(), rep_.limit()); + } + const_iterator end() const { + return const_iterator(rep_.limit(), rep_.limit()); + } + + size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } + iterator find(const Key& k) { + auto r = rep_.Find(k); + return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); + } + const_iterator find(const Key& k) const { + auto r = rep_.Find(k); + return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); + } + + Val& at(const Key& k) { + auto r = rep_.Find(k); + DCHECK(r.found); + return r.b->val(r.index); + } + const Val& at(const Key& k) const { + auto r = rep_.Find(k); + DCHECK(r.found); + return r.b->val(r.index); + } + + template + std::pair insert(const P& p) { + return Insert(p.first, p.second); + } + std::pair insert(const std::pair& p) { + return Insert(p.first, p.second); + } + template + void insert(InputIter first, InputIter last) { + for (; first != last; ++first) { + insert(*first); + } + } + + Val& operator[](const Key& k) { return IndexOp(k); } + Val& operator[](Key&& k) { return IndexOp(std::forward(k)); } + + template + std::pair emplace(Args&&... args) { + return InsertPair(std::make_pair(std::forward(args)...)); + } + + size_t erase(const Key& k) { + auto r = rep_.Find(k); + if (!r.found) return 0; + rep_.Erase(r.b, r.index); + return 1; + } + iterator erase(iterator pos) { + rep_.Erase(pos.b_, pos.i_); + ++pos; + return pos; + } + iterator erase(iterator pos, iterator last) { + for (; pos != last; ++pos) { + rep_.Erase(pos.b_, pos.i_); + } + return pos; + } + + std::pair equal_range(const Key& k) { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + std::pair equal_range(const Key& k) const { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + + bool operator==(const FlatMap& x) const { + if (size() != x.size()) return false; + for (auto& p : x) { + auto i = find(p.first); + if (i == end()) return false; + if (i->second != p.second) return false; + } + return true; + } + bool operator!=(const FlatMap& x) const { return !(*this == x); } + + // If key exists in the table, prefetch the associated value. This + // is a hint, and may have no effect. + void prefetch_value(const Key& key) const { rep_.Prefetch(key); } + + private: + using Rep = internal::FlatRep; + + // Bucket stores kWidth triples. + // The data is organized as three parallel arrays to reduce padding. + struct Bucket { + uint8 marker[Rep::kWidth]; + + // Wrap keys and values in union to control construction and destruction. + union Storage { + struct { + Key key[Rep::kWidth]; + Val val[Rep::kWidth]; + }; + Storage() {} + ~Storage() {} + } storage; + + Key& key(uint32 i) { + DCHECK_GE(marker[i], 2); + return storage.key[i]; + } + Val& val(uint32 i) { + DCHECK_GE(marker[i], 2); + return storage.val[i]; + } + template + void InitVal(uint32 i, V&& v) { + new (&storage.val[i]) Val(std::forward(v)); + } + void Destroy(uint32 i) { + storage.key[i].Key::~Key(); + storage.val[i].Val::~Val(); + } + void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); + new (&storage.val[i]) Val(std::move(src->storage.val[src_index])); + } + void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(src->storage.key[src_index]); + new (&storage.val[i]) Val(src->storage.val[src_index]); + } + }; + + template + std::pair InsertPair(Pair&& p) { + return Insert(std::forward(p.first), + std::forward(p.second)); + } + + template + std::pair Insert(K&& k, V&& v) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(std::forward(k)); + const bool inserted = !r.found; + if (inserted) { + r.b->InitVal(r.index, std::forward(v)); + } + return {iterator(r.b, rep_.limit(), r.index), inserted}; + } + + template + Val& IndexOp(K&& k) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(std::forward(k)); + Val* vptr = &r.b->val(r.index); + if (!r.found) { + new (vptr) Val(); // Initialize value in new slot. + } + return *vptr; + } + + Rep rep_; +}; + +} // namespace gtl +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATMAP_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatrep.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatrep.h new file mode 100644 index 000000000..73fdc1720 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatrep.h @@ -0,0 +1,351 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATREP_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATREP_H_ + +#include +#include +#include "tensorflow_plugin/src/utils/prefetch.h" +#include "tensorflow_plugin/src/utils/types.h" + +namespace demo_plugin { +namespace gtl { +namespace internal { + +// Internal representation for FlatMap and FlatSet. +// +// The representation is an open-addressed hash table. Conceptually, +// the representation is a flat array of entries. However we +// structure it as an array of buckets where each bucket holds +// kWidth entries along with metadata for the kWidth entries. The +// metadata marker is +// +// (a) kEmpty: the entry is empty +// (b) kDeleted: the entry has been deleted +// (c) other: the entry is occupied and has low-8 bits of its hash. +// These hash bits can be used to avoid potentially expensive +// key comparisons. +// +// FlatMap passes in a bucket that contains keys and values, FlatSet +// passes in a bucket that does not contain values. +template +class FlatRep { + public: + // kWidth is the number of entries stored in a bucket. + static constexpr uint32 kBase = 3; + static constexpr uint32 kWidth = (1 << kBase); + + FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) { + Init(N); + } + FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) { + Init(src.size()); + CopyEntries(src.array_, src.end_, CopyEntry()); + } + + FlatRep(FlatRep&& src) + // Copy rather than move src.hash_ and src.equal_. This is necessary to + // leave src in a valid state -- otherwise e.g. if hash_ is an + // std::function, moving it would null it out. + : hash_(src.hash_), equal_(src.equal_) { + // TODO(jlebar): Init(1) still allocates some memory, so this isn't as cheap + // as it could be. The fundamental problem is that we need to leave src in + // a valid state, and FlatRep *always* owns a nonzero amount of memory. + Init(1); + swap(src); + } + + ~FlatRep() { + clear_no_resize(); + delete[] array_; + } + + // Simple accessors. + size_t size() const { return not_empty_ - deleted_; } + size_t bucket_count() const { return mask_ + 1; } + Bucket* start() const { return array_; } + Bucket* limit() const { return end_; } + const Hash& hash_function() const { return hash_; } + const Eq& key_eq() const { return equal_; } + + // Overwrite contents of *this with contents of src. + void CopyFrom(const FlatRep& src) { + if (this != &src) { + clear_no_resize(); + delete[] array_; + Init(src.size()); + CopyEntries(src.array_, src.end_, CopyEntry()); + } + } + + void MoveFrom(FlatRep&& src) { + if (this != &src) { + swap(src); + } + } + + void clear_no_resize() { + for (Bucket* b = array_; b != end_; b++) { + for (uint32 i = 0; i < kWidth; i++) { + if (b->marker[i] >= 2) { + b->Destroy(i); + b->marker[i] = kEmpty; + } + } + } + not_empty_ = 0; + deleted_ = 0; + } + + void clear() { + clear_no_resize(); + grow_ = 0; // Consider shrinking in MaybeResize() + MaybeResize(); + } + + void swap(FlatRep& x) { + using std::swap; + swap(array_, x.array_); + swap(end_, x.end_); + swap(lglen_, x.lglen_); + swap(mask_, x.mask_); + swap(not_empty_, x.not_empty_); + swap(deleted_, x.deleted_); + swap(grow_, x.grow_); + swap(shrink_, x.shrink_); + } + + struct SearchResult { + bool found; + Bucket* b; + uint32 index; + }; + + // Hash value is partitioned as follows: + // 1. Bottom 8 bits are stored in bucket to help speed up comparisons. + // 2. Next 3 bits give index inside bucket. + // 3. Remaining bits give bucket number. + + // Find bucket/index for key k. + SearchResult Find(const Key& k) const { + size_t h = hash_(k); + const uint32 marker = Marker(h & 0xff); + size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket + uint32 num_probes = 1; // Needed for quadratic probing + while (true) { + uint32 bi = index & (kWidth - 1); + Bucket* b = &array_[index >> kBase]; + const uint32 x = b->marker[bi]; + if (x == marker && equal_(b->key(bi), k)) { + return {true, b, bi}; + } else if (x == kEmpty) { + return {false, nullptr, 0}; + } + index = NextIndex(index, num_probes); + num_probes++; + } + } + + // Find bucket/index for key k, creating a new one if necessary. + // + // KeyType is a template parameter so that k's type is deduced and it + // becomes a universal reference which allows the key initialization + // below to use an rvalue constructor if available. + template + SearchResult FindOrInsert(KeyType&& k) { + size_t h = hash_(k); + const uint32 marker = Marker(h & 0xff); + size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket + uint32 num_probes = 1; // Needed for quadratic probing + Bucket* del = nullptr; // First encountered deletion for kInsert + uint32 di = 0; + while (true) { + uint32 bi = index & (kWidth - 1); + Bucket* b = &array_[index >> kBase]; + const uint32 x = b->marker[bi]; + if (x == marker && equal_(b->key(bi), k)) { + return {true, b, bi}; + } else if (!del && x == kDeleted) { + // Remember deleted index to use for insertion. + del = b; + di = bi; + } else if (x == kEmpty) { + if (del) { + // Store in the first deleted slot we encountered + b = del; + bi = di; + deleted_--; // not_empty_ does not change + } else { + not_empty_++; + } + b->marker[bi] = marker; + new (&b->key(bi)) Key(std::forward(k)); + return {false, b, bi}; + } + index = NextIndex(index, num_probes); + num_probes++; + } + } + + void Erase(Bucket* b, uint32 i) { + b->Destroy(i); + b->marker[i] = kDeleted; + deleted_++; + grow_ = 0; // Consider shrinking on next insert + } + + void Prefetch(const Key& k) const { + size_t h = hash_(k); + size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket + uint32 bi = index & (kWidth - 1); + Bucket* b = &array_[index >> kBase]; + port::prefetch(&b->marker[bi]); + port::prefetch(&b->storage.key[bi]); + } + + inline void MaybeResize() { + if (not_empty_ < grow_) { + return; // Nothing to do + } + if (grow_ == 0) { + // Special value set by erase to cause shrink on next insert. + if (size() >= shrink_) { + // Not small enough to shrink. + grow_ = static_cast(bucket_count() * 0.8); + if (not_empty_ < grow_) return; + } + } + Resize(size() + 1); + } + + void Resize(size_t N) { + Bucket* old = array_; + Bucket* old_end = end_; + Init(N); + CopyEntries(old, old_end, MoveEntry()); + delete[] old; + } + + private: + enum { kEmpty = 0, kDeleted = 1 }; // Special markers for an entry. + + Hash hash_; // User-supplied hasher + Eq equal_; // User-supplied comparator + uint8 lglen_; // lg(#buckets) + Bucket* array_; // array of length (1 << lglen_) + Bucket* end_; // Points just past last bucket in array_ + size_t mask_; // (# of entries in table) - 1 + size_t not_empty_; // Count of entries with marker != kEmpty + size_t deleted_; // Count of entries with marker == kDeleted + size_t grow_; // Grow array when not_empty_ >= grow_ + size_t shrink_; // Shrink array when size() < shrink_ + + // Avoid kEmpty and kDeleted markers when computing hash values to + // store in Bucket::marker[]. + static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); } + + void Init(size_t N) { + // Make enough room for N elements. + size_t lg = 0; // Smallest table is just one bucket. + while (N >= 0.8 * ((1 << lg) * kWidth)) { + lg++; + } + const size_t n = (1 << lg); + Bucket* array = new Bucket[n]; + for (size_t i = 0; i < n; i++) { + Bucket* b = &array[i]; + memset(b->marker, kEmpty, kWidth); + } + const size_t capacity = (1 << lg) * kWidth; + lglen_ = lg; + mask_ = capacity - 1; + array_ = array; + end_ = array + n; + not_empty_ = 0; + deleted_ = 0; + grow_ = static_cast(capacity * 0.8); + if (lg == 0) { + // Already down to one bucket; no more shrinking. + shrink_ = 0; + } else { + shrink_ = static_cast(grow_ * 0.4); // Must be less than 0.5 + } + } + + // Used by FreshInsert when we should copy from source. + struct CopyEntry { + inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) { + dst->CopyFrom(dsti, src, srci); + } + }; + + // Used by FreshInsert when we should move from source. + struct MoveEntry { + inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) { + dst->MoveFrom(dsti, src, srci); + src->Destroy(srci); + src->marker[srci] = kDeleted; + } + }; + + template + void CopyEntries(Bucket* start, Bucket* end, Copier copier) { + for (Bucket* b = start; b != end; b++) { + for (uint32 i = 0; i < kWidth; i++) { + if (b->marker[i] >= 2) { + FreshInsert(b, i, copier); + } + } + } + } + + // Create an entry for the key numbered src_index in *src and return + // its bucket/index. Used for insertion into a fresh table. We + // assume that there are no deletions, and k does not already exist + // in the table. + template + void FreshInsert(Bucket* src, uint32 src_index, Copier copier) { + size_t h = hash_(src->key(src_index)); + const uint32 marker = Marker(h & 0xff); + size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket + uint32 num_probes = 1; // Needed for quadratic probing + while (true) { + uint32 bi = index & (kWidth - 1); + Bucket* b = &array_[index >> kBase]; + const uint32 x = b->marker[bi]; + if (x == 0) { + b->marker[bi] = marker; + not_empty_++; + copier(b, bi, src, src_index); + return; + } + index = NextIndex(index, num_probes); + num_probes++; + } + } + + inline size_t NextIndex(size_t i, uint32 num_probes) const { + // Quadratic probing. + return (i + num_probes) & mask_; + } +}; + +} // namespace internal +} // namespace gtl +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATREP_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatset.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatset.h new file mode 100644 index 000000000..7928eeafb --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/flatset.h @@ -0,0 +1,295 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATSET_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATSET_H_ + +#include +#include +#include +#include +#include +#include "tensorflow_plugin/src/utils/gtl/flatrep.h" +#include "tensorflow_plugin/src/utils/hash.h" +#include "tensorflow_plugin/src/utils/logging.h" +#include "tensorflow_plugin/src/utils/types.h" + +namespace demo_plugin { +namespace gtl { + +// FlatSet provides a set of K. +// +// The map is implemented using an open-addressed hash table. A +// single array holds entire map contents and collisions are resolved +// by probing at a sequence of locations in the array. +template , class Eq = std::equal_to> +class FlatSet { + private: + // Forward declare some internal types needed in public section. + struct Bucket; + + public: + typedef Key key_type; + typedef Key value_type; + typedef Hash hasher; + typedef Eq key_equal; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef value_type& reference; + typedef const value_type& const_reference; + + FlatSet() : FlatSet(1) {} + + explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) + : rep_(N, hf, eq) {} + + FlatSet(const FlatSet& src) : rep_(src.rep_) {} + + // Move constructor leaves src in a valid but unspecified state (same as + // std::unordered_set). + FlatSet(FlatSet&& src) : rep_(std::move(src.rep_)) {} + + template + FlatSet(InputIter first, InputIter last, size_t N = 1, + const Hash& hf = Hash(), const Eq& eq = Eq()) + : FlatSet(N, hf, eq) { + insert(first, last); + } + + FlatSet(std::initializer_list init, size_t N = 1, + const Hash& hf = Hash(), const Eq& eq = Eq()) + : FlatSet(init.begin(), init.end(), N, hf, eq) {} + + FlatSet& operator=(const FlatSet& src) { + rep_.CopyFrom(src.rep_); + return *this; + } + + // Move-assignment operator leaves src in a valid but unspecified state (same + // as std::unordered_set). + FlatSet& operator=(FlatSet&& src) { + rep_.MoveFrom(std::move(src.rep_)); + return *this; + } + + ~FlatSet() {} + + void swap(FlatSet& x) { rep_.swap(x.rep_); } + void clear_no_resize() { rep_.clear_no_resize(); } + void clear() { rep_.clear(); } + void reserve(size_t N) { rep_.Resize(std::max(N, size())); } + void rehash(size_t N) { rep_.Resize(std::max(N, size())); } + void resize(size_t N) { rep_.Resize(std::max(N, size())); } + size_t size() const { return rep_.size(); } + bool empty() const { return size() == 0; } + size_t bucket_count() const { return rep_.bucket_count(); } + hasher hash_function() const { return rep_.hash_function(); } + key_equal key_eq() const { return rep_.key_eq(); } + + class const_iterator { + public: + typedef typename FlatSet::difference_type difference_type; + typedef typename FlatSet::value_type value_type; + typedef typename FlatSet::const_pointer pointer; + typedef typename FlatSet::const_reference reference; + typedef ::std::forward_iterator_tag iterator_category; + + const_iterator() : b_(nullptr), end_(nullptr), i_(0) {} + + // Make iterator pointing at first element at or after b. + const_iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { + SkipUnused(); + } + + // Make iterator pointing exactly at ith element in b, which must exist. + const_iterator(Bucket* b, Bucket* end, uint32 i) + : b_(b), end_(end), i_(i) {} + + reference operator*() const { return key(); } + pointer operator->() const { return &key(); } + bool operator==(const const_iterator& x) const { + return b_ == x.b_ && i_ == x.i_; + } + bool operator!=(const const_iterator& x) const { return !(*this == x); } + const_iterator& operator++() { + DCHECK(b_ != end_); + i_++; + SkipUnused(); + return *this; + } + const_iterator operator++(int /*indicates postfix*/) { + const_iterator tmp(*this); + ++*this; + return tmp; + } + + private: + friend class FlatSet; + Bucket* b_; + Bucket* end_; + uint32 i_; + + reference key() const { return b_->key(i_); } + void SkipUnused() { + while (b_ < end_) { + if (i_ >= Rep::kWidth) { + i_ = 0; + b_++; + } else if (b_->marker[i_] < 2) { + i_++; + } else { + break; + } + } + } + }; + + typedef const_iterator iterator; + + iterator begin() { return iterator(rep_.start(), rep_.limit()); } + iterator end() { return iterator(rep_.limit(), rep_.limit()); } + const_iterator begin() const { + return const_iterator(rep_.start(), rep_.limit()); + } + const_iterator end() const { + return const_iterator(rep_.limit(), rep_.limit()); + } + + size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } + iterator find(const Key& k) { + auto r = rep_.Find(k); + return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); + } + const_iterator find(const Key& k) const { + auto r = rep_.Find(k); + return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); + } + + std::pair insert(const Key& k) { return Insert(k); } + std::pair insert(Key&& k) { return Insert(std::move(k)); } + template + void insert(InputIter first, InputIter last) { + for (; first != last; ++first) { + insert(*first); + } + } + + template + std::pair emplace(Args&&... args) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(std::forward(args)...); + const bool inserted = !r.found; + return {iterator(r.b, rep_.limit(), r.index), inserted}; + } + + size_t erase(const Key& k) { + auto r = rep_.Find(k); + if (!r.found) return 0; + rep_.Erase(r.b, r.index); + return 1; + } + iterator erase(iterator pos) { + rep_.Erase(pos.b_, pos.i_); + ++pos; + return pos; + } + iterator erase(iterator pos, iterator last) { + for (; pos != last; ++pos) { + rep_.Erase(pos.b_, pos.i_); + } + return pos; + } + + std::pair equal_range(const Key& k) { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + std::pair equal_range(const Key& k) const { + auto pos = find(k); + if (pos == end()) { + return std::make_pair(pos, pos); + } else { + auto next = pos; + ++next; + return std::make_pair(pos, next); + } + } + + bool operator==(const FlatSet& x) const { + if (size() != x.size()) return false; + for (const auto& elem : x) { + auto i = find(elem); + if (i == end()) return false; + } + return true; + } + bool operator!=(const FlatSet& x) const { return !(*this == x); } + + // If key exists in the table, prefetch it. This is a hint, and may + // have no effect. + void prefetch_value(const Key& key) const { rep_.Prefetch(key); } + + private: + using Rep = internal::FlatRep; + + // Bucket stores kWidth triples. + // The data is organized as three parallel arrays to reduce padding. + struct Bucket { + uint8 marker[Rep::kWidth]; + + // Wrap keys in union to control construction and destruction. + union Storage { + Key key[Rep::kWidth]; + Storage() {} + ~Storage() {} + } storage; + + Key& key(uint32 i) { + DCHECK_GE(marker[i], 2); + return storage.key[i]; + } + void Destroy(uint32 i) { storage.key[i].Key::~Key(); } + void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); + } + void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { + new (&storage.key[i]) Key(src->storage.key[src_index]); + } + }; + + template + std::pair Insert(K&& k) { + rep_.MaybeResize(); + auto r = rep_.FindOrInsert(std::forward(k)); + const bool inserted = !r.found; + return {iterator(r.b, rep_.limit(), r.index), inserted}; + } + + Rep rep_; +}; + +} // namespace gtl +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_GTL_FLATSET_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/inlined_vector.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/inlined_vector.h new file mode 100644 index 000000000..90c006c56 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/inlined_vector.h @@ -0,0 +1,29 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_GTL_INLINED_VECTOR_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_GTL_INLINED_VECTOR_H_ + +#include "absl/container/inlined_vector.h" + +namespace demo_plugin { +namespace gtl { +using absl::InlinedVector; + +} // namespace gtl +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_GTL_INLINED_VECTOR_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/map_traits.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/map_traits.h new file mode 100644 index 000000000..217a4e660 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/map_traits.h @@ -0,0 +1,78 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_GTL_MAP_TRAITS_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_GTL_MAP_TRAITS_H_ + +#include + +// Traits classes for performing uniform lookup on different map value types. +// +// The access is computed as follows: +// +// 1. If T has a `first` or `second` field, use them. +// 2. Otherwise if it has `key()` or `value()` methods, use them. +// 3. Otherwise the program is ill-formed. +namespace demo_plugin { +namespace gtl { +namespace subtle { +namespace internal_map_traits { +struct Rank1 {}; +struct Rank0 : Rank1 {}; + +template +auto GetKey(V&& v, Rank0) -> decltype((std::forward(v).first)) { + return std::forward(v).first; +} +template +auto GetKey(V&& v, Rank1) -> decltype(std::forward(v).key()) { + return std::forward(v).key(); +} + +template +auto GetMapped(V&& v, Rank0) -> decltype((std::forward(v).second)) { + return std::forward(v).second; +} +template +auto GetMapped(V&& v, Rank1) -> decltype(std::forward(v).value()) { + return std::forward(v).value(); +} + +} // namespace internal_map_traits + +// Accesses the `key_type` from a `value_type`. +template +auto GetKey(V&& v) + -> decltype(internal_map_traits::GetKey(std::forward(v), + internal_map_traits::Rank0())) { + return internal_map_traits::GetKey(std::forward(v), + internal_map_traits::Rank0()); +} + +// Accesses the `mapped_type` from a `value_type`. +template +auto GetMapped(V&& v) + -> decltype(internal_map_traits::GetMapped(std::forward(v), + internal_map_traits::Rank0())) { + return internal_map_traits::GetMapped(std::forward(v), + internal_map_traits::Rank0()); +} + +} // namespace subtle +} // namespace gtl +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_GTL_MAP_TRAITS_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/map_util.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/map_util.h new file mode 100644 index 000000000..d99b9e8b5 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/gtl/map_util.h @@ -0,0 +1,214 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_GTL_MAP_UTIL_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_GTL_MAP_UTIL_H_ + +#include + +#include +#include +#include +#include + +#include "tensorflow_plugin/src/utils/gtl/map_traits.h" + +// This file provides utility functions for use with STL map-like data +// structures, such as std::map and hash_map. Some functions will also work with +// sets, such as ContainsKey(). +namespace demo_plugin { +namespace gtl { +// Returns a pointer to the const value associated with the given key if it +// exists, or NULL otherwise. +template +const typename Collection::value_type::second_type* FindOrNull( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Same as above but returns a pointer to the non-const value. +template +typename Collection::value_type::second_type* FindOrNull( + Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Returns the pointer value associated with the given key. If none is found, +// NULL is returned. The function is designed to be used with a map of keys to +// pointers. +// +// This function does not distinguish between a missing key and a key mapped +// to a NULL value. +template +typename Collection::value_type::second_type FindPtrOrNull( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return typename Collection::value_type::second_type(); + } + return it->second; +} + +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to the provided default value. +// +// WARNING: If a temporary object is passed as the default "value," +// this function will return a reference to that temporary object, +// which will be destroyed at the end of the statement. A common +// example: if you have a map with string values, and you pass a char* +// as the default "value," either use the returned value immediately +// or store it in a string (not string&). +template +const typename Collection::value_type::second_type& FindWithDefault( + const Collection& collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return value; + } + return it->second; +} + +// Inserts the given key-value pair into the collection. Returns true if and +// only if the key from the given pair didn't previously exist. Otherwise, the +// value in the map is replaced with the value from the given pair. +template +bool InsertOrUpdate(Collection* const collection, + const typename Collection::value_type& vt) { + std::pair ret = collection->insert(vt); + if (!ret.second) { + // update + ret.first->second = vt.second; + return false; + } + return true; +} + +// Same as above, except that the key and value are passed separately. +template +bool InsertOrUpdate(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return InsertOrUpdate(collection, + typename Collection::value_type(key, value)); +} + +// Inserts the given key and value into the given collection if and only if the +// given key did NOT already exist in the collection. If the key previously +// existed in the collection, the value is not changed. Returns true if the +// key-value pair was inserted; returns false if the key was already present. +template +bool InsertIfNotPresent(Collection* const collection, + const typename Collection::value_type& vt) { + return collection->insert(vt).second; +} + +// Same as above except the key and value are passed separately. +template +bool InsertIfNotPresent( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return InsertIfNotPresent(collection, + typename Collection::value_type(key, value)); +} + +// Looks up a given key and value pair in a collection and inserts the key-value +// pair if it's not already present. Returns a reference to the value associated +// with the key. +template +typename Collection::value_type::second_type& LookupOrInsert( + Collection* const collection, const typename Collection::value_type& vt) { + return collection->insert(vt).first->second; +} + +// Same as above except the key-value are passed separately. +template +typename Collection::value_type::second_type& LookupOrInsert( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return LookupOrInsert(collection, + typename Collection::value_type(key, value)); +} + +// Saves the reverse mapping into reverse. Returns true if values could all be +// inserted. +template +bool ReverseMap(const M& m, ReverseM* reverse) { + bool all_unique = true; + for (const auto& kv : m) { + if (!InsertOrUpdate(reverse, kv.second, kv.first)) { + all_unique = false; + } + } + return all_unique; +} + +// Like ReverseMap above, but returns its output m. Return type has to +// be specified explicitly. Example: +// M::M(...) : m_(...), r_(ReverseMap(m_)) {} +template +ReverseM ReverseMap(const M& m) { + typename std::remove_const::type reverse; + ReverseMap(m, &reverse); + return reverse; +} + +// Erases the m item identified by the given key, and returns the value +// associated with that key. It is assumed that the value (i.e., the +// mapped_type) is a pointer. Returns null if the key was not found in the +// m. +// +// Examples: +// std::map my_map; +// +// One line cleanup: +// delete EraseKeyReturnValuePtr(&my_map, "abc"); +// +// Use returned value: +// std::unique_ptr value_ptr( +// EraseKeyReturnValuePtr(&my_map, "abc")); +// if (value_ptr.get()) +// value_ptr->DoSomething(); +// +template +typename Collection::value_type::second_type EraseKeyReturnValuePtr( + Collection* collection, + const typename Collection::value_type::first_type& key) { + auto it = collection->find(key); + if (it == collection->end()) return nullptr; + auto v = gtl::subtle::GetMapped(*it); + collection->erase(it); + return v; +} + +} // namespace gtl +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_GTL_MAP_UTIL_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/integral_types.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/integral_types.h new file mode 100644 index 000000000..de71fa89b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/integral_types.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_INTEGRAL_TYPES_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_INTEGRAL_TYPES_H_ + +namespace demo_plugin { + +typedef signed char int8; +typedef short int16; +typedef int int32; + +// for compatible with int64_t +typedef long int64; + +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long long uint64; + +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_INTEGRAL_TYPES_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/logging.cc b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/logging.cc new file mode 100644 index 000000000..356085350 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/logging.cc @@ -0,0 +1,350 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_plugin/src/utils/logging.h" + +#include +#include +#include + +#include +#include + +#include "absl/base/internal/cycleclock.h" +#include "absl/base/internal/sysinfo.h" +#include "tensorflow_plugin/src/utils/env_time.h" +#include "tensorflow_plugin/src/utils/macros.h" + +namespace demo_plugin { + +void TFAddLogSink(TFLogSink *sink) { + // LogSink is not implemented. + // If necessary, one can add the log sink support as follows. + // 1. Define a global vector to keep track of all registered + // TFLogSink objects. Protect the global vector with mutex to make it + // thread-safe. + // 2. Add/remove elements from the global vector in TFAddLogSink + // and TFRemoveLogSink function + // 3. Add logic in LogMessage::GenerateLogMessage() below to dispatch log + // messages to all the registered log sinks. +} + +void TFRemoveLogSink(TFLogSink *sink) { + // LogSink is not implemented. +} + +namespace internal { +namespace { + +int ParseInteger(const char *str, size_t size) { + // Ideally we would use env_var / safe_strto64, but it is + // hard to use here without pulling in a lot of dependencies, + // so we use std:istringstream instead + string integer_str(str, size); + int level = 0; + level = std::stoi(integer_str); + return level; +} + +// Parse log level (int64) from environment variable (char*) +int64 LogLevelStrToInt(const char *tf_env_var_val) { + if (tf_env_var_val == nullptr) { + return 0; + } + return ParseInteger(tf_env_var_val, strlen(tf_env_var_val)); +} + +// Using StringPiece breaks Windows build. +struct StringData { + struct Hasher { + size_t operator()(const StringData &sdata) const { + // For dependency reasons, we cannot use hash.h here. Use DBJHash instead. + size_t hash = 5381; + const char *data = sdata.data; + for (const char *top = data + sdata.size; data < top; ++data) { + hash = ((hash << 5) + hash) + (*data); + } + return hash; + } + }; + + StringData() = default; + StringData(const char *data, size_t size) : data(data), size(size) {} + + bool operator==(const StringData &rhs) const { + return size == rhs.size && memcmp(data, rhs.data, size) == 0; + } + + const char *data = nullptr; + size_t size = 0; +}; + +using VmoduleMap = std::unordered_map; + +// Returns a mapping from module name to VLOG level, derived from the +// TF_CPP_VMODULE environment variable; ownership is transferred to the caller. +VmoduleMap *VmodulesMapFromEnv() { + // The value of the env var is supposed to be of the form: + // "foo=1,bar=2,baz=3" + const char *env = getenv("TF_CPP_VMODULE"); + if (env == nullptr) { + // If there is no TF_CPP_VMODULE configuration (most common case), return + // nullptr so that the ShouldVlogModule() API can fast bail out of it. + return nullptr; + } + // The memory returned by getenv() can be invalidated by following getenv() or + // setenv() calls. And since we keep references to it in the VmoduleMap in + // form of StringData objects, make a copy of it. + const char *env_data = strdup(env); + VmoduleMap *result = new VmoduleMap(); + while (true) { + const char *eq = strchr(env_data, '='); + if (eq == nullptr) { + break; + } + const char *after_eq = eq + 1; + + // Comma either points at the next comma delimiter, or at a null terminator. + // We check that the integer we parse ends at this delimiter. + const char *comma = strchr(after_eq, ','); + const char *new_env_data; + if (comma == nullptr) { + comma = strchr(after_eq, '\0'); + new_env_data = comma; + } else { + new_env_data = comma + 1; + } + (*result)[StringData(env_data, eq - env_data)] = + ParseInteger(after_eq, comma - after_eq); + env_data = new_env_data; + } + return result; +} + +bool EmitThreadIdFromEnv() { + const char *tf_env_var_val = getenv("TF_CPP_LOG_THREAD_ID"); + return tf_env_var_val == nullptr + ? false + : ParseInteger(tf_env_var_val, strlen(tf_env_var_val)) != 0; +} + +} // namespace + +int64 MinLogLevelFromEnv() { + // We don't want to print logs during fuzzing as that would slow fuzzing down + // by almost 2x. So, if we are in fuzzing mode (not just running a test), we + // return a value so that nothing is actually printed. Since LOG uses >= + // (see ~LogMessage in this file) to see if log messages need to be printed, + // the value we're interested on to disable printing is the maximum severity. + // See also http://llvm.org/docs/LibFuzzer.html#fuzzer-friendly-build-mode +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + return tensorflow::NUM_SEVERITIES; +#else + // Read TENSORFLOW_PLUGIN env var first. + const char *tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL"); + if (tf_env_var_val == nullptr) + tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL"); + return LogLevelStrToInt(tf_env_var_val); +#endif +} + +int64 MinVLogLevelFromEnv() { + // We don't want to print logs during fuzzing as that would slow fuzzing down + // by almost 2x. So, if we are in fuzzing mode (not just running a test), we + // return a value so that nothing is actually printed. Since VLOG uses <= + // (see VLOG_IS_ON in logging.h) to see if log messages need to be printed, + // the value we're interested on to disable printing is 0. + // See also http://llvm.org/docs/LibFuzzer.html#fuzzer-friendly-build-mode +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + return 0; +#else + // Read TENSORFLOW_PLUGIN env var first. + const char *tf_env_var_val = getenv("TF_CPP_MAX_VLOG_LEVEL"); + if (tf_env_var_val == nullptr) + tf_env_var_val = getenv("TF_CPP_MAX_VLOG_LEVEL"); + return LogLevelStrToInt(tf_env_var_val); +#endif +} + +LogMessage::LogMessage(const char *fname, int line, int severity) + : fname_(fname), line_(line), severity_(severity) {} + +LogMessage &LogMessage::AtLocation(const char *fname, int line) { + fname_ = fname; + line_ = line; + return *this; +} + +LogMessage::~LogMessage() { + // Read the min log level once during the first call to logging. + static int64 min_log_level = MinLogLevelFromEnv(); + if (severity_ >= min_log_level) { + GenerateLogMessage(); + } +} + +void LogMessage::GenerateLogMessage() { + static bool log_thread_id = EmitThreadIdFromEnv(); + uint64 now_micros = EnvTime::NowMicros(); + time_t now_seconds = static_cast(now_micros / 1000000); + int32 micros_remainder = static_cast(now_micros % 1000000); + const size_t time_buffer_size = 30; + char time_buffer[time_buffer_size]; + strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", + localtime(&now_seconds)); + const size_t tid_buffer_size = 10; + char tid_buffer[tid_buffer_size] = ""; + if (log_thread_id) { + snprintf(tid_buffer, sizeof(tid_buffer), " %7u", + absl::base_internal::GetTID()); + } + // TODO(jeff,sanjay): Replace this with something that logs through the env. + fprintf(stderr, "%s.%06d: %c%s %s:%d] %s\n", time_buffer, micros_remainder, + "IWEF"[severity_], tid_buffer, fname_, line_, str().c_str()); +} + +int64 LogMessage::MinVLogLevel() { + static int64 min_vlog_level = MinVLogLevelFromEnv(); + return min_vlog_level; +} + +bool LogMessage::VmoduleActivated(const char *fname, int level) { + if (level <= MinVLogLevel()) { + return true; + } + static VmoduleMap *vmodules = VmodulesMapFromEnv(); + if (TF_PREDICT_TRUE(vmodules == nullptr)) { + return false; + } + const char *last_slash = strrchr(fname, '/'); + const char *module_start = last_slash == nullptr ? fname : last_slash + 1; + const char *dot_after = strchr(module_start, '.'); + const char *module_limit = + dot_after == nullptr ? strchr(fname, '\0') : dot_after; + StringData module(module_start, module_limit - module_start); + auto it = vmodules->find(module); + return it != vmodules->end() && it->second >= level; +} + +LogMessageFatal::LogMessageFatal(const char *file, int line) + : LogMessage(file, line, FATAL) {} +LogMessageFatal::~LogMessageFatal() { + // abort() ensures we don't return (we promised we would not via + // ATTRIBUTE_NORETURN). + GenerateLogMessage(); + abort(); +} + +void LogString(const char *fname, int line, int severity, + const string &message) { + LogMessage(fname, line, severity) << message; +} + +template <> void MakeCheckOpValueString(std::ostream *os, const char &v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "char value " << static_cast(v); + } +} + +template <> +void MakeCheckOpValueString(std::ostream *os, const signed char &v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "signed char value " << static_cast(v); + } +} + +template <> +void MakeCheckOpValueString(std::ostream *os, const unsigned char &v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "unsigned char value " << static_cast(v); + } +} + +#if LANG_CXX11 +template <> +void MakeCheckOpValueString(std::ostream *os, const std::nullptr_t &v) { + (*os) << "nullptr"; +} +#endif + +CheckOpMessageBuilder::CheckOpMessageBuilder(const char *exprtext) + : stream_(new std::ostringstream) { + *stream_ << "Check failed: " << exprtext << " ("; +} + +CheckOpMessageBuilder::~CheckOpMessageBuilder() { delete stream_; } + +std::ostream *CheckOpMessageBuilder::ForVar2() { + *stream_ << " vs. "; + return stream_; +} + +string *CheckOpMessageBuilder::NewString() { + *stream_ << ")"; + return new string(stream_->str()); +} + +namespace { +// The following code behaves like AtomicStatsCounter::LossyAdd() for +// speed since it is fine to lose occasional updates. +// Returns old value of *counter. +uint32 LossyIncrement(std::atomic *counter) { + const uint32 value = counter->load(std::memory_order_relaxed); + counter->store(value + 1, std::memory_order_relaxed); + return value; +} +} // namespace + +bool LogEveryNState::ShouldLog(int n) { + return n != 0 && (LossyIncrement(&counter_) % n) == 0; +} + +bool LogFirstNState::ShouldLog(int n) { + const int counter_value = + static_cast(counter_.load(std::memory_order_relaxed)); + if (counter_value < n) { + counter_.store(counter_value + 1, std::memory_order_relaxed); + return true; + } + return false; +} + +bool LogEveryPow2State::ShouldLog(int ignored) { + const uint32 new_value = LossyIncrement(&counter_) + 1; + return (new_value & (new_value - 1)) == 0; +} + +bool LogEveryNSecState::ShouldLog(double seconds) { + LossyIncrement(&counter_); + const int64 now_cycles = absl::base_internal::CycleClock::Now(); + int64 next_cycles = next_log_time_cycles_.load(std::memory_order_relaxed); + do { + if (now_cycles <= next_cycles) + return false; + } while (!next_log_time_cycles_.compare_exchange_weak( + next_cycles, + now_cycles + seconds * absl::base_internal::CycleClock::Frequency(), + std::memory_order_relaxed, std::memory_order_relaxed)); + return true; +} + +} // namespace internal +} // namespace demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/logging.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/logging.h new file mode 100644 index 000000000..488ecf0d3 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/logging.h @@ -0,0 +1,521 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_LOGGING_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_LOGGING_H_ + +#include +#include +#include +#include + +#include "absl/base/log_severity.h" +#include "absl/strings/string_view.h" +#include "tensorflow_plugin/src//utils/macros.h" +#include "tensorflow_plugin/src/utils/integral_types.h" + +#undef ERROR + +namespace demo_plugin { +const int INFO = 0; // base_logging::INFO; +const int WARNING = 1; // base_logging::WARNING; +const int ERROR = 2; // base_logging::ERROR; +const int FATAL = 3; // base_logging::FATAL; +const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES; + +namespace internal { + +using std::string; + +class LogMessage : public std::basic_ostringstream { +public: + LogMessage(const char *fname, int line, int severity); + ~LogMessage() override; + + // Change the location of the log message. + LogMessage &AtLocation(const char *fname, int line); + + // Returns the minimum log level for VLOG statements. + // E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output, + // but VLOG(3) will not. Defaults to 0. + static int64 MinVLogLevel(); + + // Returns whether VLOG level lvl is activated for the file fname. + // + // E.g. if the environment variable TF_CPP_VMODULE contains foo=3 and fname is + // foo.cc and lvl is <= 3, this will return true. It will also return true if + // the level is lower or equal to TF_CPP_MIN_VLOG_LEVEL (default zero). + // + // It is expected that the result of this query will be cached in the VLOG-ing + // call site to avoid repeated lookups. This routine performs a hash-map + // access against the VLOG-ing specification provided by the env var. + static bool VmoduleActivated(const char *fname, int level); + +protected: + void GenerateLogMessage(); + +private: + const char *fname_; + int line_; + int severity_; +}; + +// Uses the lower operator & precedence to voidify a LogMessage reference, so +// that the ternary VLOG() implementation is balanced, type wise. +struct Voidifier { + template void operator&(const T &)const {} +}; + +// LogMessageFatal ensures the process will exit in failure after +// logging this message. +class LogMessageFatal : public LogMessage { +public: + LogMessageFatal(const char *file, int line) TF_ATTRIBUTE_COLD; + TF_ATTRIBUTE_NORETURN ~LogMessageFatal() override; +}; + +// LogMessageNull supports the DVLOG macro by simply dropping any log messages. +class LogMessageNull : public std::basic_ostringstream { +public: + LogMessageNull() {} + ~LogMessageNull() override {} +}; + +#define _TF_LOG_INFO \ + ::demo_plugin::internal::LogMessage(__FILE__, __LINE__, ::demo_plugin::INFO) +#define _TF_LOG_WARNING \ + ::demo_plugin::internal::LogMessage(__FILE__, __LINE__, \ + ::demo_plugin::WARNING) +#define _TF_LOG_ERROR \ + ::demo_plugin::internal::LogMessage(__FILE__, __LINE__, ::demo_plugin::ERROR) +#define _TF_LOG_FATAL \ + ::demo_plugin::internal::LogMessageFatal(__FILE__, __LINE__) + +#define _TF_LOG_QFATAL _TF_LOG_FATAL + +#define LOG(severity) _TF_LOG_##severity + +#ifdef IS_MOBILE_PLATFORM + +// Turn VLOG off when under mobile devices for considerations of binary size. +#define VLOG_IS_ON(lvl) ((lvl) <= 0) + +#else + +// Otherwise, set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level +// of VLOG, or TF_CPP_VMODULE to set the minimum log level for individual +// translation units. +#define VLOG_IS_ON(lvl) \ + (([](int level, const char *fname) { \ + static const bool vmodule_activated = \ + ::demo_plugin::internal::LogMessage::VmoduleActivated(fname, level); \ + return vmodule_activated; \ + })(lvl, __FILE__)) + +#endif + +#define VLOG(level) \ + TF_PREDICT_TRUE(!VLOG_IS_ON(level)) \ + ? (void)0 \ + : ::demo_plugin::internal::Voidifier() & \ + ::demo_plugin::internal::LogMessage(__FILE__, __LINE__, \ + demo_plugin::INFO) + +// `DVLOG` behaves like `VLOG` in debug mode (i.e. `#ifndef NDEBUG`). +// Otherwise, it compiles away and does nothing. +#ifndef NDEBUG +#define DVLOG VLOG +#else +#define DVLOG(verbose_level) \ + while (false && (verbose_level) > 0) \ + ::demo_plugin::internal::LogMessageNull() +#endif + +class LogEveryNState { +public: + bool ShouldLog(int n); + uint32_t counter() { return counter_.load(std::memory_order_relaxed); } + +private: + std::atomic counter_{0}; +}; + +class LogFirstNState { +public: + bool ShouldLog(int n); + uint32 counter() { return counter_.load(std::memory_order_relaxed); } + +private: + std::atomic counter_{0}; +}; + +class LogEveryPow2State { +public: + bool ShouldLog(int ignored); + uint32 counter() { return counter_.load(std::memory_order_relaxed); } + +private: + std::atomic counter_{0}; +}; + +class LogEveryNSecState { +public: + bool ShouldLog(double seconds); + uint32 counter() { return counter_.load(std::memory_order_relaxed); } + +private: + std::atomic counter_{0}; + // Cycle count according to CycleClock that we should next log at. + std::atomic next_log_time_cycles_{0}; +}; + +// This macro has a lot going on! +// +// * A local static (`logging_internal_stateful_condition_state`) is +// declared in a scope such that each `LOG_EVERY_N` (etc.) line has its own +// state. +// * `COUNTER`, the third variable, is used to support `<< COUNTER`. It is not +// mangled, so shadowing can be a problem, albeit more of a +// shoot-yourself-in-the-foot one. Don't name your variables `COUNTER`. +// * A single for loop can declare state and also test +// `condition && state.ShouldLog()`, but there's no way to constrain it to run +// only once (or not at all) without declaring another variable. The outer +// for-loop declares this variable (`do_log`). +// * Using for loops instead of if statements means there's no risk of an +// ambiguous dangling else statement. +#define LOGGING_INTERNAL_STATEFUL_CONDITION(kind, condition, arg) \ + for (bool logging_internal_stateful_condition_do_log(condition); \ + logging_internal_stateful_condition_do_log; \ + logging_internal_stateful_condition_do_log = false) \ + for (static ::demo_plugin::internal::Log##kind##State \ + logging_internal_stateful_condition_state; \ + logging_internal_stateful_condition_do_log && \ + logging_internal_stateful_condition_state.ShouldLog(arg); \ + logging_internal_stateful_condition_do_log = false) \ + for (const uint32_t COUNTER ABSL_ATTRIBUTE_UNUSED = \ + logging_internal_stateful_condition_state.counter(); \ + logging_internal_stateful_condition_do_log; \ + logging_internal_stateful_condition_do_log = false) + +// An instance of `LOG_EVERY_N` increments a hidden zero-initialized counter +// every time execution passes through it and logs the specified message when +// the counter's value is a multiple of `n`, doing nothing otherwise. Each +// instance has its own counter. The counter's value can be logged by streaming +// the symbol `COUNTER`. `LOG_EVERY_N` is thread-safe. +// Example: +// +// for (const auto& user : all_users) { +// LOG_EVERY_N(INFO, 1000) << "Processing user #" << COUNTER; +// ProcessUser(user); +// } +#define LOG_EVERY_N(severity, n) \ + LOGGING_INTERNAL_STATEFUL_CONDITION(EveryN, true, n) \ + LOG(severity) +// `LOG_FIRST_N` behaves like `LOG_EVERY_N` except that the specified message is +// logged when the counter's value is less than `n`. `LOG_FIRST_N` is +// thread-safe. +#define LOG_FIRST_N(severity, n) \ + LOGGING_INTERNAL_STATEFUL_CONDITION(FirstN, true, n) \ + LOG(severity) +// `LOG_EVERY_POW_2` behaves like `LOG_EVERY_N` except that the specified +// message is logged when the counter's value is a power of 2. +// `LOG_EVERY_POW_2` is thread-safe. +#define LOG_EVERY_POW_2(severity) \ + LOGGING_INTERNAL_STATEFUL_CONDITION(EveryPow2, true, 0) \ + LOG(severity) +// An instance of `LOG_EVERY_N_SEC` uses a hidden state variable to log the +// specified message at most once every `n_seconds`. A hidden counter of +// executions (whether a message is logged or not) is also maintained and can be +// logged by streaming the symbol `COUNTER`. `LOG_EVERY_N_SEC` is thread-safe. +// Example: +// +// LOG_EVERY_N_SEC(INFO, 2.5) << "Got " << COUNTER << " cookies so far"; +#define LOG_EVERY_N_SEC(severity, n_seconds) \ + LOGGING_INTERNAL_STATEFUL_CONDITION(EveryNSec, true, n_seconds) \ + LOG(severity) + +// CHECK dies with a fatal error if condition is not true. It is *not* +// controlled by NDEBUG, so the check will be executed regardless of +// compilation mode. Therefore, it is safe to do things like: +// CHECK(fp->Write(x) == 4) +#define CHECK(condition) \ + if (TF_PREDICT_FALSE(!(condition))) \ + LOG(FATAL) << "Check failed: " #condition " " + +// Function is overloaded for integral types to allow static const +// integrals declared in classes and not defined to be used as arguments to +// CHECK* macros. It's not encouraged though. +template inline const T &GetReferenceableValue(const T &t) { + return t; +} +inline char GetReferenceableValue(char t) { return t; } +inline unsigned char GetReferenceableValue(unsigned char t) { return t; } +inline signed char GetReferenceableValue(signed char t) { return t; } +inline int16 GetReferenceableValue(int16 t) { return t; } +inline uint16 GetReferenceableValue(uint16 t) { return t; } +inline int GetReferenceableValue(int t) { return t; } +inline unsigned int GetReferenceableValue(unsigned int t) { return t; } +inline int64 GetReferenceableValue(int64 t) { return t; } +inline uint64 GetReferenceableValue(uint64 t) { return t; } + +// This formats a value for a failing CHECK_XX statement. Ordinarily, +// it uses the definition for operator<<, with a few special cases below. +template +inline void MakeCheckOpValueString(std::ostream *os, const T &v) { + (*os) << v; +} + +// Overrides for char types provide readable values for unprintable +// characters. +template <> void MakeCheckOpValueString(std::ostream *os, const char &v); +template <> void MakeCheckOpValueString(std::ostream *os, const signed char &v); +template <> +void MakeCheckOpValueString(std::ostream *os, const unsigned char &v); + +#if LANG_CXX11 +// We need an explicit specialization for std::nullptr_t. +template <> +void MakeCheckOpValueString(std::ostream *os, const std::nullptr_t &v); +#endif + +// A container for a string pointer which can be evaluated to a bool - +// true iff the pointer is non-NULL. +struct CheckOpString { + explicit CheckOpString(string *str) : str_(str) {} + // No destructor: if str_ is non-NULL, we're about to LOG(FATAL), + // so there's no point in cleaning up str_. + explicit operator bool() const { return TF_PREDICT_FALSE(str_ != nullptr); } + string *str_; +}; + +// Build the error message string. Specify no inlining for code size. +template +string *MakeCheckOpString(const T1 &v1, const T2 &v2, + const char *exprtext) TF_ATTRIBUTE_NOINLINE; + +// A helper class for formatting "expr (V1 vs. V2)" in a CHECK_XX +// statement. See MakeCheckOpString for sample usage. Other +// approaches were considered: use of a template method (e.g., +// base::BuildCheckOpString(exprtext, base::Print, &v1, +// base::Print, &v2), however this approach has complications +// related to volatile arguments and function-pointer arguments). +class CheckOpMessageBuilder { +public: + // Inserts "exprtext" and " (" to the stream. + explicit CheckOpMessageBuilder(const char *exprtext); + // Deletes "stream_". + ~CheckOpMessageBuilder(); + // For inserting the first variable. + std::ostream *ForVar1() { return stream_; } + // For inserting the second variable (adds an intermediate " vs. "). + std::ostream *ForVar2(); + // Get the result (inserts the closing ")"). + string *NewString(); + +private: + std::ostringstream *stream_; +}; + +template +string *MakeCheckOpString(const T1 &v1, const T2 &v2, const char *exprtext) { + CheckOpMessageBuilder comb(exprtext); + MakeCheckOpValueString(comb.ForVar1(), v1); + MakeCheckOpValueString(comb.ForVar2(), v2); + return comb.NewString(); +} + +// Helper functions for CHECK_OP macro. +// The (int, int) specialization works around the issue that the compiler +// will not instantiate the template version of the function on values of +// unnamed enum type - see comment below. +// The (size_t, int) and (int, size_t) specialization are to handle unsigned +// comparison errors while still being thorough with the comparison. +#define TF_DEFINE_CHECK_OP_IMPL(name, op) \ + template \ + inline string *name##Impl(const T1 &v1, const T2 &v2, \ + const char *exprtext) { \ + if (TF_PREDICT_TRUE(v1 op v2)) \ + return NULL; \ + else \ + return ::demo_plugin::internal::MakeCheckOpString(v1, v2, exprtext); \ + } \ + inline string *name##Impl(int v1, int v2, const char *exprtext) { \ + return name##Impl(v1, v2, exprtext); \ + } \ + inline string *name##Impl(const size_t v1, const int v2, \ + const char *exprtext) { \ + if (TF_PREDICT_FALSE(v2 < 0)) { \ + return ::demo_plugin::internal::MakeCheckOpString(v1, v2, exprtext); \ + } \ + return name##Impl(v1, v2, exprtext); \ + } \ + inline string *name##Impl(const int v1, const size_t v2, \ + const char *exprtext) { \ + if (TF_PREDICT_FALSE(v2 >= std::numeric_limits::max())) { \ + return ::demo_plugin::internal::MakeCheckOpString(v1, v2, exprtext); \ + } \ + const size_t uval = (size_t)((unsigned)v2); \ + return name##Impl(v1, uval, exprtext); \ + } + +// We use the full name Check_EQ, Check_NE, etc. in case the file including +// base/logging.h provides its own #defines for the simpler names EQ, NE, etc. +// This happens if, for example, those are used as token names in a +// yacc grammar. +TF_DEFINE_CHECK_OP_IMPL(Check_EQ, + ==) // Compilation error with CHECK_EQ(NULL, x)? +TF_DEFINE_CHECK_OP_IMPL(Check_NE, !=) // Use CHECK(x == NULL) instead. +TF_DEFINE_CHECK_OP_IMPL(Check_LE, <=) +TF_DEFINE_CHECK_OP_IMPL(Check_LT, <) +TF_DEFINE_CHECK_OP_IMPL(Check_GE, >=) +TF_DEFINE_CHECK_OP_IMPL(Check_GT, >) +#undef TF_DEFINE_CHECK_OP_IMPL + +// In optimized mode, use CheckOpString to hint to compiler that +// the while condition is unlikely. +#define CHECK_OP_LOG(name, op, val1, val2) \ + while (::demo_plugin::internal::CheckOpString _result{ \ + ::demo_plugin::internal::name##Impl( \ + ::demo_plugin::internal::GetReferenceableValue(val1), \ + ::demo_plugin::internal::GetReferenceableValue(val2), \ + #val1 " " #op " " #val2)}) \ + ::demo_plugin::internal::LogMessageFatal(__FILE__, __LINE__) \ + << *(_result.str_) + +#define CHECK_OP(name, op, val1, val2) CHECK_OP_LOG(name, op, val1, val2) + +// CHECK_EQ/NE/... +#define CHECK_EQ(val1, val2) CHECK_OP(Check_EQ, ==, val1, val2) +#define CHECK_NE(val1, val2) CHECK_OP(Check_NE, !=, val1, val2) +#define CHECK_LE(val1, val2) CHECK_OP(Check_LE, <=, val1, val2) +#define CHECK_LT(val1, val2) CHECK_OP(Check_LT, <, val1, val2) +#define CHECK_GE(val1, val2) CHECK_OP(Check_GE, >=, val1, val2) +#define CHECK_GT(val1, val2) CHECK_OP(Check_GT, >, val1, val2) +#define CHECK_NOTNULL(val) \ + ::demo_plugin::internal::CheckNotNull(__FILE__, __LINE__, \ + "'" #val "' Must be non NULL", (val)) + +#ifndef NDEBUG +// DCHECK_EQ/NE/... +#define DCHECK(condition) CHECK(condition) +#define DCHECK_EQ(val1, val2) CHECK_EQ(val1, val2) +#define DCHECK_NE(val1, val2) CHECK_NE(val1, val2) +#define DCHECK_LE(val1, val2) CHECK_LE(val1, val2) +#define DCHECK_LT(val1, val2) CHECK_LT(val1, val2) +#define DCHECK_GE(val1, val2) CHECK_GE(val1, val2) +#define DCHECK_GT(val1, val2) CHECK_GT(val1, val2) + +#else + +#define DCHECK(condition) \ + while (false && (condition)) \ + LOG(FATAL) + +// NDEBUG is defined, so DCHECK_EQ(x, y) and so on do nothing. +// However, we still want the compiler to parse x and y, because +// we don't want to lose potentially useful errors and warnings. +// _DCHECK_NOP is a helper, and should not be used outside of this file. +#define _TF_DCHECK_NOP(x, y) \ + while (false && ((void)(x), (void)(y), 0)) \ + LOG(FATAL) + +#define DCHECK_EQ(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_NE(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_LE(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_LT(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_GE(x, y) _TF_DCHECK_NOP(x, y) +#define DCHECK_GT(x, y) _TF_DCHECK_NOP(x, y) + +#endif + +// These are for when you don't want a CHECK failure to print a verbose +// stack trace. The implementation of CHECK* in this file already doesn't. +#define QCHECK(condition) CHECK(condition) +#define QCHECK_EQ(x, y) CHECK_EQ(x, y) +#define QCHECK_NE(x, y) CHECK_NE(x, y) +#define QCHECK_LE(x, y) CHECK_LE(x, y) +#define QCHECK_LT(x, y) CHECK_LT(x, y) +#define QCHECK_GE(x, y) CHECK_GE(x, y) +#define QCHECK_GT(x, y) CHECK_GT(x, y) + +template +T &&CheckNotNull(const char *file, int line, const char *exprtext, T &&t) { + if (t == nullptr) { + LogMessageFatal(file, line) << string(exprtext); + } + return std::forward(t); +} + +int64 MinLogLevelFromEnv(); + +int64 MinVLogLevelFromEnv(); + +} // namespace internal + +// LogSink support adapted from //base/logging.h +// +// `LogSink` is an interface which can be extended to intercept and process +// all log messages. LogSink implementations must be thread-safe. A single +// instance will be called from whichever thread is performing a logging +// operation. +class TFLogEntry { + static absl::LogSeverity AsAbslLogSeverity(int severity) { + return static_cast(severity); + } + +public: + explicit TFLogEntry(int severity, absl::string_view log_line) + : severity_(AsAbslLogSeverity(severity)), log_line_(log_line) {} + + absl::LogSeverity log_severity() const { return severity_; } + std::string ToString() const { return std::string(log_line_); } + +private: + const absl::LogSeverity severity_; + const absl::string_view log_line_; +}; + +class TFLogSink { +public: + virtual ~TFLogSink() = default; + + // `Send` is called synchronously during the log statement. The logging + // module guarantees not to call `Send` concurrently on the same log sink. + // Implementations should be careful not to call`LOG` or `CHECK` or take + // any locks that might be held by the `LOG` caller, to avoid deadlock. + // + // `e` is guaranteed to remain valid until the subsequent call to + // `WaitTillSent` completes, so implementations may store a pointer to or + // copy of `e` (e.g. in a thread local variable) for use in `WaitTillSent`. + virtual void Send(const TFLogEntry &entry) = 0; + + // `WaitTillSent` blocks the calling thread (the thread that generated a log + // message) until the sink has finished processing the log message. + // `WaitTillSent` is called once per log message, following the call to + // `Send`. This may be useful when log messages are buffered or processed + // asynchronously by an expensive log sink. + // The default implementation returns immediately. Like `Send`, + // implementations should be careful not to call `LOG` or `CHECK or take any + // locks that might be held by the `LOG` caller, to avoid deadlock. + virtual void WaitTillSent() {} +}; + +// Add or remove a `LogSink` as a consumer of logging data. Thread-safe. +void TFAddLogSink(TFLogSink *sink); +void TFRemoveLogSink(TFLogSink *sink); + +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_LOGGING_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/macros.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/macros.h new file mode 100644 index 000000000..aecdb01ca --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/macros.h @@ -0,0 +1,152 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_MACROS_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_MACROS_H_ + +// Compiler attributes +#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG) +// Compiler supports GCC-style attributes +#define TF_ATTRIBUTE_NORETURN __attribute__((noreturn)) +#define TF_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) +#define TF_ATTRIBUTE_NOINLINE __attribute__((noinline)) +#define TF_ATTRIBUTE_UNUSED __attribute__((unused)) +#define TF_ATTRIBUTE_COLD __attribute__((cold)) +#define TF_ATTRIBUTE_WEAK __attribute__((weak)) +#define TF_PACKED __attribute__((packed)) +#define TF_MUST_USE_RESULT __attribute__((warn_unused_result)) +#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__printf__, string_index, first_to_check))) +#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__scanf__, string_index, first_to_check))) +#elif defined(_MSC_VER) +// Non-GCC equivalents +#define TF_ATTRIBUTE_NORETURN __declspec(noreturn) +#define TF_ATTRIBUTE_ALWAYS_INLINE __forceinline +#define TF_ATTRIBUTE_NOINLINE +#define TF_ATTRIBUTE_UNUSED +#define TF_ATTRIBUTE_COLD +#define TF_ATTRIBUTE_WEAK +#define TF_MUST_USE_RESULT +#define TF_PACKED +#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) +#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) +#else +// Non-GCC equivalents +#define TF_ATTRIBUTE_NORETURN +#define TF_ATTRIBUTE_ALWAYS_INLINE +#define TF_ATTRIBUTE_NOINLINE +#define TF_ATTRIBUTE_UNUSED +#define TF_ATTRIBUTE_COLD +#define TF_ATTRIBUTE_WEAK +#define TF_MUST_USE_RESULT +#define TF_PACKED +#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) +#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) +#endif + +// Control visibility outside .so +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_EXPORT __declspec(dllexport) +#else +#define TF_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 + +#ifdef __has_builtin +#define TF_HAS_BUILTIN(x) __has_builtin(x) +#else +#define TF_HAS_BUILTIN(x) 0 +#endif + +// C++11-style attributes (N2761) +#if defined(__has_cpp_attribute) +// Safely checks if an attribute is supported. Equivalent to +// ABSL_HAVE_CPP_ATTRIBUTE. +#define TF_HAS_CPP_ATTRIBUTE(n) __has_cpp_attribute(n) +#else +#define TF_HAS_CPP_ATTRIBUTE(n) 0 +#endif + +// [[clang::annotate("x")]] allows attaching custom strings (e.g. "x") to +// declarations (variables, functions, fields, etc.) for use by tools. They are +// represented in the Clang AST (as AnnotateAttr nodes) and in LLVM IR, but not +// in final output. +#if TF_HAS_CPP_ATTRIBUTE(clang::annotate) +#define TF_ATTRIBUTE_ANNOTATE(str) [[clang::annotate(str)]] +#else +#define TF_ATTRIBUTE_ANNOTATE(str) +#endif + +// Compilers can be told that a certain branch is not likely to be taken +// (for instance, a CHECK failure), and use that information in static +// analysis. Giving it this information can help it optimize for the +// common case in the absence of better information (ie. +// -fprofile-arcs). +#if TF_HAS_BUILTIN(__builtin_expect) || (defined(__GNUC__) && __GNUC__ >= 3) +#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else +#define TF_PREDICT_FALSE(x) (x) +#define TF_PREDICT_TRUE(x) (x) +#endif + +// A macro to disallow the copy constructor and operator= functions +// This is usually placed in the private: declarations for a class. +#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName &) = delete; \ + void operator=(const TypeName &) = delete + +// The TF_ARRAYSIZE(arr) macro returns the # of elements in an array arr. +// +// The expression TF_ARRAYSIZE(a) is a compile-time constant of type +// size_t. +#define TF_ARRAYSIZE(a) \ + ((sizeof(a) / sizeof(*(a))) / \ + static_cast(!(sizeof(a) % sizeof(*(a))))) + +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ + (defined(_MSC_VER) && _MSC_VER >= 1900) +// Define this to 1 if the code is compiled in C++11 mode; leave it +// undefined otherwise. Do NOT define it to 0 -- that causes +// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'. +#define LANG_CXX11 1 +#endif + +#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning) +#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough") +#define TF_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT +#endif +#endif + +#ifndef TF_FALLTHROUGH_INTENDED +#define TF_FALLTHROUGH_INTENDED \ + do { \ + } while (0) +#endif + +namespace demo_plugin { +namespace internal { +template void remove_unused_variable_compiler_warning(const T &){}; +} +} // namespace demo_plugin +#define TF_UNUSED_VARIABLE(x) \ + tensorflow::internal::remove_unused_variable_compiler_warning(x) + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_MACROS_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/node_def.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/node_def.proto new file mode 100644 index 000000000..68d5a8215 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/node_def.proto @@ -0,0 +1,88 @@ +syntax = "proto3"; + +package demo_plugin; + +import "tensorflow_plugin/src/utils/attr_value.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/node_def_go_proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_>./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // TODO(josh11b): Add some examples here showing best practices. + map attr = 5; + + message ExperimentalDebugInfo { + // Opaque string inserted into error messages created by the runtime. + // + // This is intended to store the list of names of the nodes from the + // original graph that this node was derived. For example if this node, say + // C, was result of a fusion of 2 nodes A and B, then 'original_node' would + // be {A, B}. This information can be used to map errors originating at the + // current node to some top level source code. + repeated string original_node_names = 1; + + // This is intended to store the list of names of the functions from the + // original graph that this node was derived. For example if this node, say + // C, was result of a fusion of node A in function FA and node B in function + // FB, then `original_funcs` would be {FA, FB}. If the node is in the top + // level graph, the `original_func` is empty. This information, with the + // `original_node_names` can be used to map errors originating at the + // current ndoe to some top level source code. + repeated string original_func_names = 2; + } + + // This stores debug information associated with the node. + ExperimentalDebugInfo experimental_debug_info = 6; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/numeric_types.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/numeric_types.h new file mode 100644 index 000000000..f7e1fedf5 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/numeric_types.h @@ -0,0 +1,94 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_NUMERIC_TYPES_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_NUMERIC_TYPES_H_ + +#include "tensorflow_plugin/src/utils/tstring.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include +// Disable clang-format to prevent 'FixedPoint' header from being included +// before 'Tensor' header on which it depends. +// clang-format off +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" +// clang-format on + +namespace demo_plugin { + +// Single precision complex. +typedef std::complex complex64; +// Double precision complex. +typedef std::complex complex128; + +// We use Eigen's QInt implementations for our quantized int types. +typedef Eigen::QInt8 qint8; +typedef Eigen::QUInt8 quint8; +typedef Eigen::QInt32 qint32; +typedef Eigen::QInt16 qint16; +typedef Eigen::QUInt16 quint16; + +} // namespace demo_plugin + +static inline Eigen::bfloat16 FloatToBFloat16(float float_val) { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return *reinterpret_cast( + reinterpret_cast(&float_val)); +#else + return *reinterpret_cast( + &(reinterpret_cast(&float_val)[1])); +#endif +} + +namespace Eigen { +template <> +struct NumTraits + : GenericNumTraits { + enum { + RequireInitialization = 1, + ReadCost = HugeCost, + AddCost = HugeCost, + MulCost = HugeCost + }; + + static inline int digits10() { return 0; } + +private: + static inline demo_plugin::tstring epsilon(); + static inline demo_plugin::tstring dummy_precision(); + static inline demo_plugin::tstring lowest(); + static inline demo_plugin::tstring highest(); + static inline demo_plugin::tstring infinity(); + static inline demo_plugin::tstring quiet_NaN(); +}; + +} // namespace Eigen + +#if defined(_MSC_VER) && !defined(__clang__) +namespace std { +template <> struct hash { + std::size_t operator()(const Eigen::half &a) const { + return static_cast(a.x); + } +}; + +template <> struct hash { + std::size_t operator()(const Eigen::bfloat16 &a) const { + return hash()(static_cast(a)); + } +}; +} // namespace std +#endif // _MSC_VER + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_NUMERIC_TYPES_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/op_def.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/op_def.proto new file mode 100644 index 000000000..fdc9c4010 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/op_def.proto @@ -0,0 +1,174 @@ +syntax = "proto3"; + +package demo_plugin; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto"; +import "tensorflow_plugin/src/utils/attr_value.proto"; +import "tensorflow_plugin/src/utils/types.proto"; +import "tensorflow_plugin/src/utils/resource_handle.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9>_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // The handle data for resource inputs. + repeated ResourceHandleProto.DtypeAndShape handle_data = 7; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Named control outputs for this operation. Useful only for composite + // operations (i.e. functions) which want to name different control outputs. + repeated string control_output = 20; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + // TODO(josh11b): bool is_optional? + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + // TODO(josh11b): Implement that optimization. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/op_performance_data.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/op_performance_data.proto new file mode 100644 index 000000000..091324187 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/op_performance_data.proto @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package demo_plugin; +option cc_enable_arenas = true; + +import "tensorflow_plugin/src/utils/tensor.proto"; +import "tensorflow_plugin/src/utils/tensor_shape.proto"; +import "tensorflow_plugin/src/utils/types.proto"; +import "tensorflow_plugin/src/utils/attr_value.proto"; +import "tensorflow_plugin/src/utils/device_properties.proto"; + +// Description of the session when an op is run. +message SessionInfo { + int64 intra_op_parallelism = 1; +} + +// Description of an operation as well as the parameters expected to impact its +// performance. +message OpInfo { + // The operation name. There may be custom parameters in attrs. + string op = 1; + + // Custom parameters impacting the behavior of the op. + map attr = 2; + + // Input data types, shapes and values if known. + message TensorProperties { + DataType dtype = 1; + TensorShapeProto shape = 2; + TensorProto value = 3; + }; + repeated TensorProperties inputs = 3; + + // Optional description of the op outputs + repeated TensorProperties outputs = 5; + + // Device on which the operation is run. + DeviceProperties device = 4; + + // Information about the session configs. + SessionInfo session_info = 6; +} + +message NormalDistribution { + double mu = 1; + double sigma = 2; +} + +message LogNormalDistribution { + double mu = 1; + double sigma = 2; +} + +// Performance data for tensorflow operations +message OpPerformance { + // The op + OpInfo op = 1; + + // Information about the session configs. + SessionInfo session_info = 12 [deprecated = true]; + + // The node name (optional). Makes it easier to associate the performance data + // with a specific graph node. + string node = 5; + + // Temporary memory used by this node (in bytes). + int64 temporary_memory_size = 2; + + // Time it takes to run the op (in nanoseconds). + int64 compute_cost = 3; + + // Analytical compute cost (in nanoseconds). + int64 compute_time = 6; + + // Analytical memory access cost (in nanoseconds). + int64 memory_time = 7; + + // Percentage of theoretical compute performance. + double compute_efficiency = 4; + + // Percentage of theoretical memory performance. + double memory_efficiency = 8; + + // Expected execution time, modeled using one of 2 possible distributions. + oneof execution_time { + NormalDistribution execution_time_normal = 10; + LogNormalDistribution execution_time_log_normal = 11; + }; + + // Memory usage data for a tensorflow operation. + message OpMemory { + // The output information may have memory usage and output shapes. + repeated int64 output_memory = 1; + + // Temp and persistent memory allocated by this node. + int64 temp_memory = 2; + int64 persistent_memory = 4; + + int64 device_temp_memory = 3 [deprecated = true]; + int64 device_persistent_memory = 5 [deprecated = true]; + } + OpMemory op_memory = 9; +} + +// A collection of OpPerformance data points. +message OpPerformanceList { + repeated OpPerformance op_performance = 1; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/platform.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/platform.h new file mode 100644 index 000000000..d2cbda874 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/platform.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_PLATFORM_DEFINE_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_PLATFORM_DEFINE_H_ + +#define PLATFORM_POSIX + +// Look for both gcc/clang and Visual Studio macros indicating we're compiling +// for an x86 device. +#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || \ + defined(_M_X64) +#define PLATFORM_IS_X86 +#endif + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_PLATFORM_DEFINE_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/prefetch.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/prefetch.h new file mode 100644 index 000000000..803d3853e --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/prefetch.h @@ -0,0 +1,58 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_PLATFORM_PREFETCH_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_PLATFORM_PREFETCH_H_ + +#include "tensorflow_plugin/src/utils/platform.h" + +namespace demo_plugin { +namespace port { + +// Prefetching support +// +// Defined behavior on some of the uarchs: +// PREFETCH_HINT_T0: +// prefetch to all levels of the hierarchy (except on p4: prefetch to L2) +// PREFETCH_HINT_NTA: +// p4: fetch to L2, but limit to 1 way (out of the 8 ways) +// core: skip L2, go directly to L1 +// k8 rev E and later: skip L2, can go to either of the 2-ways in L1 +enum PrefetchHint { + PREFETCH_HINT_T0 = 3, // More temporal locality + PREFETCH_HINT_T1 = 2, + PREFETCH_HINT_T2 = 1, // Less temporal locality + PREFETCH_HINT_NTA = 0 // No temporal locality +}; +template void prefetch(const void *x); + +// --------------------------------------------------------------------------- +// Inline implementation +// --------------------------------------------------------------------------- +template inline void prefetch(const void *x) { +// Check of COMPILER_GCC macro below is kept only for backward-compatibility +// reasons. COMPILER_GCC3 is the macro that actually enables prefetch. +#if defined(__llvm__) || defined(COMPILER_GCC) || defined(COMPILER_GCC3) + __builtin_prefetch(x, 0, hint); +#else +// You get no effect. Feel free to add more sections above. +#endif +} + +} // namespace port +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_PLATFORM_PREFETCH_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/protobuf.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/protobuf.h new file mode 100644 index 000000000..909ef3fcf --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/protobuf.h @@ -0,0 +1,122 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_PROTOBUF_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_PROTOBUF_H_ + +#include "tensorflow_plugin/src/utils/platform.h" +#include "tensorflow_plugin/src/utils/types.h" + +// Import whatever namespace protobuf comes from into the +// ::tensorflow::protobuf namespace. +// +// TensorFlow code should use the ::tensorflow::protobuf namespace to +// refer to all protobuf APIs. + +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor.pb.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/tokenizer.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/map.h" +#include "google/protobuf/message.h" +#include "google/protobuf/repeated_field.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/json_util.h" +#include "google/protobuf/util/type_resolver_util.h" + +namespace demo_plugin { + +namespace protobuf = ::google::protobuf; +using protobuf_int64 = ::google::protobuf::int64; +using protobuf_uint64 = ::google::protobuf::uint64; +extern const char *kProtobufInt64Typename; +extern const char *kProtobufUint64Typename; + +// Parses a protocol buffer contained in a string in the binary wire format. +// Returns true on success. Note: Unlike protobuf's builtin ParseFromString, +// this function has no size restrictions on the total size of the encoded +// protocol buffer. +bool ParseProtoUnlimited(protobuf::MessageLite *proto, + const std::string &serialized); +bool ParseProtoUnlimited(protobuf::MessageLite *proto, const void *serialized, + size_t size); +inline bool ParseProtoUnlimited(protobuf::MessageLite *proto, + const tstring &serialized) { + return ParseProtoUnlimited(proto, serialized.data(), serialized.size()); +} + +// Returns the string value for the value of a string or bytes protobuf field. +inline const std::string &ProtobufStringToString(const std::string &s) { + return s; +} + +// Set to . Swapping is allowed, as does not need to be +// preserved. +inline void SetProtobufStringSwapAllowed(std::string *src, std::string *dest) { + *dest = std::move(*src); +} + +#if defined(TENSORFLOW_PROTOBUF_USES_CORD) +// These versions of ProtobufStringToString and SetProtobufString get used by +// tools/proto_text's generated code. They have the same name as the versions +// in core/platform/protobuf.h, so the generation code doesn't need to determine +// if the type is Cord or string at generation time. +inline std::string ProtobufStringToString(const absl::Cord &s) { + return std::string(s); +} +inline void SetProtobufStringSwapAllowed(std::string *src, absl::Cord *dest) { + dest->CopyFrom(*src); +} +#endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) + +inline bool SerializeToTString(const protobuf::MessageLite &proto, + tstring *output) { + size_t size = proto.ByteSizeLong(); + output->resize_uninitialized(size); + return proto.SerializeWithCachedSizesToArray( + reinterpret_cast(output->data())); +} + +inline bool ParseFromTString(const tstring &input, + protobuf::MessageLite *proto) { + return proto->ParseFromArray(input.data(), static_cast(input.size())); +} + +// Analogue to StringOutputStream for tstring. +class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream { +public: + explicit TStringOutputStream(tstring *target); + ~TStringOutputStream() override = default; + + TStringOutputStream(const TStringOutputStream &) = delete; + void operator=(const TStringOutputStream &) = delete; + + bool Next(void **data, int *size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + +private: + static constexpr int kMinimumSize = 16; + + tstring *target_; +}; + +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_PROTOBUF_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/resource_handle.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/resource_handle.proto new file mode 100644 index 000000000..1d0effdc7 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/resource_handle.proto @@ -0,0 +1,45 @@ +syntax = "proto3"; + +package demo_plugin; + +import "tensorflow_plugin/src/utils/tensor_shape.proto"; +import "tensorflow_plugin/src/utils/types.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/resource_handle_go_proto"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; + + // Protocol buffer representing a pair of (data type, tensor shape). + message DtypeAndShape { + DataType dtype = 1; + TensorShapeProto shape = 2; + } + + // Data types and shapes for the underlying resource. + repeated DtypeAndShape dtypes_and_shapes = 6; + + reserved 7; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/stringpiece.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/stringpiece.h new file mode 100644 index 000000000..09b2afcd8 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/stringpiece.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_STRINGPIECE_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_STRINGPIECE_H_ + +#include "absl/strings/string_view.h" // IWYU pragma: export + +// StringPiece is a simple structure containing a pointer into some external +// storage and a size. The user of a StringPiece must ensure that the slice +// is not used after the corresponding external storage has been +// deallocated. +// +// Multiple threads can invoke const methods on a StringPiece without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same StringPiece must use +// external synchronization. + +namespace demo_plugin { + +using StringPiece = absl::string_view; + +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_STRINGPIECE_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tensor.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tensor.proto new file mode 100644 index 000000000..e6aa8dd74 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tensor.proto @@ -0,0 +1,96 @@ +syntax = "proto3"; + +package demo_plugin; + +import "tensorflow_plugin/src/utils/resource_handle.proto"; +import "tensorflow_plugin/src/utils/tensor_shape.proto"; +import "tensorflow_plugin/src/utils/types.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. TODO(touts): sort out the 0-rank issues. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +} + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tensor_shape.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tensor_shape.proto new file mode 100644 index 000000000..7f75c5539 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tensor_shape.proto @@ -0,0 +1,46 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"; + +package demo_plugin; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tstring.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tstring.h new file mode 100644 index 000000000..11937c2f8 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/tstring.h @@ -0,0 +1,598 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_TSTRING_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_TSTRING_H_ + +#include +#include +#include + +#include "tensorflow_plugin/src/utils/ctstring.h" + +// TODO(intel-tf): This include is temporary, and will be superfluous once +// absl::string_view is aliased to std::string_view. +#include "absl/strings/string_view.h" +namespace absl { +#ifdef ABSL_NAMESPACE_BEGIN +ABSL_NAMESPACE_BEGIN +#endif // ABSL_NAMESPACE_BEGIN +class AlphaNum; +#ifdef ABSL_NAMESPACE_END +ABSL_NAMESPACE_END +#endif // ABSL_NAMESPACE_END +} // namespace absl + +namespace demo_plugin { + +// tensorflow::tstring is the scalar type for DT_STRING tensors. +// +// tstrings are meant to be used when interfacing with string tensors, and +// should not be considered as a general replacement for std::string in +// tensorflow. The primary purpose of tstring is to provide a unified and +// stable ABI for string tensors across TF Core/C-API/Lite/etc---mitigating +// unnecessary conversions across language boundaries, and allowing for compiler +// agnostic interoperability across dynamically loaded modules. +// +// In addition to ABI stability, tstrings features two string subtypes, VIEW and +// OFFSET. +// +// VIEW tstrings are views into unowned character buffers; they can be used to +// pass around existing character strings without incurring a per object heap +// allocation. Note that, like std::string_view, it is the user's +// responsibility to ensure that the underlying buffer of a VIEW tstring exceeds +// the lifetime of the associated tstring object. +// +// TODO(dero): Methods for creating OFFSET tensors are not currently +// implemented. +// +// OFFSET tstrings are platform independent offset defined strings which can be +// directly mmaped or copied into a tensor buffer without the need for +// deserialization or processing. For security reasons, it is imperative that +// OFFSET based string tensors are validated before use, or are from a trusted +// source. +// +// Underlying VIEW and OFFSET buffers are considered immutable, so l-value +// assignment, mutation, or non-const access to data() of tstrings will result +// in the conversion to an owned SMALL/LARGE type. +// +// The interface for tstring largely overlaps with std::string. Except where +// noted, expect equivalent semantics with synonymous std::string methods. +class tstring { + TF_TString tstr_; + +public: + enum Type { + // See cstring.h + SMALL = TF_TSTR_SMALL, + LARGE = TF_TSTR_LARGE, + OFFSET = TF_TSTR_OFFSET, + VIEW = TF_TSTR_VIEW, + }; + + // Assignment to a tstring object with a tstring::view type will create a VIEW + // type tstring. + class view { + const char *data_; + size_t size_; + + public: + explicit view(const char *data, size_t size) : data_(data), size_(size) {} + explicit view(const char *data) : data_(data), size_(::strlen(data)) {} + + const char *data() const { return data_; } + + size_t size() const { return size_; } + + view() = delete; + view(const view &) = delete; + view &operator=(const view &) = delete; + }; + + typedef const char *const_iterator; + + // Ctor + tstring(); + tstring(const std::string &str); // NOLINT TODO(b/147740521): Make explicit. + tstring(const char *str, size_t len); + tstring(const char *str); // NOLINT TODO(b/147740521): Make explicit. + tstring(size_t n, char c); + explicit tstring(const absl::string_view str); +#ifdef PLATFORM_GOOGLE + explicit tstring(const absl::Cord &cord); +#endif // PLATFORM_GOOGLE + + // Copy + tstring(const tstring &str); + + // Move + tstring(tstring &&str) noexcept; + + // Dtor + ~tstring(); + + // Copy Assignment + tstring &operator=(const tstring &str); + tstring &operator=(const std::string &str); + tstring &operator=(const char *str); + tstring &operator=(char ch); + tstring &operator=(const absl::string_view str); +#ifdef PLATFORM_GOOGLE + tstring &operator=(const absl::Cord &cord); +#endif // PLATFORM_GOOGLE + + // View Assignment + tstring &operator=(const view &tsv); + + // Move Assignment + tstring &operator=(tstring &&str); + + // Comparison + int compare(const char *str, size_t len) const; + bool operator<(const tstring &o) const; + bool operator>(const tstring &o) const; + bool operator==(const char *str) const; + bool operator==(const tstring &o) const; + bool operator!=(const char *str) const; + bool operator!=(const tstring &o) const; + + // Conversion Operators + // TODO(b/147740521): Make explicit. + operator std::string() const; // NOLINT + // TODO(b/147740521): Make explicit. + operator absl::string_view() const; // NOLINT +#ifdef PLATFORM_GOOGLE + template ::value, + T>::type * = nullptr> + operator T() const; // NOLINT TODO(b/147740521): Remove. +#endif // PLATFORM_GOOGLE + + // Attributes + size_t size() const; + size_t length() const; + size_t capacity() const; + bool empty() const; + Type type() const; + + // Allocation + void resize(size_t new_size, char c = 0); + // Similar to resize, but will leave the newly grown region uninitialized. + void resize_uninitialized(size_t new_size); + void clear() noexcept; + void reserve(size_t n); + + // Iterators + const_iterator begin() const; + const_iterator end() const; + + // Const Element Access + const char *c_str() const; + const char *data() const; + const char &operator[](size_t i) const; + const char &back() const; + + // Mutable Element Access + // NOTE: For VIEW/OFFSET types, calling these methods will result in the + // conversion to a SMALL or heap allocated LARGE type. As a result, + // previously obtained pointers, references, or iterators to the underlying + // buffer will point to the original VIEW/OFFSET and not the new allocation. + char *mdata(); + char *data(); // DEPRECATED: Use mdata(). + char &operator[](size_t i); + + // Assignment + tstring &assign(const char *str, size_t len); + tstring &assign(const char *str); + + // View Assignment + tstring &assign_as_view(const tstring &str); + tstring &assign_as_view(const std::string &str); + tstring &assign_as_view(const absl::string_view str); + tstring &assign_as_view(const char *str, size_t len); + tstring &assign_as_view(const char *str); + + // Modifiers + // NOTE: Invalid input will result in undefined behavior. + tstring &append(const tstring &str); + tstring &append(const char *str, size_t len); + tstring &append(const char *str); + tstring &append(size_t n, char c); + + tstring &erase(size_t pos, size_t len); + + tstring &insert(size_t pos, const tstring &str, size_t subpos, size_t sublen); + tstring &insert(size_t pos, size_t n, char c); + void swap(tstring &str); + void push_back(char ch); + + // Friends + friend bool operator==(const char *a, const tstring &b); + friend bool operator==(const std::string &a, const tstring &b); + friend tstring operator+(const tstring &a, const tstring &b); + friend std::ostream &operator<<(std::ostream &o, const tstring &str); + friend std::hash; +}; + +// Non-member function overloads + +bool operator==(const char *a, const tstring &b); +bool operator==(const std::string &a, const tstring &b); +tstring operator+(const tstring &a, const tstring &b); +std::ostream &operator<<(std::ostream &o, const tstring &str); + +// Implementations + +// Ctor + +inline tstring::tstring() { TF_TString_Init(&tstr_); } + +inline tstring::tstring(const char *str, size_t len) { + TF_TString_Init(&tstr_); + TF_TString_Copy(&tstr_, str, len); +} + +inline tstring::tstring(const char *str) : tstring(str, ::strlen(str)) {} + +inline tstring::tstring(size_t n, char c) { + TF_TString_Init(&tstr_); + TF_TString_Resize(&tstr_, n, c); +} + +inline tstring::tstring(const std::string &str) + : tstring(str.data(), str.size()) {} + +inline tstring::tstring(const absl::string_view str) + : tstring(str.data(), str.size()) {} + +#ifdef PLATFORM_GOOGLE +inline tstring::tstring(const absl::Cord &cord) { + TF_TString_Init(&tstr_); + TF_TString_ResizeUninitialized(&tstr_, cord.size()); + + cord.CopyToArray(data()); +} +#endif // PLATFORM_GOOGLE + +// Copy + +inline tstring::tstring(const tstring &str) { + TF_TString_Init(&tstr_); + TF_TString_Assign(&tstr_, &str.tstr_); +} + +// Move + +inline tstring::tstring(tstring &&str) noexcept { + TF_TString_Init(&tstr_); + TF_TString_Move(&tstr_, &str.tstr_); +} + +// Dtor + +inline tstring::~tstring() { TF_TString_Dealloc(&tstr_); } + +// Copy Assignment + +inline tstring &tstring::operator=(const tstring &str) { + TF_TString_Assign(&tstr_, &str.tstr_); + + return *this; +} + +inline tstring &tstring::operator=(const std::string &str) { + TF_TString_Copy(&tstr_, str.data(), str.size()); + return *this; +} + +inline tstring &tstring::operator=(const char *str) { + TF_TString_Copy(&tstr_, str, ::strlen(str)); + + return *this; +} + +inline tstring &tstring::operator=(char c) { + resize_uninitialized(1); + (*this)[0] = c; + + return *this; +} + +inline tstring &tstring::operator=(const absl::string_view str) { + TF_TString_Copy(&tstr_, str.data(), str.size()); + + return *this; +} + +#ifdef PLATFORM_GOOGLE +inline tstring &tstring::operator=(const absl::Cord &cord) { + TF_TString_ResizeUninitialized(&tstr_, cord.size()); + + cord.CopyToArray(data()); + + return *this; +} +#endif // PLATFORM_GOOGLE + +// View Assignment + +inline tstring &tstring::operator=(const tstring::view &tsv) { + assign_as_view(tsv.data(), tsv.size()); + + return *this; +} + +// Move Assignment + +inline tstring &tstring::operator=(tstring &&str) { + TF_TString_Move(&tstr_, &str.tstr_); + + return *this; +} + +// Comparison + +inline int tstring::compare(const char *str, size_t len) const { + int ret = ::memcmp(data(), str, std::min(len, size())); + + if (ret < 0) + return -1; + if (ret > 0) + return +1; + + if (size() < len) + return -1; + if (size() > len) + return +1; + + return 0; +} + +inline bool tstring::operator<(const tstring &o) const { + return compare(o.data(), o.size()) < 0; +} + +inline bool tstring::operator>(const tstring &o) const { + return compare(o.data(), o.size()) > 0; +} + +inline bool tstring::operator==(const char *str) const { + return ::strlen(str) == size() && ::memcmp(data(), str, size()) == 0; +} + +inline bool tstring::operator==(const tstring &o) const { + return o.size() == size() && ::memcmp(data(), o.data(), size()) == 0; +} + +inline bool tstring::operator!=(const char *str) const { + return !(*this == str); +} + +inline bool tstring::operator!=(const tstring &o) const { + return !(*this == o); +} + +// Conversion Operators + +inline tstring::operator std::string() const { + return std::string(data(), size()); +} + +inline tstring::operator absl::string_view() const { + return absl::string_view(data(), size()); +} + +#ifdef PLATFORM_GOOGLE +template ::value, T>::type *> +inline tstring::operator T() const { + return T(absl::string_view(*this)); +} +#endif // PLATFORM_GOOGLE + +// Attributes + +inline size_t tstring::size() const { return TF_TString_GetSize(&tstr_); } + +inline size_t tstring::length() const { return size(); } + +inline size_t tstring::capacity() const { + return TF_TString_GetCapacity(&tstr_); +} + +inline bool tstring::empty() const { return size() == 0; } + +inline tstring::Type tstring::type() const { + return static_cast(TF_TString_GetType(&tstr_)); +} + +// Allocation + +inline void tstring::resize(size_t new_size, char c) { + TF_TString_Resize(&tstr_, new_size, c); +} + +inline void tstring::resize_uninitialized(size_t new_size) { + TF_TString_ResizeUninitialized(&tstr_, new_size); +} + +inline void tstring::clear() noexcept { + TF_TString_ResizeUninitialized(&tstr_, 0); +} + +inline void tstring::reserve(size_t n) { TF_TString_Reserve(&tstr_, n); } + +// Iterators + +inline tstring::const_iterator tstring::begin() const { return &(*this)[0]; } +inline tstring::const_iterator tstring::end() const { return &(*this)[size()]; } + +// Element Access + +inline const char *tstring::c_str() const { return data(); } + +inline const char *tstring::data() const { + return TF_TString_GetDataPointer(&tstr_); +} + +inline const char &tstring::operator[](size_t i) const { return data()[i]; } + +inline const char &tstring::back() const { return (*this)[size() - 1]; } + +inline char *tstring::mdata() { + return TF_TString_GetMutableDataPointer(&tstr_); +} + +inline char *tstring::data() { + // Deprecated + return mdata(); +} + +inline char &tstring::operator[](size_t i) { return mdata()[i]; } + +// Assignment + +inline tstring &tstring::assign(const char *str, size_t len) { + TF_TString_Copy(&tstr_, str, len); + + return *this; +} + +inline tstring &tstring::assign(const char *str) { + assign(str, ::strlen(str)); + + return *this; +} + +// View Assignment + +inline tstring &tstring::assign_as_view(const tstring &str) { + assign_as_view(str.data(), str.size()); + + return *this; +} + +inline tstring &tstring::assign_as_view(const std::string &str) { + assign_as_view(str.data(), str.size()); + + return *this; +} + +inline tstring &tstring::assign_as_view(const absl::string_view str) { + assign_as_view(str.data(), str.size()); + + return *this; +} + +inline tstring &tstring::assign_as_view(const char *str, size_t len) { + TF_TString_AssignView(&tstr_, str, len); + + return *this; +} + +inline tstring &tstring::assign_as_view(const char *str) { + assign_as_view(str, ::strlen(str)); + + return *this; +} + +// Modifiers + +inline tstring &tstring::append(const tstring &str) { + TF_TString_Append(&tstr_, &str.tstr_); + + return *this; +} + +inline tstring &tstring::append(const char *str, size_t len) { + TF_TString_AppendN(&tstr_, str, len); + + return *this; +} + +inline tstring &tstring::append(const char *str) { + append(str, ::strlen(str)); + + return *this; +} + +inline tstring &tstring::append(size_t n, char c) { + resize(size() + n, c); + + return *this; +} + +inline tstring &tstring::erase(size_t pos, size_t len) { + memmove(mdata() + pos, data() + pos + len, size() - len - pos); + + resize(size() - len); + + return *this; +} + +inline tstring &tstring::insert(size_t pos, const tstring &str, size_t subpos, + size_t sublen) { + size_t orig_size = size(); + TF_TString_ResizeUninitialized(&tstr_, orig_size + sublen); + + memmove(mdata() + pos + sublen, data() + pos, orig_size - pos); + memmove(mdata() + pos, str.data() + subpos, sublen); + + return *this; +} + +inline tstring &tstring::insert(size_t pos, size_t n, char c) { + size_t size_ = size(); + TF_TString_ResizeUninitialized(&tstr_, size_ + n); + + memmove(mdata() + pos + n, data() + pos, size_ - pos); + memset(mdata() + pos, c, n); + + return *this; +} + +inline void tstring::swap(tstring &str) { + // TODO(dero): Invalid for OFFSET (unimplemented). + std::swap(tstr_, str.tstr_); +} + +inline void tstring::push_back(char ch) { append(1, ch); } + +// Friends + +inline bool operator==(const char *a, const tstring &b) { + return ::strlen(a) == b.size() && ::memcmp(a, b.data(), b.size()) == 0; +} + +inline bool operator==(const std::string &a, const tstring &b) { + return a.size() == b.size() && ::memcmp(a.data(), b.data(), b.size()) == 0; +} + +inline tstring operator+(const tstring &a, const tstring &b) { + tstring r; + r.reserve(a.size() + b.size()); + r.append(a); + r.append(b); + + return r; +} + +inline std::ostream &operator<<(std::ostream &o, const tstring &str) { + return o.write(str.data(), str.size()); +} + +} // namespace demo_plugin + +#endif // TENSORFLOW_PLUGIN_SRC_UTILS_TSTRING_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/types.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/types.h new file mode 100644 index 000000000..619ab1ea6 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/types.h @@ -0,0 +1,469 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLUGIN_SRC_UTILS_TYPES_H_ +#define TENSORFLOW_PLUGIN_SRC_UTILS_TYPES_H_ + +#include +#include +#include + +//#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +// Disable clang-format to prevent 'FixedPoint' header from being included +// before 'Tensor' header on which it depends. +// clang-format off +//#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" +// clang-format on +#include "tensorflow_plugin/src/utils/types.pb.h" +//#include "tensorflow_plugin/src/utils/numeric_types.h" +#include "tensorflow_plugin/src/utils/integral_types.h" +#include "tensorflow_plugin/src/utils/gtl/array_slice.h" +#include "tensorflow_plugin/src/utils/gtl/inlined_vector.h" +#include "tensorflow_plugin/src/utils/logging.h" +#include "tensorflow_plugin/src/utils/stringpiece.h" +#include "tensorflow_plugin/src/utils/tstring.h" + +namespace demo_plugin { + +// class Variant; +// +// MemoryType is used to describe whether input or output Tensors of +// an OpKernel should reside in "Host memory" (e.g., CPU memory) or +// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU +// devices). +enum MemoryType { + DEVICE_MEMORY = 0, + HOST_MEMORY = 1, +}; + +// A DeviceType is just a string, but we wrap it up in a class to give +// some type checking as we're passing these around +class DeviceType { +public: + DeviceType(const char *type) // NOLINT(runtime/explicit) + : type_(type) {} + + explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} + + const char *type() const { return type_.c_str(); } + const std::string &type_string() const { return type_; } + + bool operator<(const DeviceType &other) const; + bool operator==(const DeviceType &other) const; + bool operator!=(const DeviceType &other) const { return !(*this == other); } + +private: + std::string type_; +}; +std::ostream &operator<<(std::ostream &os, const DeviceType &d); + +// Convenient constants that can be passed to a DeviceType constructor +TF_EXPORT extern const char *const DEVICE_DEFAULT; // "DEFAULT" +TF_EXPORT extern const char *const DEVICE_CPU; // "CPU" +TF_EXPORT extern const char *const DEVICE_GPU; // "GPU" +TF_EXPORT extern const char *const DEVICE_XPU; +TF_EXPORT extern const char *const DEVICE_AUTO; // "AUTO" + +template struct DeviceName {}; +/* +template <> struct DeviceName { + static const std::string value; +}; + +template <> struct DeviceName { + static const std::string value; +}; +*/ +typedef gtl::InlinedVector MemoryTypeVector; +typedef gtl::ArraySlice MemoryTypeSlice; + +typedef gtl::InlinedVector DataTypeVector; +typedef gtl::ArraySlice DataTypeSlice; + +typedef gtl::InlinedVector DeviceTypeVector; +typedef gtl::InlinedVector, 4> + PrioritizedDeviceTypeVector; + +//// Convert the enums to strings for errors: +std::string DataTypeString(DataType dtype); +std::string DeviceTypeString(const DeviceType &device_type); +// std::string DataTypeSliceString(const DataTypeSlice dtypes); +// inline std::string DataTypeVectorString(const DataTypeVector& dtypes) { +// return DataTypeSliceString(dtypes); +//} + +// DataTypeSet represents a set of DataType values as a simple and efficient +// bit mask. Note that DataTypeSet cannot represent all DataType values; it +// cannot represent any of the DT_*_REF values. +class DataTypeSet { +private: + const uint32 mask_; + + static constexpr uint32 kNumBits = 32; + +public: + constexpr DataTypeSet(const DataTypeSet &other) : mask_(other.mask_) {} + explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {} + + constexpr bool Contains(DataType dt) const { + return (static_cast(dt) < kNumBits) && + ((mask_ >> static_cast(dt)) & 1u) != 0u; + } + + class Iterator { + const DataTypeSet &set_; + uint32 pos_; + + public: + Iterator(const DataTypeSet &set, uint32 pos) : set_(set), pos_(pos) { + DCHECK_LE(pos, kNumBits); + } + DataType operator*() const { return static_cast(pos_); } + Iterator &operator++() { + ++pos_; + DCHECK_LE(pos_, kNumBits); + if (pos_ < kNumBits) { + uint32 remaining_mask = set_.mask_ >> pos_; + if (remaining_mask != 0u) { + pos_ += ctz_uint32(remaining_mask); + } + } + DCHECK_LE(pos_, kNumBits); + return *this; + } + bool operator==(const Iterator &other) const { return pos_ == other.pos_; } + bool operator!=(const Iterator &other) const { return !(*this == other); } + size_t operator-(const Iterator &other) const { + return this->pos_ - other.pos_; + } + }; + + static uint32 ctz_uint32(uint32 x) { + DCHECK_NE(x, 0u); +#ifdef __GNUC__ + return __builtin_ctz(x); +#else + uint32 n = 0u; + while ((x & 1u) == 0u) { + x >>= 1; + ++n; + } + return n; +#endif + } + + static uint32 clz_uint32(uint32 x) { + DCHECK_NE(x, 0u); +#ifdef __GNUC__ + return __builtin_clz(x); +#else + uint32 n = 0u; + while ((x >> (kNumBits - 1u)) == 0u) { + x <<= 1; + ++n; + } + return n; +#endif + } + + Iterator begin() const { + // The begin position is the index of the first bit set to 1 in the entire + // bit mask. If there are no bits set to 1, then the index is 0. + if (mask_ != 0) { + return Iterator(*this, ctz_uint32(mask_)); + } + // The set is empty. + return Iterator(*this, 0); + } + + Iterator end() const { + // The end position is the index of the highest bit that is set, plus 1. + // If there are no bits set to 1, then the index is 0. + if (mask_ != 0) { + return Iterator(*this, kNumBits - clz_uint32(mask_)); + } + // The set is empty. + return Iterator(*this, 0); + } + + size_t size() const { +#if defined(__GNUC__) + return __builtin_popcount(mask_); +#else + size_t n = 0; + uint32 x = mask_; + while (x > 0) { + n += x & 1u; + x >>= 1; + } + return n; +#endif + } + + constexpr DataTypeSet operator|(const DataTypeSet &other) const { + return DataTypeSet(mask_ | other.mask_); + } +}; + +// If "sp" names a valid type, store it in "*dt" and return true. Otherwise, +// return false. +bool DataTypeFromString(StringPiece sp, DataType *dt); + +constexpr inline DataTypeSet ToSet(DataType dt) { + return DataTypeSet(1u << static_cast(dt)); +} + +// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. +enum { kDataTypeRefOffset = 100 }; +inline bool IsRefType(DataType dtype) { + return dtype > static_cast(kDataTypeRefOffset); +} +inline DataType MakeRefType(DataType dtype) { + DCHECK(!IsRefType(dtype)); + return static_cast(dtype + kDataTypeRefOffset); +} +inline DataType RemoveRefType(DataType dtype) { + DCHECK(IsRefType(dtype)); + return static_cast(dtype - kDataTypeRefOffset); +} +inline DataType BaseType(DataType dtype) { + return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; +} + +// Returns true if the actual type is the same as or ref of the expected type. +inline bool TypesCompatible(DataType expected, DataType actual) { + return expected == actual || expected == BaseType(actual); +} + +//// Does not include _ref types. +// constexpr DataTypeSet kAllTypes = +// ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) | +// ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) | +// ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | +// ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | +// ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) +// | ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | +// ToSet(DT_BFLOAT16); +// inline const DataTypeSet& AllTypes() { return kAllTypes; } +// +//// Types that support '<' and '>'. +// constexpr DataTypeSet kRealNumberTypes = +// ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | +// ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) | +// ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); +// inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; } +// +//// Return the list of all numeric types. +//// Includes complex and quantized types. +//// NOTE: On Android, we only include the float and int32 types for now. +// const DataTypeSet kNumberTypes = +// ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) | +// ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | +// ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) | +// ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) | +// ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); +// inline const DataTypeSet& NumberTypes() { return kNumberTypes; } +// +// constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | +// ToSet(DT_QINT16) | ToSet(DT_QUINT16) | +// ToSet(DT_QINT32); +// inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; } +// +//// Types that support '<' and '>', including quantized types. +// const DataTypeSet kRealAndQuantizedTypes = +// ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | +// ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | +// ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) +// | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16); +// inline const DataTypeSet& RealAndQuantizedTypes() { +// return kRealAndQuantizedTypes; +//} +// + +// Validates type T for whether it is a supported DataType. +template struct IsValidDataType; + +// DataTypeToEnum::v() and DataTypeToEnum::value are the DataType +// constants for T, e.g. DataTypeToEnum::v() is DT_FLOAT. +template struct DataTypeToEnum { + static_assert(IsValidDataType::value, "Specified Data Type not supported"); +}; // Specializations below + +// EnumToDataType::Type is the type for DataType constant VALUE, e.g. +// EnumToDataType::Type is float. +template struct EnumToDataType {}; // Specializations below + +// Template specialization for both DataTypeToEnum and EnumToDataType. +#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ + template <> struct DataTypeToEnum { \ + static DataType v() { return ENUM; } \ + static DataType ref() { return MakeRefType(ENUM); } \ + static constexpr DataType value = ENUM; \ + }; \ + template <> struct IsValidDataType { \ + static constexpr bool value = true; \ + }; \ + template <> struct EnumToDataType { typedef TYPE Type; } + +MATCH_TYPE_AND_ENUM(float, DT_FLOAT); +MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); +MATCH_TYPE_AND_ENUM(int32, DT_INT32); +MATCH_TYPE_AND_ENUM(uint32, DT_UINT32); +MATCH_TYPE_AND_ENUM(uint16, DT_UINT16); +MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); +MATCH_TYPE_AND_ENUM(int16, DT_INT16); +MATCH_TYPE_AND_ENUM(int8, DT_INT8); +MATCH_TYPE_AND_ENUM(tstring, DT_STRING); +//MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); +//MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128); +MATCH_TYPE_AND_ENUM(bool, DT_BOOL); +//MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); +//MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); +//MATCH_TYPE_AND_ENUM(qint16, DT_QINT16); +//MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16); +//MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); +//MATCH_TYPE_AND_ENUM(Eigen::bfloat16, DT_BFLOAT16); +//MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); +// MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); +// MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT); + +template <> struct DataTypeToEnum { + static DataType v() { return value; } + static DataType ref() { return MakeRefType(value); } + static constexpr DataType value = sizeof(long) == 4 ? DT_INT32 : DT_INT64; +}; +template <> struct IsValidDataType { + static constexpr bool value = true; +}; +template <> struct EnumToDataType { + typedef demo_plugin::int64 Type; +}; + +template <> struct DataTypeToEnum { + static DataType v() { return value; } + static DataType ref() { return MakeRefType(value); } + static constexpr DataType value = + sizeof(unsigned long) == 4 ? DT_UINT32 : DT_UINT64; +}; +template <> struct IsValidDataType { + static constexpr bool value = true; +}; +template <> struct EnumToDataType { + typedef demo_plugin::uint64 Type; +}; + +template <> struct DataTypeToEnum { + static DataType v() { return DT_INT64; } + static DataType ref() { return MakeRefType(DT_INT64); } + static constexpr DataType value = DT_INT64; +}; +template <> struct IsValidDataType { + static constexpr bool value = true; +}; + +template <> struct DataTypeToEnum { + static DataType v() { return DT_UINT64; } + static DataType ref() { return MakeRefType(DT_UINT64); } + static constexpr DataType value = DT_UINT64; +}; +template <> struct IsValidDataType { + static constexpr bool value = true; +}; + +#undef MATCH_TYPE_AND_ENUM + +// All types not specialized are marked invalid. +template struct IsValidDataType { + static constexpr bool value = false; +}; + +// Extra validity checking; not part of public API. +static_assert(IsValidDataType::value, "Incorrect impl for int64"); +static_assert(IsValidDataType::value, "Incorrect impl for int32"); + +//// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying +//// is_simple in tensor.cc (and possible choose a more general name?) +// constexpr DataTypeSet kDataTypesCanUseMemcpy = +// ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) | +// ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | +// ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | +// ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | +// ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | +// ToSet(DT_BFLOAT16) | ToSet(DT_HALF); +// inline bool DataTypeCanUseMemcpy(DataType dt) { +// return kDataTypesCanUseMemcpy.Contains(dt); +//} +// +//// Returns true iff 'dt' is a real, non-quantized floating point type. +// constexpr DataTypeSet kDataTypeIsFloating = +// ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE); +// inline bool DataTypeIsFloating(DataType dt) { +// return kDataTypeIsFloating.Contains(dt); +//} +// +//// Returns true iff 'dt' is a complex type. +// constexpr DataTypeSet kDataTypeIsComplex = +// ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128); +// inline bool DataTypeIsComplex(DataType dt) { +// return kDataTypeIsComplex.Contains(dt); +//} +// +// inline bool DataTypeIsQuantized(DataType dt) { +// return kQuantizedTypes.Contains(dt); +//} + +// Is the dtype nonquantized integral? +constexpr DataTypeSet kDataTypeIsInteger = + ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) | + ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64); +inline bool DataTypeIsInteger(DataType dt) { + return kDataTypeIsInteger.Contains(dt); +} + +// Is the dtype a signed integral type? +constexpr DataTypeSet kDataTypeIsSigned = + ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64); +inline bool DataTypeIsSigned(DataType dt) { + return kDataTypeIsSigned.Contains(dt); +} +// +//// Is the dtype an unsigned integral type? +// constexpr DataTypeSet kDataTypeIsUnsigned = +// ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64); +// inline bool DataTypeIsUnsigned(DataType dt) { +// return kDataTypeIsUnsigned.Contains(dt); +//} +// +//// Returns a 0 on failure +int DataTypeSize(DataType dt); + +// Returns HOST_MEMORY if `dtype` is always on host or is a DT_INT32, +// DEVICE_MEMORY otherwise. +MemoryType MTypeFromDType(const DataType dtype); + +//// Returns HOST_MEMORY if `dtype` is always on host, DEVICE_MEMORY otherwise. +//// The reason we have MTypeFromDType() and MTypeFromDTypeIntsOnDevice(): for +//// GPUs, we would like to keep int operations on host for performance +/// concerns. / But for TPUs (and other devices), int operations are placed on +/// device. +// MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype); +// +//// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE. +//// For DT_RESOURCE, the handle always sits on host (even if the underlying +//// object has device-allocated resources). +// bool DataTypeAlwaysOnHost(DataType dt); + +} // namespace demo_plugin +#endif // TAENSORFLOW_PLUGIN_SRC_UTILS_TYPES_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/types.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/types.proto new file mode 100644 index 000000000..6a600d223 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/types.proto @@ -0,0 +1,89 @@ +syntax = "proto3"; + +package demo_plugin; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"; + +// (== suppress_warning documentation-presence ==) +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/tf_datatype.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) + +// For identifying the underlying type of a variant. For variants, the types +// listed here are a subset of the types in the variant type registry, +// corresponding to commonly used variants which must occasionally be +// special-cased. +enum SpecializedType { + // Invalid/unknown specialized type. + ST_INVALID = 0; + // "tensorflow::TensorList" in the variant type registry. + ST_TENSOR_LIST = 1; + // "tensorflow::data::Optional" in the variant type registry. + ST_OPTIONAL = 2; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/versions.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/versions.proto new file mode 100644 index 000000000..00adb8ed8 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/versions.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package demo_plugin; + +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/versions_go_proto"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/xplane.proto b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/xplane.proto new file mode 100644 index 000000000..26fdb7c91 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/src/utils/xplane.proto @@ -0,0 +1,156 @@ +syntax = "proto3"; + +package demo_plugin.profiler; + +option cc_enable_arenas = true; + +// A container of parallel XPlanes, generated by one or more profiling sources. +// Next ID: 5 +message XSpace { + repeated XPlane planes = 1; + // Errors (if any) in the generation of planes. + repeated string errors = 2; + // Warnings (if any) in the generation of planes; + repeated string warnings = 3; + // List of hostnames that XPlanes are generated from. + repeated string hostnames = 4; +} + +// An XPlane is a container of parallel timelines (XLines), generated by a +// profiling source or by post-processing one or more XPlanes. +// Next ID: 7 +message XPlane { + int64 id = 1; + + // Name of this line. + string name = 2; + + // Parallel timelines grouped in this plane. XLines with the same id + // are effectively the same timeline. + repeated XLine lines = 3; + + // XEventMetadata map, each entry uses the XEventMetadata.id as key. This map + // should be used for events that share the same ID over the whole XPlane. + map event_metadata = 4; + + // XStatMetadata map, each entry uses the XStatMetadata.id as key. This map + // should be used for stats that share the same ID over the whole XPlane. + map stat_metadata = 5; + + // XStats associated with this plane, e.g. device capabilities. + // Each of these XStats should have a different metadata_id. + repeated XStat stats = 6; +} + +// An XLine is a timeline of trace events (XEvents). +// Next ID: 12 +message XLine { + // Id of this line, can be repeated within an XPlane. All XLines with the + // same id are effectively the same timeline. + int64 id = 1; + + // Display id of this line. Multiple lines with the same display_id are + // grouped together in the same trace viewer row. + int64 display_id = 10; + + // Name of this XLine. + string name = 2; + + // Name of this XLine to display in trace viewer. + string display_name = 11; + + // Start time of this line in nanoseconds since the UNIX epoch. + // XEvent.offset_ps is relative to this timestamp. + int64 timestamp_ns = 3; + + // Profiling duration for this line in picoseconds. + int64 duration_ps = 9; + + // XEvents within the same XLine should not overlap in time, but they can be + // nested. + repeated XEvent events = 4; + + reserved 5, 6, 7, 8; +} + +// An XEvent is a trace event, optionally annotated with XStats. +// Next ID: 6 +message XEvent { + // XEventMetadata.id of corresponding metadata. + int64 metadata_id = 1; + + oneof data { + // Start time of the event in picoseconds, as offset from + // XLine.timestamp_ns(). + int64 offset_ps = 2; + + // Number of occurrences of the event, if aggregated. + int64 num_occurrences = 5; + } + + // Duration of the event in picoseconds. Can be zero for an instant event. + int64 duration_ps = 3; + + // XStats associated with the event. + // Each of these XStats should have a different metadata_id. + repeated XStat stats = 4; +} + +// An XStat is a named value associated with an XEvent, e.g., a performance +// counter value, a metric computed by a formula applied over nested XEvents +// and XStats. +// Next ID: 8 +message XStat { + // XStatMetadata.id of corresponding metadata. + int64 metadata_id = 1; + + // Value of this stat. + oneof value { + double double_value = 2; + uint64 uint64_value = 3; + int64 int64_value = 4; + string str_value = 5; + bytes bytes_value = 6; + // A string value that stored in XStatMetadata::name. + uint64 ref_value = 7; + } +} + +// Metadata for an XEvent, corresponds to an event type and is shared by +// all XEvents with the same metadata_id. +// Next ID: 7 +message XEventMetadata { + // XPlane.event_metadata map key. + int64 id = 1; + + // Name of the event. + string name = 2; + + // Name of the event shown in trace viewer. + string display_name = 4; + + // Additional metadata in serialized format. + bytes metadata = 3; + + // XStats that are constant for all XEvents with the same metadata_id. + // Each of these XStats should have a different metadata_id. + repeated XStat stats = 5; + + // XPlane.event_metadata map key for children events. + repeated int64 child_id = 6; +} + +// Metadata for an XStat, corresponds to a stat type and is shared by all +// XStats with the same metadata_id. +// Next ID: 4 +message XStatMetadata { + // XPlane.stat_metadata map key. + int64 id = 1; + + // Name of the stat (should be short). + // Two XStatMetadata with different id should have different names. + string name = 2; + + // Description of the stat (might be long). + string description = 3; +} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tf_configure.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tf_configure.bzl new file mode 100644 index 000000000..cf7e11046 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tf_configure.bzl @@ -0,0 +1,205 @@ +"""Setup TensorFlow as external dependency""" + +_TF_HEADER_DIR = "TF_HEADER_DIR" + +def _tpl(repository_ctx, tpl, substitutions = {}, out = None): + if not out: + out = tpl + repository_ctx.template( + out, + Label("//third_party/tf_dependency:%s.tpl" % tpl), + substitutions, + ) + +def _fail(msg): + """Output failure message when auto configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _is_windows(repository_ctx): + """Returns true if the host operating system is windows.""" + os_name = repository_ctx.os.name.lower() + if os_name.find("windows") != -1: + return True + return False + +def _execute( + repository_ctx, + cmdline, + error_msg = None, + error_details = None, + empty_stdout_fine = False): + """Executes an arbitrary shell command. + + Helper for executes an arbitrary shell command. + + Args: + repository_ctx: the repository_ctx object. + cmdline: list of strings, the command to execute. + error_msg: string, a summary of the error if the command fails. + error_details: string, details about the error or steps to fix it. + empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise + it's an error. + + Returns: + The result of repository_ctx.execute(cmdline). + """ + result = repository_ctx.execute(cmdline) + if result.stderr or not (empty_stdout_fine or result.stdout): + _fail("\n".join([ + error_msg.strip() if error_msg else "Repository command failed", + result.stderr.strip(), + error_details if error_details else "", + ])) + return result + +def _read_dir(repository_ctx, src_dir): + """Returns a string with all files in a directory. + + Finds all files inside a directory, traversing subfolders and following + symlinks. The returned string contains the full path of all files + separated by line breaks. + + Args: + repository_ctx: the repository_ctx object. + src_dir: directory to find files from. + + Returns: + A string of all files inside the given dir. + """ + if _is_windows(repository_ctx): + src_dir = src_dir.replace("/", "\\") + find_result = _execute( + repository_ctx, + ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], + empty_stdout_fine = True, + ) + + # src_files will be used in genrule.outs where the paths must + # use forward slashes. + result = find_result.stdout.replace("\\", "/") + else: + find_result = _execute( + repository_ctx, + ["find", src_dir, "-follow", "-type", "f"], + empty_stdout_fine = True, + ) + result = find_result.stdout + return result + +def _genrule(genrule_name, command, outs): + """Returns a string with a genrule. + + Genrule executes the given command and produces the given outputs. + + Args: + genrule_name: A unique name for genrule target. + command: The command to run. + outs: A list of files generated by this rule. + + Returns: + A genrule target. + """ + return ( + "genrule(\n" + + ' name = "' + + genrule_name + '",\n' + + " outs = [\n" + + outs + + "\n ],\n" + + ' cmd = """\n' + + command + + '\n """,\n' + + ")\n" + ) + +def _norm_path(path): + """Returns a path with '/' and remove the trailing slash.""" + path = path.replace("\\", "/") + if path[-1] == "/": + path = path[:-1] + return path + +def _symlink_genrule_for_dir( + repository_ctx, + src_dir, + dest_dir, + genrule_name, + src_files = [], + dest_files = [], + tf_pip_dir_rename_pair = []): + """Returns a genrule to symlink(or copy if on Windows) a set of files. + If src_dir is passed, files will be read from the given directory; otherwise + we assume files are in src_files and dest_files. + Args: + repository_ctx: the repository_ctx object. + src_dir: source directory. + dest_dir: directory to create symlink in. + genrule_name: genrule name. + src_files: list of source files instead of src_dir. + dest_files: list of corresonding destination files. + tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to + replace. For example, in TF pip package, the source code is under + "tensorflow_core", and we might want to replace it with + "tensorflow" to match the header includes. + Returns: + genrule target that creates the symlinks. + """ + + # Check that tf_pip_dir_rename_pair has the right length + tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair) + if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2: + _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len) + + if src_dir != None: + src_dir = _norm_path(src_dir) + dest_dir = _norm_path(dest_dir) + files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) + + # Create a list with the src_dir stripped to use for outputs. + if tf_pip_dir_rename_pair_len: + dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines() + else: + dest_files = files.replace(src_dir, "").splitlines() + src_files = files.splitlines() + command = [] + outs = [] + + for i in range(len(dest_files)): + if dest_files[i] != "": + # If we have only one file to link we do not want to use the dest_dir, as + # $(@D) will include the full path to the file. + dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] + + # Copy the headers to create a sandboxable setup. + cmd = "cp -f" + command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) + outs.append(' "' + dest_dir + dest_files[i] + '",') + + genrule = _genrule( + genrule_name, + ";\n".join(command), + "\n".join(outs), + ) + return genrule + +def _tf_pip_impl(repository_ctx): + tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR] + tf_header_rule = _symlink_genrule_for_dir( + repository_ctx, + tf_header_dir, + "include", + "tf_header_include", + tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"], + ) + _tpl(repository_ctx, "BUILD", { + "%{TF_HEADER_GENRULE}": tf_header_rule, + }) + +tf_configure = repository_rule( + environ = [ + _TF_HEADER_DIR, + ], + implementation = _tf_pip_impl, +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/BUILD new file mode 100644 index 000000000..5d8298be2 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/BUILD @@ -0,0 +1,47 @@ +# Description: +# Tools for building the TensorFlow pip package. + +package(default_visibility = ["//visibility:private"]) + +load( + "//tensorflow_plugin:demo_plugin.bzl", + "transitive_hdrs", +) +load( + "@local_config_syslibs//:build_defs.bzl", + "if_not_system_lib", +) + +# This returns a list of headers of all public header libraries (e.g., +# framework, lib), and all of the transitive dependencies of those +# public headers. Not all of the headers returned by the filegroup +# are public (e.g., internal headers that are included by public +# headers), but the internal headers need to be packaged in the +# pip_package for the public headers to be properly included. +# +# Public headers are therefore defined by those that are both: +# +# 1) "publicly visible" as defined by bazel +# 2) Have documentation. +# +# This matches the policy of "public" for our python API. + +COMMON_PIP_DEPS = [ + "MANIFEST.in", + "README", + "setup.py", + "//tensorflow_plugin:libdemo_plugin.so", +] + +py_binary( + name = "simple_console", + srcs = ["simple_console.py"], + srcs_version = "PY2AND3", + deps = [], +) + +sh_binary( + name = "build_pip_package", + srcs = ["build_pip_package.sh"], + data = ["simple_console"] + COMMON_PIP_DEPS, +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/MANIFEST.in b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/MANIFEST.in new file mode 100644 index 000000000..1631b1c92 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/MANIFEST.in @@ -0,0 +1,12 @@ +include README +recursive-include * *.py +recursive-include * *.pyd +recursive-include * *.pd +recursive-include * *.so +recursive-include * *.so.[0-9] +recursive-include * *.dylib +recursive-include * *.dll +recursive-include * *.lib +recursive-include * *.csv +recursive-include * *.h +recursive-include * *.hpp \ No newline at end of file diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/README b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/README new file mode 100644 index 000000000..5590d70e0 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/README @@ -0,0 +1 @@ +demo_plugin diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/build_pip_package.sh b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/build_pip_package.sh new file mode 100755 index 000000000..66f364d39 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/build_pip_package.sh @@ -0,0 +1,226 @@ +#!/usr/bin/env bash +# Copyright 2012 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +set -e + +function is_absolute { + [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]] +} + +function real_path() { + is_absolute "$1" && echo "$1" || echo "$PWD/${1#./}" +} + +function cp_external() { + local src_dir=$1 + local dest_dir=$2 + + pushd . + cd "$src_dir" + for f in `find . ! -type d ! -name '*.py' ! -path '*local_config_syslibs*' ! -path '*org_tensorflow_tensorflow-plugin*'`; do + mkdir -p "${dest_dir}/$(dirname ${f})" + cp "${f}" "${dest_dir}/$(dirname ${f})/" + done + popd +} + +PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" +function is_windows() { + if [[ "${PLATFORM}" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]; then + true + else + false + fi +} + +function prepare_src() { + if [ $# -lt 1 ] ; then + echo "No destination dir provided" + exit 1 + fi + + TMPDIR="$1" + mkdir -p "$TMPDIR" + EXTERNAL_INCLUDES="${TMPDIR}/tensorflow-plugins/include/external" + + echo $(date) : "=== Preparing sources in dir: ${TMPDIR}" + + if [ ! -d bazel-bin/tensorflow_plugin ]; then + echo "Could not find bazel-bin. Did you run from the root of the build tree?" + exit 1 + fi + + RUNFILES=bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package.runfiles/org_tensorflow_plugin + if [ -d bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package.runfiles/org_tensorflow_plugin/external ]; then + # Old-style runfiles structure (--legacy_external_runfiles). + cp -R \ + bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package.runfiles/org_tensorflow_plugin/tensorflow_plugin \ + "${TMPDIR}" + mkdir -p ${EXTERNAL_INCLUDES} + if [ -d bazel-tensorflow-plugin/external/com_google_absl ]; then + cp -R bazel-tensorflow-plugin/external/com_google_absl "${EXTERNAL_INCLUDES}" + fi + if [ -d bazel-tensorflow-plugin/external/eigen_archive ]; then + cp -R bazel-tensorflow-plugin/external/eigen_archive "${EXTERNAL_INCLUDES}" + fi + # Copy MKL libs over so they can be loaded at runtime + so_lib_dir=$(ls $RUNFILES | grep solib) || true + if [ -n "${so_lib_dir}" ]; then + mkl_so_dir=$(ls ${RUNFILES}/${so_lib_dir} | grep mkl) || true + plugin_so_dir=$(ls ${RUNFILES}/${so_lib_dir} | grep plugin) || true + if [ -n "${mkl_so_dir}" ]; then + mkdir "${TMPDIR}/${so_lib_dir}" + cp -R ${RUNFILES}/${so_lib_dir}/${mkl_so_dir} "${TMPDIR}/${so_lib_dir}" + fi + if [ -n "${plugin_so_dir}" ]; then + #mkdir "${TMPDIR}/${so_lib_dir}" + cp -R -d ${RUNFILES}/${so_lib_dir}/${plugin_so_dir} "${TMPDIR}/tensorflow-plugins" + fi + fi + else + # New-style runfiles structure (--nolegacy_external_runfiles). + cp -R \ + bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package.runfiles/org_tensorflow_plugin/plugin \ + "${TMPDIR}" + cp_external \ + bazel-bin/tensorflow_plugin/tools/pip_package/build_pip_package.runfiles \ + "${EXTERNAL_INCLUDES}" + # Copy MKL libs over so they can be loaded at runtime + so_lib_dir=$(ls $RUNFILES | grep solib) || true + if [ -n "${so_lib_dir}" ]; then + mkl_so_dir=$(ls ${RUNFILES}/${so_lib_dir} | grep mkl) || true + if [ -n "${mkl_so_dir}" ]; then + mkdir "${TMPDIR}/${so_lib_dir}" + cp -R ${RUNFILES}/${so_lib_dir}/${mkl_so_dir} "${TMPDIR}/${so_lib_dir}" + fi + fi + fi + + + cp tensorflow_plugin/tools/pip_package/MANIFEST.in ${TMPDIR} + cp tensorflow_plugin/tools/pip_package/README ${TMPDIR} + cp tensorflow_plugin/tools/pip_package/setup.py ${TMPDIR} + # my_plugin_dir should be the same with _MY_PLUGIN_PATH in setup.py + mkdir -p ${TMPDIR}/my_plugin_dir + cp -r tensorflow_plugin/python/ ${TMPDIR}/my_plugin_dir + touch ${TMPDIR}/my_plugin_dir/__init__.py + if [ -d ${TMPDIR}/tensorflow_plugin ] ; then + mv ${TMPDIR}/tensorflow_plugin/* ${TMPDIR}/tensorflow-plugins + fi + +} + +function build_wheel() { + if [ $# -lt 2 ] ; then + echo "No src and dest dir provided" + exit 1 + fi + + TMPDIR="$1" + DEST="$2" + PKG_NAME_FLAG="$3" + + # Before we leave the top-level directory, make sure we know how to + # call python. + if [[ -e tools/python_bin_path.sh ]]; then + source tools/python_bin_path.sh + fi + + pushd ${TMPDIR} > /dev/null + rm -f MANIFEST + echo $(date) : "=== Building wheel" + "${PYTHON_BIN_PATH:-python}" setup.py bdist_wheel ${PKG_NAME_FLAG} >/dev/null + mkdir -p ${DEST} + cp dist/* ${DEST} + popd > /dev/null + echo $(date) : "=== Output wheel file is in: ${DEST}" +} + +function usage() { + echo "Usage:" + echo "$0 [--src srcdir] [--dst dstdir] [options]" + echo "$0 dstdir [options]" + echo "" + echo " --src prepare sources in srcdir" + echo " will use temporary dir if not specified" + echo "" + echo " --dst build wheel in dstdir" + echo " if dstdir is not set do not build, only prepare sources" + echo "" + exit 1 +} + +function main() { + PKG_NAME_FLAG="" + PROJECT_NAME="" + GPU_BUILD=0 + NIGHTLY_BUILD=0 + SRCDIR="" + DSTDIR="" + CLEANSRC=1 + while true; do + if [[ "$1" == "--help" ]]; then + usage + exit 1 + elif [[ "$1" == "--project_name" ]]; then + shift + if [[ -z "$1" ]]; then + break + fi + PROJECT_NAME="$1" + elif [[ "$1" == "--src" ]]; then + shift + SRCDIR="$(real_path $1)" + CLEANSRC=0 + elif [[ "$1" == "--dst" ]]; then + shift + DSTDIR="$(real_path $1)" + else + DSTDIR="$(real_path $1)" + fi + shift + + if [[ -z "$1" ]]; then + break + fi + done + + if [[ -z "$DSTDIR" ]] && [[ -z "$SRCDIR" ]]; then + echo "No destination dir provided" + usage + exit 1 + fi + + if [[ -z "$SRCDIR" ]]; then + # make temp srcdir if none set + SRCDIR="$(mktemp -d -t tmp.XXXXXXXXXX)" + fi + + prepare_src "$SRCDIR" + + if [[ -z "$DSTDIR" ]]; then + # only want to prepare sources + exit + fi + + build_wheel "$SRCDIR" "$DSTDIR" "$PKG_NAME_FLAG" + + if [[ $CLEANSRC -ne 0 ]]; then + rm -rf "${TMPDIR}" + fi +} + +main "$@" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/setup.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/setup.py new file mode 100644 index 000000000..ce1bd8e2e --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/setup.py @@ -0,0 +1,262 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +''' +setup.py file to build wheel for tensorflow_plugin +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import fnmatch +import os +import re +import sys + +from setuptools import Command +from setuptools import find_packages +from setuptools import setup +from setuptools.command.install import install as InstallCommandBase +from setuptools.dist import Distribution + +DOCLINES = __doc__.split('\n') + +# This version string is semver compatible, but incompatible with pip. +# For pip, we will remove all '-' characters from this string, and use the +# result for pip. +# Also update tensorflow/demo_plugin.bzl and +# tensorflow/core/public/version.h +_VERSION = '0.0.1' +# this path can't be modified. +_PLUGIN_LIB_PATH = 'tensorflow-plugins' +_MY_PLUGIN_PATH = 'my_plugin_dir' + +REQUIRED_PACKAGES = [ + 'tensorflow >= 2.5.0', +] + +if sys.byteorder == 'little': + # grpcio does not build correctly on big-endian machines due to lack of + # BoringSSL support. + # See https://github.com/tensorflow/tensorflow/issues/17882. + REQUIRED_PACKAGES.append('grpcio >= 1.8.6') + +# The wheel package name, change it as your requirements +project_name = 'my_tensorflow_plugin_package' + +# python3 requires wheel 0.26 +if sys.version_info.major == 3: + REQUIRED_PACKAGES.append('wheel >= 0.26') +else: + REQUIRED_PACKAGES.append('wheel') + # mock comes with unittest.mock for python3, need to install for python2 + REQUIRED_PACKAGES.append('mock >= 2.0.0') + +# weakref.finalize and enum were introduced in Python 3.4 +if sys.version_info < (3, 4): + REQUIRED_PACKAGES.append('backports.weakref >= 1.0rc1') + REQUIRED_PACKAGES.append('enum34 >= 1.1.6') + +# pylint: disable=line-too-long +CONSOLE_SCRIPTS = [ +# 'freeze_graph = tensorflow.python.tools.freeze_graph:run_main', +# 'toco_from_protos = tensorflow.lite.toco.python.toco_from_protos:main', +# 'tflite_convert = tensorflow.lite.python.tflite_convert:main', +# 'toco = tensorflow.lite.python.tflite_convert:main', +# 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', +# # We need to keep the TensorBoard command, even though the console script +# # is now declared by the tensorboard pip package. If we remove the +# # TensorBoard command, pip will inappropriately remove it during install, +# # even though the command is not removed, just moved to a different wheel. +# 'tensorboard = tensorboard.main:run_main', +# 'tf_upgrade_v2 = tensorflow.tools.compatibility.tf_upgrade_v2_main:main', +] +# pylint: enable=line-too-long + +TEST_PACKAGES = [ + 'scipy >= 0.15.1', +] + + +class BinaryDistribution(Distribution): + + def has_ext_modules(self): + return True + + +class InstallCommand(InstallCommandBase): + """Override the dir where the headers go.""" + + def finalize_options(self): + ret = InstallCommandBase.finalize_options(self) + self.install_headers = os.path.join(self.install_purelib, + 'tensorflow-plugins', 'include') + return ret + + +class InstallHeaders(Command): + """Override how headers are copied. + + The install_headers that comes with setuptools copies all files to + the same directory. But we need the files to be in a specific directory + hierarchy for -I to work correctly. + """ + description = 'install C/C++ header files' + + user_options = [('install-dir=', 'd', + 'directory to install header files to'), + ('force', 'f', + 'force installation (overwrite existing files)'), + ] + + boolean_options = ['force'] + + def initialize_options(self): + self.install_dir = None + self.force = 0 + self.outfiles = [] + + def finalize_options(self): + self.set_undefined_options('install', + ('install_headers', 'install_dir'), + ('force', 'force')) + + def mkdir_and_copy_file(self, header): + install_dir = os.path.join(self.install_dir, os.path.dirname(header)) + # Get rid of some extra intervening directories so we can have fewer + # directories for -I + install_dir = re.sub('/google/protobuf_archive/src', '', install_dir) + + # Copy external code headers into tensorflow/include. + # A symlink would do, but the wheel file that gets created ignores + # symlink within the directory hierarchy. + # NOTE(keveman): Figure out how to customize bdist_wheel package so + # we can do the symlink. + external_header_locations = [ + 'tensorflow-plugins/include/external/eigen_archive/', + 'tensorflow-plugins/include/external/com_google_absl/', + 'tensorflow-plugins/include/external/com_google_protobuf', + ] + for location in external_header_locations: + if location in install_dir: + extra_dir = install_dir.replace(location, '') + if not os.path.exists(extra_dir): + self.mkpath(extra_dir) + self.copy_file(header, extra_dir) + + if not os.path.exists(install_dir): + self.mkpath(install_dir) + return self.copy_file(header, install_dir) + + def run(self): + hdrs = self.distribution.headers + if not hdrs: + return + + self.mkpath(self.install_dir) + for header in hdrs: + (out, _) = self.mkdir_and_copy_file(header) + self.outfiles.append(out) + + def get_inputs(self): + return self.distribution.headers or [] + + def get_outputs(self): + return self.outfiles + + +def find_files(pattern, root): + """Return all the files matching pattern below root dir.""" + for dirpath, _, files in os.walk(root): + for filename in fnmatch.filter(files, pattern): + yield os.path.join(dirpath, filename) + + +so_lib_paths = [ + i for i in os.listdir('.') + if os.path.isdir(i) and fnmatch.fnmatch(i, '_solib_*') +] + +print(os.listdir('.')) +matches = [] +for path in so_lib_paths: + matches.extend( + ['../' + x for x in find_files('*', path) if '.py' not in x] + ) + +if os.name == 'nt': + EXTENSION_NAME = 'libdemo_plugin.pyd' +else: + EXTENSION_NAME = 'libdemo_plugin.so' + +headers = ( + list(find_files('*.h', 'tensorflow-plugins/c_api/c')) + + list(find_files('*.h', 'tensorflow-plugins/c_api/src'))) + +setup( + name=project_name, + version=_VERSION.replace('-', ''), + description=DOCLINES[0], + long_description='\n'.join(DOCLINES[2:]), + url='https://github.com/tensorflow', + download_url='https://github.com/tensorflow', + author='Tensorflow', + author_email='packages@tensorflow.org', + # Contained modules and scripts. + packages= [_PLUGIN_LIB_PATH, _MY_PLUGIN_PATH], + entry_points={ + 'console_scripts': CONSOLE_SCRIPTS, + }, + headers=headers, + install_requires=REQUIRED_PACKAGES, + tests_require=REQUIRED_PACKAGES + TEST_PACKAGES, + package_data={ + _PLUGIN_LIB_PATH: [ + '*.so' + ], + _MY_PLUGIN_PATH: [ + '*', '*/*' + ] + }, + zip_safe=False, + distclass=BinaryDistribution, + cmdclass={ + 'install_headers': InstallHeaders, + 'install': InstallCommand, + }, + # PyPI package information. + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + license='Apache 2.0', + keywords='tensorflow tensor machine learning plugin', +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/simple_console.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/simple_console.py new file mode 100644 index 000000000..4d2554928 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/tools/pip_package/simple_console.py @@ -0,0 +1,34 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +"""Start a simple interactive console with TensorFlow available.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import code +import sys + + +def main(_): + """Run an interactive console.""" + code.interact() + return 0 + + +if __name__ == '__main__': + sys.exit(main(sys.argv)) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/workspace.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/workspace.bzl new file mode 100644 index 000000000..4f4297758 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/tensorflow_plugin/workspace.bzl @@ -0,0 +1,71 @@ +load("//third_party:repo.bzl", "tf_http_archive", "third_party_http_archive") +load("//third_party/build_option:gcc_configure.bzl", "gcc_configure") +load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def clean_dep(dep): + return str(Label(dep)) + +def demo_plugin_workspace(path_prefix = "", tf_repo_name = ""): + """All external dependencies for TF builds""" + gcc_configure(name = "local_config_gcc") + syslibs_configure(name = "local_config_syslibs") + + http_archive( + name = "bazel_toolchains", + sha256 = "109a99384f9d08f9e75136d218ebaebc68cc810c56897aea2224c57932052d30", + strip_prefix = "bazel-toolchains-94d31935a2c94fe7e7c7379a0f3393e181928ff7", + urls = [ + "http://mirror.tensorflow.org/github.com/bazelbuild/bazel-toolchains/archive/94d31935a2c94fe7e7c7379a0f3393e181928ff7.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/94d31935a2c94fe7e7c7379a0f3393e181928ff7.tar.gz", + ], + ) + + tf_http_archive( + name = "eigen_archive", + build_file = clean_dep("//third_party:eigen.BUILD"), + sha256 = "df23a89e4cdfa7de2d81ee28190bd194413e47ff177c94076f845b32d7280344", # SHARED_EIGEN_SHA + strip_prefix = "eigen-5dc2fbabeee17fe023c38756ebde0c1d56472913", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/5dc2fbabeee17fe023c38756ebde0c1d56472913/eigen-5dc2fbabeee17fe023c38756ebde0c1d56472913.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/5dc2fbabeee17fe023c38756ebde0c1d56472913/eigen-5dc2fbabeee17fe023c38756ebde0c1d56472913.tar.gz", + ], + ) + + third_party_http_archive( + name = "com_google_absl", + build_file = clean_dep("//third_party:com_google_absl.BUILD"), + sha256 = "56cd3fbbbd94468a5fff58f5df2b6f9de7a0272870c61f6ca05b869934f4802a", + strip_prefix = "abseil-cpp-daf381e8535a1f1f1b8a75966a74e7cca63dee89", + urls = [ + "http://mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/daf381e8535a1f1f1b8a75966a74e7cca63dee89.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/daf381e8535a1f1f1b8a75966a74e7cca63dee89.tar.gz", + ], + ) + + tf_http_archive( + name = "zlib", + build_file = clean_dep("//third_party:zlib.BUILD"), + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"), + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/zlib.net/zlib-1.2.11.tar.gz", + "https://zlib.net/zlib-1.2.11.tar.gz", + ], + ) + + tf_http_archive( + name = "com_google_protobuf", + patch_file = clean_dep("//third_party/protobuf:protobuf.patch"), + sha256 = "cfcba2df10feec52a84208693937c17a4b5df7775e1635c1e3baffc487b24c9b", + strip_prefix = "protobuf-3.9.2", + system_build_file = clean_dep("//third_party/systemlibs:protobuf.BUILD"), + system_link_files = { + "//third_party/systemlibs:protobuf.bzl": "protobuf.bzl", + }, + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/protocolbuffers/protobuf/archive/v3.9.2.zip", + "https://github.com/protocolbuffers/protobuf/archive/v3.9.2.zip", + ], + ) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/test_profiler.py b/rfcs/20200624-pluggable-device-for-tensorflow/sample/test_profiler.py new file mode 100644 index 000000000..bf07b61c2 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/test_profiler.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +import tensorflow as tf +import numpy as np +import os +tf.compat.v1.disable_eager_execution() + +profile_options = tf.profiler.experimental.ProfilerOptions( + host_tracer_level = 3, + device_tracer_level = 1) + +logpath = os.path.join('data', 'logs', 'profiler_demo') + +a = tf.random.normal(shape=[1,10, 10, 8], dtype=tf.float32, seed=1) +w = tf.random.normal(shape=[3, 3, 8, 4], dtype=tf.float32, seed=1) + +a1 = tf.random.normal(shape=[1, 10, 10, 8], dtype=tf.float32, seed=1) +w1 = tf.random.normal(shape=[3, 3, 8, 4], dtype=tf.float32, seed=1) + + +with tf.device("/MY_DEVICE:0"): + tf.profiler.experimental.start(logpath) + b = tf.nn.relu(a) + c = tf.nn.conv2d(b, w, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC') + tf.profiler.experimental.stop() +with tf.device("/CPU:0"): + b1 = tf.nn.relu(a1) + c1 = tf.nn.conv2d(b1, w1, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC') + + +sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=False, log_device_placement=True)) +print(sess.run(tf.reduce_all(tf.less(c - c1, 1e-5)))) + diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/BUILD new file mode 100644 index 000000000..ffd0fb0cd --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/BUILD @@ -0,0 +1 @@ +package(default_visibility = ["//visibility:public"]) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/build_option/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/build_option/BUILD new file mode 100644 index 000000000..ffd0fb0cd --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/build_option/BUILD @@ -0,0 +1 @@ +package(default_visibility = ["//visibility:public"]) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/build_option/gcc_configure.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/build_option/gcc_configure.bzl new file mode 100644 index 000000000..2d73251af --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/build_option/gcc_configure.bzl @@ -0,0 +1,19 @@ +_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR" + +def _cpu_autoconf_imp(repository_ctx): + tf_lib_path = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] + repository_ctx.symlink(tf_lib_path, "proper") + repository_ctx.file("BUILD", """ +cc_import( + name = "framework_lib", + shared_library = "proper/libtensorflow_framework.so.2", + # interface_library = "libtensorflow_framework.so", + # system_provided = 1, + visibility = ['//visibility:public'], +) +""") + +gcc_configure = repository_rule( + implementation = _cpu_autoconf_imp, + local = True, +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/com_google_absl.BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/com_google_absl.BUILD new file mode 100644 index 000000000..8fca145f7 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/com_google_absl.BUILD @@ -0,0 +1,5 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache + +exports_files(["LICENSE"]) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/common.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/common.bzl new file mode 100644 index 000000000..148035c49 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/common.bzl @@ -0,0 +1,42 @@ +# Rule for simple expansion of template files. This performs a simple +# search over the template file for the keys in substitutions, +# and replaces them with the corresponding values. +# +# Typical usage: +# load("/tools/build_rules/template_rule", "expand_header_template") +# template_rule( +# name = "ExpandMyTemplate", +# src = "my.template", +# out = "my.txt", +# substitutions = { +# "$VAR1": "foo", +# "$VAR2": "bar", +# } +# ) +# +# Args: +# name: The name of the rule. +# template: The template file to expand +# out: The destination of the expanded file +# substitutions: A dictionary mapping strings to their substitutions + +def template_rule_impl(ctx): + ctx.actions.expand_template( + template = ctx.file.src, + output = ctx.outputs.out, + substitutions = ctx.attr.substitutions, + ) + +template_rule = rule( + attrs = { + "src": attr.label( + mandatory = True, + allow_single_file = True, + ), + "substitutions": attr.string_dict(mandatory = True), + "out": attr.output(mandatory = True), + }, + # output_to_genfiles is required for header files. + #output_to_genfiles = True, + implementation = template_rule_impl, +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen.BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen.BUILD new file mode 100644 index 000000000..042e7396a --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen.BUILD @@ -0,0 +1,79 @@ +# Description: +# Eigen is a C++ template library for linear algebra: vectors, +# matrices, and related algorithms. + +licenses([ + # Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code. + # We've taken special care to not reference any restricted code. + "reciprocal", # MPL2 + "notice", # Portions BSD +]) + +exports_files(["COPYING.MPL2"]) + +# License-restricted (i.e. not reciprocal or notice) files inside Eigen/... +EIGEN_RESTRICTED_FILES = [ + "Eigen/src/OrderingMethods/Amd.h", + "Eigen/src/SparseCholesky/**", +] + +# Notable transitive dependencies of restricted files inside Eigen/... +EIGEN_RESTRICTED_DEPS = [ + "Eigen/Eigen", + "Eigen/IterativeLinearSolvers", + "Eigen/MetisSupport", + "Eigen/Sparse", + "Eigen/SparseCholesky", + "Eigen/SparseLU", +] + +EIGEN_FILES = [ + "Eigen/**", + "unsupported/Eigen/CXX11/**", + "unsupported/Eigen/FFT", + "unsupported/Eigen/KroneckerProduct", + "unsupported/Eigen/src/FFT/**", + "unsupported/Eigen/src/KroneckerProduct/**", + "unsupported/Eigen/MatrixFunctions", + "unsupported/Eigen/SpecialFunctions", + "unsupported/Eigen/src/MatrixFunctions/**", + "unsupported/Eigen/src/SpecialFunctions/**", +] + +# List of files picked up by glob but actually part of another target. +EIGEN_EXCLUDE_FILES = [ + "Eigen/src/Core/arch/AVX/PacketMathGoogleTest.cc", +] + +# Files known to be under MPL2 license. +EIGEN_MPL2_HEADER_FILES = glob( + EIGEN_FILES, + exclude = EIGEN_EXCLUDE_FILES + + EIGEN_RESTRICTED_FILES + + EIGEN_RESTRICTED_DEPS + [ + # Guarantees any file missed by excludes above will not compile. + "Eigen/src/Core/util/NonMPL2.h", + "Eigen/**/CMakeLists.txt", + ], +) + +cc_library( + name = "eigen", + hdrs = EIGEN_MPL2_HEADER_FILES, + defines = [ + # This define (mostly) guarantees we don't link any problematic + # code. We use it, but we do not rely on it, as evidenced above. + "EIGEN_MPL2_ONLY", + "EIGEN_MAX_ALIGN_BYTES=64", + "EIGEN_HAS_TYPE_TRAITS=0", + "EIGEN_USE_THREADS=1", + ], + includes = ["."], + visibility = ["//visibility:public"], +) + +filegroup( + name = "eigen_header_files", + srcs = EIGEN_MPL2_HEADER_FILES, + visibility = ["//visibility:public"], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/BUILD new file mode 100644 index 000000000..6c2f481eb --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/BUILD @@ -0,0 +1,76 @@ +# Description: +# Eigen is a C++ template library for linear algebra: vectors, +# matrices, and related algorithms. + +licenses([ + # Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code. + # We've taken special care to not reference any restricted code. + "reciprocal", # MPL2 + "notice", # Portions BSD +]) + +exports_files(["LICENSE"]) + +#load("//third_party/mkl:build_defs.bzl", "if_mkl") + +EIGEN3_THIRD_PARTY_HEADERS = [ + "Eigen/Core", + "Eigen/LU", + "Eigen/Cholesky", + "Eigen/Eigenvalues", + "Eigen/OrderingMethods", + "Eigen/QR", + "Eigen/SparseCholesky", + "Eigen/SparseCore", + "Eigen/SVD", + "unsupported/Eigen/MatrixFunctions", + "unsupported/Eigen/SpecialFunctions", + "unsupported/Eigen/CXX11/ThreadPool", + "unsupported/Eigen/CXX11/Tensor", + "unsupported/Eigen/CXX11/FixedPoint", +] + glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + +cc_library( + name = "eigen3", + hdrs = EIGEN3_THIRD_PARTY_HEADERS, + includes = [], #+ if_mkl(["./mkl_include"]), + visibility = ["//visibility:public"], + deps = [ + "@eigen_archive//:eigen", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = ["**/OWNERS"], + ), + visibility = ["//tensorflow_plugin:__subpackages__"], +) + +filegroup( + name = "eigen_third_party_header_files", + srcs = EIGEN3_THIRD_PARTY_HEADERS, + visibility = ["//visibility:public"], +) + +genrule( + name = "install_eigen_headers", + srcs = [ + "@eigen_archive//:eigen_header_files", + ":eigen_third_party_header_files", + ], + outs = ["include"], + cmd = """ + mkdir $@ + for f in $(SRCS); do + d="$${f%/*}" + d="$${d#*external/eigen_archive/}" + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + """, + tags = ["manual"], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Cholesky b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Cholesky new file mode 100644 index 000000000..c199a0255 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Cholesky @@ -0,0 +1 @@ +#include "Eigen/Cholesky" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Core b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Core new file mode 100644 index 000000000..d4b036772 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Core @@ -0,0 +1 @@ +#include "Eigen/Core" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Eigenvalues b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Eigenvalues new file mode 100644 index 000000000..bf739b9b8 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/Eigenvalues @@ -0,0 +1 @@ +#include "Eigen/Eigenvalues" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/LU b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/LU new file mode 100644 index 000000000..536149cea --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/LU @@ -0,0 +1 @@ +#include "Eigen/LU" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/OrderingMethods b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/OrderingMethods new file mode 100644 index 000000000..190fc224b --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/OrderingMethods @@ -0,0 +1 @@ +#include "Eigen/OrderingMethods" \ No newline at end of file diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/QR b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/QR new file mode 100644 index 000000000..be067d3ed --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/QR @@ -0,0 +1 @@ +#include "Eigen/QR" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SVD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SVD new file mode 100644 index 000000000..eecf47c10 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SVD @@ -0,0 +1 @@ +#include "Eigen/SVD" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SparseCholesky b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SparseCholesky new file mode 100644 index 000000000..a6d362b9c --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SparseCholesky @@ -0,0 +1 @@ +#include "Eigen/SparseCholesky" \ No newline at end of file diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SparseCore b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SparseCore new file mode 100644 index 000000000..3c60745d0 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/Eigen/SparseCore @@ -0,0 +1 @@ +#include "Eigen/SparseCore" \ No newline at end of file diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/LICENSE b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/LICENSE new file mode 100644 index 000000000..c355a5ec0 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/LICENSE @@ -0,0 +1,1938 @@ +Eigen is primarily MPL2 licensed. See COPYING.MPL2 and these links: + http://www.mozilla.org/MPL/2.0/ + http://www.mozilla.org/MPL/2.0/FAQ.html + +Some files contain third-party code under BSD or LGPL licenses, whence +the other COPYING.* files here. + +All the LGPL code is either LGPL 2.1-only, or LGPL 2.1-or-later. +For this reason, the COPYING.LGPL file contains the LGPL 2.1 text. + +If you want to guarantee that the Eigen code that you are #including +is licensed under the MPL2 and possibly more permissive licenses (like +BSD), #define this preprocessor symbol: EIGEN_MPL2_ONLY +For example, with most compilers, you could add this to your project + CXXFLAGS: -DEIGEN_MPL2_ONLY +This will cause a compilation error to be generated if you #include +any code that is LGPL licensed. + +---------------------------------------------------------------------- +Following applies to: +./test/mapstaticmethods.cpp +./test/schur_real.cpp +./test/prec_inverse_4x4.cpp +./test/smallvectors.cpp +./test/redux.cpp +./test/special_numbers.cpp +./test/adjoint.cpp +./test/resize.cpp +./test/mixingtypes.cpp +./test/product_trmv.cpp +./test/sparse_solvers.cpp +./test/cholesky.cpp +./test/geo_quaternion.cpp +./test/miscmatrices.cpp +./test/stddeque.cpp +./test/integer_types.cpp +./test/product_large.cpp +./test/eigensolver_generic.cpp +./test/householder.cpp +./test/geo_orthomethods.cpp +./test/array_for_matrix.cpp +./test/sparseLM.cpp +./test/upperbidiagonalization.cpp +./test/nomalloc.cpp +./test/packetmath.cpp +./test/jacobisvd.cpp +./test/geo_transformations.cpp +./test/swap.cpp +./test/eigensolver_selfadjoint.cpp +./test/inverse.cpp +./test/product_selfadjoint.cpp +./test/product_trsolve.cpp +./test/product_extra.cpp +./test/sparse_solver.h +./test/mapstride.cpp +./test/mapped_matrix.cpp +./test/geo_eulerangles.cpp +./test/eigen2support.cpp +./test/denseLM.cpp +./test/stdvector.cpp +./test/nesting_ops.cpp +./test/sparse_permutations.cpp +./test/zerosized.cpp +./test/exceptions.cpp +./test/vectorwiseop.cpp +./test/cwiseop.cpp +./test/basicstuff.cpp +./test/product_trmm.cpp +./test/linearstructure.cpp +./test/sparse_product.cpp +./test/stdvector_overload.cpp +./test/stable_norm.cpp +./test/umeyama.cpp +./test/unalignedcount.cpp +./test/triangular.cpp +./test/product_mmtr.cpp +./test/sparse_basic.cpp +./test/sparse_vector.cpp +./test/meta.cpp +./test/real_qz.cpp +./test/ref.cpp +./test/eigensolver_complex.cpp +./test/cholmod_support.cpp +./test/conjugate_gradient.cpp +./test/sparse.h +./test/simplicial_cholesky.cpp +./test/bicgstab.cpp +./test/dynalloc.cpp +./test/product_notemporary.cpp +./test/geo_hyperplane.cpp +./test/lu.cpp +./test/qr.cpp +./test/hessenberg.cpp +./test/sizeof.cpp +./test/main.h +./test/selfadjoint.cpp +./test/permutationmatrices.cpp +./test/superlu_support.cpp +./test/qtvector.cpp +./test/geo_homogeneous.cpp +./test/determinant.cpp +./test/array_reverse.cpp +./test/unalignedassert.cpp +./test/stdlist.cpp +./test/product_symm.cpp +./test/corners.cpp +./test/dontalign.cpp +./test/visitor.cpp +./test/geo_alignedbox.cpp +./test/diagonalmatrices.cpp +./test/product_small.cpp +./test/eigensolver_generalized_real.cpp +./test/umfpack_support.cpp +./test/first_aligned.cpp +./test/qr_fullpivoting.cpp +./test/array_replicate.cpp +./test/geo_parametrizedline.cpp +./test/eigen2/eigen2_unalignedassert.cpp +./test/eigen2/eigen2_prec_inverse_4x4.cpp +./test/eigen2/eigen2_alignedbox.cpp +./test/eigen2/eigen2_sparse_product.cpp +./test/eigen2/eigen2_meta.cpp +./test/eigen2/eigen2_nomalloc.cpp +./test/eigen2/eigen2_visitor.cpp +./test/eigen2/eigen2_packetmath.cpp +./test/eigen2/eigen2_svd.cpp +./test/eigen2/eigen2_mixingtypes.cpp +./test/eigen2/eigen2_qr.cpp +./test/eigen2/eigen2_cwiseop.cpp +./test/eigen2/eigen2_geometry_with_eigen2_prefix.cpp +./test/eigen2/eigen2_smallvectors.cpp +./test/eigen2/eigen2_commainitializer.cpp +./test/eigen2/eigen2_sparse_solvers.cpp +./test/eigen2/eigen2_hyperplane.cpp +./test/eigen2/eigen2_eigensolver.cpp +./test/eigen2/eigen2_linearstructure.cpp +./test/eigen2/eigen2_sizeof.cpp +./test/eigen2/eigen2_parametrizedline.cpp +./test/eigen2/eigen2_lu.cpp +./test/eigen2/eigen2_adjoint.cpp +./test/eigen2/eigen2_geometry.cpp +./test/eigen2/eigen2_stdvector.cpp +./test/eigen2/eigen2_newstdvector.cpp +./test/eigen2/eigen2_submatrices.cpp +./test/eigen2/sparse.h +./test/eigen2/eigen2_swap.cpp +./test/eigen2/eigen2_triangular.cpp +./test/eigen2/eigen2_basicstuff.cpp +./test/eigen2/gsl_helper.h +./test/eigen2/eigen2_dynalloc.cpp +./test/eigen2/eigen2_array.cpp +./test/eigen2/eigen2_map.cpp +./test/eigen2/main.h +./test/eigen2/eigen2_miscmatrices.cpp +./test/eigen2/eigen2_product_large.cpp +./test/eigen2/eigen2_first_aligned.cpp +./test/eigen2/eigen2_cholesky.cpp +./test/eigen2/eigen2_determinant.cpp +./test/eigen2/eigen2_sum.cpp +./test/eigen2/eigen2_inverse.cpp +./test/eigen2/eigen2_regression.cpp +./test/eigen2/eigen2_product_small.cpp +./test/eigen2/eigen2_qtvector.cpp +./test/eigen2/eigen2_sparse_vector.cpp +./test/eigen2/product.h +./test/eigen2/eigen2_sparse_basic.cpp +./test/eigen2/eigen2_bug_132.cpp +./test/array.cpp +./test/product_syrk.cpp +./test/commainitializer.cpp +./test/conservative_resize.cpp +./test/qr_colpivoting.cpp +./test/nullary.cpp +./test/bandmatrix.cpp +./test/pastix_support.cpp +./test/product.h +./test/block.cpp +./test/vectorization_logic.cpp +./test/jacobi.cpp +./test/diagonal.cpp +./test/schur_complex.cpp +./test/sizeoverflow.cpp +./bench/BenchTimer.h +./bench/benchFFT.cpp +./bench/eig33.cpp +./bench/spbench/spbenchsolver.h +./bench/spbench/spbenchstyle.h +./lapack/complex_double.cpp +./lapack/cholesky.cpp +./lapack/lapack_common.h +./lapack/eigenvalues.cpp +./lapack/single.cpp +./lapack/lu.cpp +./lapack/complex_single.cpp +./lapack/double.cpp +./demos/mix_eigen_and_c/binary_library.cpp +./demos/mix_eigen_and_c/binary_library.h +./demos/mix_eigen_and_c/example.c +./demos/mandelbrot/mandelbrot.cpp +./demos/mandelbrot/mandelbrot.h +./demos/opengl/icosphere.cpp +./demos/opengl/icosphere.h +./demos/opengl/camera.cpp +./demos/opengl/quaternion_demo.h +./demos/opengl/camera.h +./demos/opengl/trackball.h +./demos/opengl/gpuhelper.h +./demos/opengl/trackball.cpp +./demos/opengl/gpuhelper.cpp +./demos/opengl/quaternion_demo.cpp +./debug/gdb/printers.py +./unsupported/test/minres.cpp +./unsupported/test/openglsupport.cpp +./unsupported/test/jacobisvd.cpp +./unsupported/test/dgmres.cpp +./unsupported/test/matrix_square_root.cpp +./unsupported/test/bdcsvd.cpp +./unsupported/test/matrix_exponential.cpp +./unsupported/test/forward_adolc.cpp +./unsupported/test/polynomialsolver.cpp +./unsupported/test/matrix_function.cpp +./unsupported/test/sparse_extra.cpp +./unsupported/test/matrix_functions.h +./unsupported/test/svd_common.h +./unsupported/test/FFTW.cpp +./unsupported/test/alignedvector3.cpp +./unsupported/test/autodiff.cpp +./unsupported/test/gmres.cpp +./unsupported/test/BVH.cpp +./unsupported/test/levenberg_marquardt.cpp +./unsupported/test/matrix_power.cpp +./unsupported/test/kronecker_product.cpp +./unsupported/test/splines.cpp +./unsupported/test/polynomialutils.cpp +./unsupported/bench/bench_svd.cpp +./unsupported/Eigen/IterativeSolvers +./unsupported/Eigen/src/IterativeSolvers/DGMRES.h +./unsupported/Eigen/src/IterativeSolvers/IncompleteLU.h +./unsupported/Eigen/src/IterativeSolvers/GMRES.h +./unsupported/Eigen/src/IterativeSolvers/IncompleteCholesky.h +./unsupported/Eigen/src/IterativeSolvers/Scaling.h +./unsupported/Eigen/src/IterativeSolvers/MINRES.h +./unsupported/Eigen/src/SparseExtra/RandomSetter.h +./unsupported/Eigen/src/SparseExtra/MatrixMarketIterator.h +./unsupported/Eigen/src/SparseExtra/DynamicSparseMatrix.h +./unsupported/Eigen/src/SparseExtra/MarketIO.h +./unsupported/Eigen/src/SparseExtra/BlockOfDynamicSparseMatrix.h +./unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h +./unsupported/Eigen/src/NonLinearOptimization/LevenbergMarquardt.h +./unsupported/Eigen/src/NonLinearOptimization/HybridNonLinearSolver.h +./unsupported/Eigen/src/BVH/BVAlgorithms.h +./unsupported/Eigen/src/BVH/KdBVH.h +./unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h +./unsupported/Eigen/src/AutoDiff/AutoDiffJacobian.h +./unsupported/Eigen/src/AutoDiff/AutoDiffVector.h +./unsupported/Eigen/src/Splines/Spline.h +./unsupported/Eigen/src/Splines/SplineFitting.h +./unsupported/Eigen/src/Splines/SplineFwd.h +./unsupported/Eigen/src/SVD/JacobiSVD.h +./unsupported/Eigen/src/SVD/BDCSVD.h +./unsupported/Eigen/src/SVD/SVDBase.h +./unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h +./unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h +./unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h +./unsupported/Eigen/src/MatrixFunctions/StemFunction.h +./unsupported/Eigen/src/MatrixFunctions/MatrixPower.h +./unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h +./unsupported/Eigen/src/MatrixFunctions/MatrixFunctionAtomic.h +./unsupported/Eigen/src/MoreVectorization/MathFunctions.h +./unsupported/Eigen/src/LevenbergMarquardt/LevenbergMarquardt.h +./unsupported/Eigen/src/FFT/ei_fftw_impl.h +./unsupported/Eigen/src/FFT/ei_kissfft_impl.h +./unsupported/Eigen/src/Polynomials/PolynomialSolver.h +./unsupported/Eigen/src/Polynomials/Companion.h +./unsupported/Eigen/src/Polynomials/PolynomialUtils.h +./unsupported/Eigen/src/NumericalDiff/NumericalDiff.h +./unsupported/Eigen/src/Skyline/SkylineProduct.h +./unsupported/Eigen/src/Skyline/SkylineMatrixBase.h +./unsupported/Eigen/src/Skyline/SkylineStorage.h +./unsupported/Eigen/src/Skyline/SkylineUtil.h +./unsupported/Eigen/src/Skyline/SkylineInplaceLU.h +./unsupported/Eigen/src/Skyline/SkylineMatrix.h +./unsupported/Eigen/SparseExtra +./unsupported/Eigen/AdolcForward +./unsupported/Eigen/KroneckerProduct +./unsupported/Eigen/NonLinearOptimization +./unsupported/Eigen/BVH +./unsupported/Eigen/OpenGLSupport +./unsupported/Eigen/ArpackSupport +./unsupported/Eigen/AutoDiff +./unsupported/Eigen/Splines +./unsupported/Eigen/MPRealSupport +./unsupported/Eigen/MatrixFunctions +./unsupported/Eigen/MoreVectorization +./unsupported/Eigen/LevenbergMarquardt +./unsupported/Eigen/AlignedVector3 +./unsupported/Eigen/FFT +./unsupported/Eigen/Polynomials +./unsupported/Eigen/NumericalDiff +./unsupported/Eigen/Skyline +./COPYING.README +./COPYING.README +./LICENSE +./LICENSE +./LICENSE +./Eigen/Eigen2Support +./Eigen/src/Eigen2Support/VectorBlock.h +./Eigen/src/Eigen2Support/Cwise.h +./Eigen/src/Eigen2Support/Minor.h +./Eigen/src/Eigen2Support/Lazy.h +./Eigen/src/Eigen2Support/Memory.h +./Eigen/src/Eigen2Support/MathFunctions.h +./Eigen/src/Eigen2Support/Geometry/AlignedBox.h +./Eigen/src/Eigen2Support/Geometry/Hyperplane.h +./Eigen/src/Eigen2Support/Geometry/Quaternion.h +./Eigen/src/Eigen2Support/Geometry/Rotation2D.h +./Eigen/src/Eigen2Support/Geometry/ParametrizedLine.h +./Eigen/src/Eigen2Support/Geometry/RotationBase.h +./Eigen/src/Eigen2Support/Geometry/Translation.h +./Eigen/src/Eigen2Support/Geometry/Scaling.h +./Eigen/src/Eigen2Support/Geometry/AngleAxis.h +./Eigen/src/Eigen2Support/Geometry/Transform.h +./Eigen/src/Eigen2Support/TriangularSolver.h +./Eigen/src/Eigen2Support/LU.h +./Eigen/src/Eigen2Support/QR.h +./Eigen/src/Eigen2Support/SVD.h +./Eigen/src/Eigen2Support/Meta.h +./Eigen/src/Eigen2Support/Block.h +./Eigen/src/Eigen2Support/Macros.h +./Eigen/src/Eigen2Support/LeastSquares.h +./Eigen/src/Eigen2Support/CwiseOperators.h +./Eigen/src/Jacobi/Jacobi.h +./Eigen/src/misc/Kernel.h +./Eigen/src/misc/SparseSolve.h +./Eigen/src/misc/Solve.h +./Eigen/src/misc/Image.h +./Eigen/src/SparseCore/SparseColEtree.h +./Eigen/src/SparseCore/SparseTranspose.h +./Eigen/src/SparseCore/SparseUtil.h +./Eigen/src/SparseCore/SparseCwiseBinaryOp.h +./Eigen/src/SparseCore/SparseDiagonalProduct.h +./Eigen/src/SparseCore/SparseProduct.h +./Eigen/src/SparseCore/SparseDot.h +./Eigen/src/SparseCore/SparseCwiseUnaryOp.h +./Eigen/src/SparseCore/SparseSparseProductWithPruning.h +./Eigen/src/SparseCore/SparseBlock.h +./Eigen/src/SparseCore/SparseDenseProduct.h +./Eigen/src/SparseCore/CompressedStorage.h +./Eigen/src/SparseCore/SparseMatrixBase.h +./Eigen/src/SparseCore/MappedSparseMatrix.h +./Eigen/src/SparseCore/SparseTriangularView.h +./Eigen/src/SparseCore/SparseView.h +./Eigen/src/SparseCore/SparseFuzzy.h +./Eigen/src/SparseCore/TriangularSolver.h +./Eigen/src/SparseCore/SparseSelfAdjointView.h +./Eigen/src/SparseCore/SparseMatrix.h +./Eigen/src/SparseCore/SparseVector.h +./Eigen/src/SparseCore/AmbiVector.h +./Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +./Eigen/src/SparseCore/SparseRedux.h +./Eigen/src/SparseCore/SparsePermutation.h +./Eigen/src/Eigenvalues/RealSchur.h +./Eigen/src/Eigenvalues/ComplexEigenSolver.h +./Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +./Eigen/src/Eigenvalues/ComplexSchur.h +./Eigen/src/Eigenvalues/RealQZ.h +./Eigen/src/Eigenvalues/EigenSolver.h +./Eigen/src/Eigenvalues/HessenbergDecomposition.h +./Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +./Eigen/src/Eigenvalues/Tridiagonalization.h +./Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +./Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +./Eigen/src/SuperLUSupport/SuperLUSupport.h +./Eigen/src/StlSupport/StdDeque.h +./Eigen/src/StlSupport/StdVector.h +./Eigen/src/StlSupport/StdList.h +./Eigen/src/StlSupport/details.h +./Eigen/src/SparseQR/SparseQR.h +./Eigen/src/LU/Inverse.h +./Eigen/src/LU/arch/Inverse_SSE.h +./Eigen/src/LU/Determinant.h +./Eigen/src/LU/PartialPivLU.h +./Eigen/src/LU/FullPivLU.h +./Eigen/src/UmfPackSupport/UmfPackSupport.h +./Eigen/src/OrderingMethods/Ordering.h +./Eigen/src/OrderingMethods/Eigen_Colamd.h +./Eigen/src/QR/HouseholderQR.h +./Eigen/src/QR/ColPivHouseholderQR.h +./Eigen/src/QR/FullPivHouseholderQR.h +./Eigen/src/SVD/JacobiSVD.h +./Eigen/src/SVD/UpperBidiagonalization.h +./Eigen/src/Geometry/OrthoMethods.h +./Eigen/src/Geometry/AlignedBox.h +./Eigen/src/Geometry/Hyperplane.h +./Eigen/src/Geometry/Quaternion.h +./Eigen/src/Geometry/EulerAngles.h +./Eigen/src/Geometry/Rotation2D.h +./Eigen/src/Geometry/ParametrizedLine.h +./Eigen/src/Geometry/RotationBase.h +./Eigen/src/Geometry/arch/Geometry_SSE.h +./Eigen/src/Geometry/Umeyama.h +./Eigen/src/Geometry/Homogeneous.h +./Eigen/src/Geometry/Translation.h +./Eigen/src/Geometry/Scaling.h +./Eigen/src/Geometry/AngleAxis.h +./Eigen/src/Geometry/Transform.h +./Eigen/src/plugins/BlockMethods.h +./Eigen/src/plugins/CommonCwiseUnaryOps.h +./Eigen/src/plugins/CommonCwiseBinaryOps.h +./Eigen/src/plugins/MatrixCwiseUnaryOps.h +./Eigen/src/plugins/MatrixCwiseBinaryOps.h +./Eigen/src/Householder/Householder.h +./Eigen/src/Householder/HouseholderSequence.h +./Eigen/src/Householder/BlockHouseholder.h +./Eigen/src/Core/VectorBlock.h +./Eigen/src/Core/Matrix.h +./Eigen/src/Core/Ref.h +./Eigen/src/Core/SelfAdjointView.h +./Eigen/src/Core/MathFunctions.h +./Eigen/src/Core/GlobalFunctions.h +./Eigen/src/Core/MapBase.h +./Eigen/src/Core/EigenBase.h +./Eigen/src/Core/GenericPacketMath.h +./Eigen/src/Core/NestByValue.h +./Eigen/src/Core/CwiseUnaryOp.h +./Eigen/src/Core/SolveTriangular.h +./Eigen/src/Core/Fuzzy.h +./Eigen/src/Core/Visitor.h +./Eigen/src/Core/Map.h +./Eigen/src/Core/NoAlias.h +./Eigen/src/Core/Diagonal.h +./Eigen/src/Core/StableNorm.h +./Eigen/src/Core/CoreIterators.h +./Eigen/src/Core/products/Parallelizer.h +./Eigen/src/Core/products/SelfadjointMatrixVector.h +./Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +./Eigen/src/Core/products/TriangularSolverMatrix.h +./Eigen/src/Core/products/GeneralMatrixMatrix.h +./Eigen/src/Core/products/SelfadjointProduct.h +./Eigen/src/Core/products/CoeffBasedProduct.h +./Eigen/src/Core/products/TriangularMatrixVector.h +./Eigen/src/Core/products/SelfadjointMatrixMatrix.h +./Eigen/src/Core/products/TriangularSolverVector.h +./Eigen/src/Core/products/SelfadjointRank2Update.h +./Eigen/src/Core/products/GeneralBlockPanelKernel.h +./Eigen/src/Core/products/GeneralMatrixVector.h +./Eigen/src/Core/products/TriangularMatrixMatrix.h +./Eigen/src/Core/Reverse.h +./Eigen/src/Core/BooleanRedux.h +./Eigen/src/Core/Replicate.h +./Eigen/src/Core/arch/AltiVec/PacketMath.h +./Eigen/src/Core/arch/AltiVec/Complex.h +./Eigen/src/Core/arch/SSE/PacketMath.h +./Eigen/src/Core/arch/SSE/Complex.h +./Eigen/src/Core/arch/SSE/MathFunctions.h +./Eigen/src/Core/arch/NEON/PacketMath.h +./Eigen/src/Core/arch/NEON/Complex.h +./Eigen/src/Core/arch/Default/Settings.h +./Eigen/src/Core/CwiseUnaryView.h +./Eigen/src/Core/Array.h +./Eigen/src/Core/ArrayWrapper.h +./Eigen/src/Core/Swap.h +./Eigen/src/Core/Transpositions.h +./Eigen/src/Core/Random.h +./Eigen/src/Core/IO.h +./Eigen/src/Core/SelfCwiseBinaryOp.h +./Eigen/src/Core/VectorwiseOp.h +./Eigen/src/Core/Select.h +./Eigen/src/Core/ArrayBase.h +./Eigen/src/Core/DenseCoeffsBase.h +./Eigen/src/Core/DiagonalProduct.h +./Eigen/src/Core/Assign.h +./Eigen/src/Core/Redux.h +./Eigen/src/Core/ForceAlignedAccess.h +./Eigen/src/Core/BandMatrix.h +./Eigen/src/Core/PlainObjectBase.h +./Eigen/src/Core/DenseBase.h +./Eigen/src/Core/Flagged.h +./Eigen/src/Core/CwiseBinaryOp.h +./Eigen/src/Core/ProductBase.h +./Eigen/src/Core/TriangularMatrix.h +./Eigen/src/Core/Transpose.h +./Eigen/src/Core/DiagonalMatrix.h +./Eigen/src/Core/Dot.h +./Eigen/src/Core/Functors.h +./Eigen/src/Core/PermutationMatrix.h +./Eigen/src/Core/NumTraits.h +./Eigen/src/Core/MatrixBase.h +./Eigen/src/Core/DenseStorage.h +./Eigen/src/Core/util/Memory.h +./Eigen/src/Core/util/StaticAssert.h +./Eigen/src/Core/util/BlasUtil.h +./Eigen/src/Core/util/MatrixMapper.h +./Eigen/src/Core/util/XprHelper.h +./Eigen/src/Core/util/ForwardDeclarations.h +./Eigen/src/Core/util/Meta.h +./Eigen/src/Core/util/Macros.h +./Eigen/src/Core/util/Constants.h +./Eigen/src/Core/CwiseNullaryOp.h +./Eigen/src/Core/Block.h +./Eigen/src/Core/GeneralProduct.h +./Eigen/src/Core/CommaInitializer.h +./Eigen/src/Core/ReturnByValue.h +./Eigen/src/Core/Stride.h +./Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +./Eigen/src/SparseLU/SparseLU_column_dfs.h +./Eigen/src/SparseLU/SparseLU_panel_dfs.h +./Eigen/src/SparseLU/SparseLU_relax_snode.h +./Eigen/src/SparseLU/SparseLU_panel_bmod.h +./Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +./Eigen/src/SparseLU/SparseLU_Utils.h +./Eigen/src/SparseLU/SparseLU_gemm_kernel.h +./Eigen/src/SparseLU/SparseLU_kernel_bmod.h +./Eigen/src/SparseLU/SparseLU_pivotL.h +./Eigen/src/SparseLU/SparseLU_Memory.h +./Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +./Eigen/src/SparseLU/SparseLUImpl.h +./Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +./Eigen/src/SparseLU/SparseLU_Structs.h +./Eigen/src/SparseLU/SparseLU.h +./Eigen/src/SparseLU/SparseLU_column_bmod.h +./Eigen/src/SparseLU/SparseLU_pruneL.h +./Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +./Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +./Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +./Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +./Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +./Eigen/src/SparseCholesky/SimplicialCholesky.h +./Eigen/src/Cholesky/LDLT.h +./Eigen/src/Cholesky/LLT.h +./Eigen/src/CholmodSupport/CholmodSupport.h +./Eigen/src/PaStiXSupport/PaStiXSupport.h +./Eigen/src/MetisSupport/MetisSupport.h +./Eigen/StdVector +./Eigen/Core +./Eigen/OrderingMethods +./Eigen/SparseLU +./Eigen/StdList +./Eigen/StdDeque +./Eigen/SparseCholesky +./Eigen/SparseCore +./scripts/relicense.py +./scripts/relicense.py +./blas/BandTriangularSolver.h +./blas/PackedTriangularMatrixVector.h +./blas/complex_double.cpp +./blas/level2_real_impl.h +./blas/level1_cplx_impl.h +./blas/level1_impl.h +./blas/level1_real_impl.h +./blas/level3_impl.h +./blas/single.cpp +./blas/level2_cplx_impl.h +./blas/PackedSelfadjointProduct.h +./blas/Rank2Update.h +./blas/complex_single.cpp +./blas/PackedTriangularSolverVector.h +./blas/double.cpp +./blas/common.h +./blas/level2_impl.h +./blas/GeneralRank1Update.h + +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. + +---------------------------------------------------------------------- +Following applies to: +./doc/UsingIntelMKL.dox +./doc/UsingIntelMKL.dox +./Eigen/src/Eigenvalues/ComplexSchur_MKL.h +./Eigen/src/Eigenvalues/ComplexSchur_MKL.h +./Eigen/src/Eigenvalues/SelfAdjointEigenSolver_MKL.h +./Eigen/src/Eigenvalues/SelfAdjointEigenSolver_MKL.h +./Eigen/src/Eigenvalues/RealSchur_MKL.h +./Eigen/src/Eigenvalues/RealSchur_MKL.h +./Eigen/src/LU/arch/Inverse_SSE.h +./Eigen/src/LU/arch/Inverse_SSE.h +./Eigen/src/LU/PartialPivLU_MKL.h +./Eigen/src/LU/PartialPivLU_MKL.h +./Eigen/src/QR/HouseholderQR_MKL.h +./Eigen/src/QR/HouseholderQR_MKL.h +./Eigen/src/QR/ColPivHouseholderQR_MKL.h +./Eigen/src/QR/ColPivHouseholderQR_MKL.h +./Eigen/src/SVD/JacobiSVD_MKL.h +./Eigen/src/SVD/JacobiSVD_MKL.h +./Eigen/src/PardisoSupport/PardisoSupport.h +./Eigen/src/PardisoSupport/PardisoSupport.h +./Eigen/src/Core/Assign_MKL.h +./Eigen/src/Core/Assign_MKL.h +./Eigen/src/Core/products/SelfadjointMatrixVector_MKL.h +./Eigen/src/Core/products/SelfadjointMatrixVector_MKL.h +./Eigen/src/Core/products/GeneralMatrixVector_MKL.h +./Eigen/src/Core/products/GeneralMatrixVector_MKL.h +./Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h +./Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h +./Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h +./Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h +./Eigen/src/Core/products/GeneralMatrixMatrix_MKL.h +./Eigen/src/Core/products/GeneralMatrixMatrix_MKL.h +./Eigen/src/Core/products/TriangularMatrixVector_MKL.h +./Eigen/src/Core/products/TriangularMatrixVector_MKL.h +./Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h +./Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h +./Eigen/src/Core/products/TriangularSolverMatrix_MKL.h +./Eigen/src/Core/products/TriangularSolverMatrix_MKL.h +./Eigen/src/Core/util/MKL_support.h +./Eigen/src/Core/util/MKL_support.h +./Eigen/src/Cholesky/LLT_MKL.h +./Eigen/src/Cholesky/LLT_MKL.h + +/* + Copyright (c) 2011, Intel Corporation. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. * + Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the + distribution. * Neither the name of Intel Corporation nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +---------------------------------------------------------------------- +Following applies to: + everything under ./bench/btl + + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds +of works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, +family, or household purposes, or (2) anything designed or sold for +incorporation into a dwelling. In determining whether a product is a +consumer product, doubtful cases shall be resolved in favor of +coverage. For a particular product received by a particular user, +"normally used" refers to a typical or common use of that class of +product, regardless of the status of the particular user or of the way +in which the particular user actually uses, or expects or is expected +to use, the product. A product is a consumer product regardless of +whether the product has substantial commercial, industrial or +non-consumer uses, unless such uses represent the only significant +mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to +install and execute modified versions of a covered work in that User +Product from a modified version of its Corresponding Source. The +information must suffice to ensure that the continued functioning of +the modified object code is in no case prevented or interfered with +solely because modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include +a requirement to continue to provide support service, warranty, or +updates for a work that has been modified or installed by the +recipient, or for the User Product in which it has been modified or +installed. Access to a network may be denied when the modification +itself materially and adversely affects the operation of the network +or violates the rules and protocols for communication across the +network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material +you add to a covered work, you may (if authorized by the copyright +holders of that material) supplement the terms of this License with +terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions + of it) with contractual assumptions of liability to the recipient, + for any liability that these contractual assumptions directly + impose on those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement +or otherwise) that contradict the conditions of this License, they do +not excuse you from the conditions of this License. If you cannot +convey a covered work so as to satisfy simultaneously your obligations +under this License and any other pertinent obligations, then as a +consequence you may not convey it at all. For example, if you agree +to terms that obligate you to collect a royalty for further conveying +from those to whom you convey the Program, the only way you could +satisfy both those terms and this License would be to refrain entirely +from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT +WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND +PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE +DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR +CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN +WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES +AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR +DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL +DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM +(INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED +INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF +THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER +OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these +terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or + modify it under the terms of the GNU General Public License as + published by the Free Software Foundation, either version 3 of the + License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see + . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) This program comes + with ABSOLUTELY NO WARRANTY; for details type `show w'. This is + free software, and you are welcome to redistribute it under + certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the +appropriate parts of the General Public License. Of course, your +program's commands might be different; for a GUI interface, you would +use an "about box". + + You should also get your employer (if you work as a programmer) or +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. For more information on this, and how to apply and follow +the GNU GPL, see . + + The GNU General Public License does not permit incorporating your +program into proprietary programs. If your program is a subroutine +library, you may consider it more useful to permit linking proprietary +applications with the library. If this is what you want to do, use +the GNU Lesser General Public License instead of this License. But +first, please read . + + +---------------------------------------------------------------------- +Following applies to: +./test/metis_support.cpp +./test/sparselu.cpp +./unsupported/test/mpreal/mpreal.h +./unsupported/Eigen/src/IterativeSolvers/IterationController.h +./unsupported/Eigen/src/IterativeSolvers/ConstrainedConjGrad.h +./unsupported/Eigen/src/Eigenvalues/ArpackSelfAdjointEigenSolver.h +./Eigen/src/OrderingMethods/Amd.h +./Eigen/src/SparseCholesky/SimplicialCholesky_impl.h + + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the +GNU General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this + license document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this + license document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of + this License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. + + +---------------------------------------------------------------------- +Following applies to: +./unsupported/Eigen/src/LevenbergMarquardt/LevenbergMarquardt.h +./unsupported/Eigen/src/LevenbergMarquardt/LMcovar.h +./unsupported/Eigen/src/LevenbergMarquardt/LMonestep.h +./unsupported/Eigen/src/LevenbergMarquardt/LMpar.h +./unsupported/Eigen/src/LevenbergMarquardt/LMqrsolv.h + +Minpack Copyright Notice (1999) University of Chicago. All rights +reserved + +Redistribution and use in source and binary forms, with or +without modification, are permitted provided that the +following conditions are met: + +1. Redistributions of source code must retain the above +copyright notice, this list of conditions and the following +disclaimer. + +2. Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following +disclaimer in the documentation and/or other materials +provided with the distribution. + +3. The end-user documentation included with the +redistribution, if any, must include the following +acknowledgment: + + "This product includes software developed by the + University of Chicago, as Operator of Argonne National + Laboratory. + +Alternately, this acknowledgment may appear in the software +itself, if and wherever such third-party acknowledgments +normally appear. + +4. WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS" +WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE +UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND +THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE +OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY +OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR +USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF +THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4) +DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION +UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL +BE CORRECTED. + +5. LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT +HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF +ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT, +INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF +ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF +PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER +SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT +(INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE, +EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE +POSSIBILITY OF SUCH LOSS OR DAMAGES. diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/gpu_packet_math.patch b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/gpu_packet_math.patch new file mode 100644 index 000000000..fdc8961b9 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/gpu_packet_math.patch @@ -0,0 +1,25 @@ +diff -ru a/Eigen/src/Geometry/arch/Geometry_SSE.h b/Eigen/src/Geometry/arch/Geometry_SSE.h +--- a/Eigen/src/Geometry/arch/Geometry_SSE.h ++++ b/Eigen/src/Geometry/arch/Geometry_SSE.h +@@ -33,13 +33,14 @@ + Packet4f b = be.template packet(0); + Packet4f s1 = pmul(vec4f_swizzle1(a,1,2,0,2),vec4f_swizzle1(b,2,0,1,2)); + Packet4f s2 = pmul(vec4f_swizzle1(a,3,3,3,1),vec4f_swizzle1(b,0,1,2,1)); +- pstoret( +- &res.x(), +- padd(psub(pmul(a,vec4f_swizzle1(b,3,3,3,3)), +- pmul(vec4f_swizzle1(a,2,0,1,0), +- vec4f_swizzle1(b,1,2,0,0))), +- pxor(mask,padd(s1,s2)))); +- ++ pstoret( ++ &res.x(), ++ padd( ++ psub(pmul(a, vec4f_swizzle1(b, 3, 3, 3, 3)), ++ pmul(vec4f_swizzle1(a, 2, 0, 1, 0), ++ vec4f_swizzle1(b, 1, 2, 0, 0))), ++ pxor(mask, padd(s1, s2)))); ++ + return res; + } + }; diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint new file mode 100644 index 000000000..67cb111db --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint @@ -0,0 +1,58 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_FIXED_POINT_MODULE +#define EIGEN_CXX11_FIXED_POINT_MODULE + +#include +#include + +/** \defgroup CXX11_FixedPoint_Module Fixed Point Module + * + * This module provides common core features for all modules that + * explicitly depend on C++11. Currently, this is only the Tensor + * module. Note that at this stage, you should not need to include + * this module directly. + * + * It also provides a limited fallback for compilers that don't support + * CXX11 yet, such as nvcc. + * + * \code + * #include + * \endcode + */ + +#include "src/FixedPoint/FixedPointTypes.h" + +// Use optimized implementations whenever available +#if defined (EIGEN_VECTORIZE_AVX512DQ) || defined (EIGEN_VECTORIZE_AVX512BW) +#include "src/FixedPoint/PacketMathAVX512.h" +#include "src/FixedPoint/TypeCastingAVX512.h" + +#elif defined EIGEN_VECTORIZE_AVX2 +#define EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT +#define EIGEN_USE_OPTIMIZED_INT16_INT16_MAT_MAT_PRODUCT +#include "src/FixedPoint/PacketMathAVX2.h" +#include "src/FixedPoint/MatMatProductAVX2.h" +#include "src/FixedPoint/TypeCastingAVX2.h" + +#elif defined EIGEN_VECTORIZE_AVX +#include "src/FixedPoint/PacketMathAVX.h" + +#elif defined EIGEN_VECTORIZE_NEON +#define EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT +#include "src/FixedPoint/MatMatProductNEON.h" +#endif + +// Use the default implementation when no optimized code is available +#include "src/FixedPoint/MatMatProduct.h" +#include "src/FixedPoint/MatVecProduct.h" + + +#endif // EIGEN_CXX11_FIXED_POINT_MODULE diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/Tensor new file mode 100644 index 000000000..b45f30c69 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/Tensor @@ -0,0 +1,15 @@ + +#include "unsupported/Eigen/CXX11/Tensor" + +#ifdef _WIN32 +#ifndef SLEEP_FUNC_HEADER_GUARD +#define SLEEP_FUNC_HEADER_GUARD +inline void sleep(unsigned int seconds) { Sleep(1000*seconds); } +#endif + +// On Windows, Eigen will include Windows.h, which defines various +// macros that conflict with TensorFlow symbols. Undefine them here to +// prevent clashes. +#undef DeleteFile +#undef ERROR +#endif // _WIN32 diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool new file mode 100644 index 000000000..d2639af4d --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool @@ -0,0 +1 @@ +#include "unsupported/Eigen/CXX11/ThreadPool" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h new file mode 100644 index 000000000..5cdc957c5 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h @@ -0,0 +1,340 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef CXX11_SRC_FIXEDPOINT_FIXEDPOINTTYPES_H_ +#define CXX11_SRC_FIXEDPOINT_FIXEDPOINTTYPES_H_ + +#include +#include + +namespace Eigen { + +// The mantissa part of the fixed point representation. See +// go/tensorfixedpoint for details +struct QInt8; +struct QUInt8; +struct QInt16; +struct QUInt16; +struct QInt32; + +template <> +struct NumTraits : GenericNumTraits {}; +template <> +struct NumTraits : GenericNumTraits {}; +template <> +struct NumTraits : GenericNumTraits {}; +template <> +struct NumTraits : GenericNumTraits {}; +template <> +struct NumTraits : GenericNumTraits {}; + +namespace internal { +template <> +struct scalar_product_traits { + enum { + // Cost = NumTraits::MulCost, + Defined = 1 + }; + typedef QInt32 ReturnType; +}; +} // namespace internal + +// Wrap the 8bit int into a QInt8 struct instead of using a typedef to prevent +// the compiler from silently type cast the mantissa into a bigger or a smaller +// representation. +struct QInt8 { + QInt8() {} + QInt8(const int8_t v) : value(v) {} + QInt8(const QInt32 v); + + operator int() const { return static_cast(value); } + + int8_t value; +}; + +struct QUInt8 { + QUInt8() {} + QUInt8(const uint8_t v) : value(v) {} + QUInt8(const QInt32 v); + + operator int() const { return static_cast(value); } + + uint8_t value; +}; + +struct QInt16 { + QInt16() {} + QInt16(const int16_t v) : value(v) {} + QInt16(const QInt32 v); + operator int() const { return static_cast(value); } + + int16_t value; +}; + +struct QUInt16 { + QUInt16() {} + QUInt16(const uint16_t v) : value(v) {} + QUInt16(const QInt32 v); + operator int() const { return static_cast(value); } + + uint16_t value; +}; + +struct QInt32 { + QInt32() {} + QInt32(const int8_t v) : value(v) {} + QInt32(const int32_t v) : value(v) {} + QInt32(const uint32_t v) : value(static_cast(v)) {} + QInt32(const QInt8 v) : value(v.value) {} + QInt32(const float v) : value(static_cast(lrint(v))) {} +#ifdef EIGEN_MAKING_DOCS + // Workaround to fix build on PPC. + QInt32(unsigned long v) : value(v) {} +#endif + + operator float() const { return static_cast(value); } + + int32_t value; +}; + +EIGEN_STRONG_INLINE QInt8::QInt8(const QInt32 v) + : value(v.value > 127 ? 127 : (v.value < -128 ? -128 : v.value)) {} +EIGEN_STRONG_INLINE QUInt8::QUInt8(const QInt32 v) + : value(v.value > 255 ? 255 : (v.value < 0 ? 0 : v.value)) {} +EIGEN_STRONG_INLINE QInt16::QInt16(const QInt32 v) + : value(v.value > 32767 ? 32767 : (v.value < -32768 ? -32768 : v.value)) {} +EIGEN_STRONG_INLINE QUInt16::QUInt16(const QInt32 v) + : value(v.value > 65535 ? 65535 : (v.value < 0 ? 0 : v.value)) {} + +// Basic widening 8-bit operations: This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt8 b) { + return QInt32(static_cast(a.value) * static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QUInt8 b) { + return QInt32(static_cast(a.value) * static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt8 b) { + return QInt32(static_cast(a.value) + static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt8 b) { + return QInt32(static_cast(a.value) - static_cast(b.value)); +} + +// Basic widening 16-bit operations: This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt16 b) { + return QInt32(static_cast(a.value) * static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QUInt16 b) { + return QInt32(static_cast(a.value) * static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt16 b) { + return QInt32(static_cast(a.value) + static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt16 b) { + return QInt32(static_cast(a.value) - static_cast(b.value)); +} + +// Mixed QInt32 op QInt8 operations. This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt8 b) { + return QInt32(a.value + static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt32 b) { + return QInt32(static_cast(a.value) + b.value); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt8 b) { + return QInt32(a.value - static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt32 b) { + return QInt32(static_cast(a.value) - b.value); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt8 b) { + return QInt32(a.value * static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt32 b) { + return QInt32(static_cast(a.value) * b.value); +} + +// Mixed QInt32 op QInt16 operations. This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt16 b) { + return QInt32(a.value + static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt32 b) { + return QInt32(static_cast(a.value) + b.value); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt16 b) { + return QInt32(a.value - static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt32 b) { + return QInt32(static_cast(a.value) - b.value); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt16 b) { + return QInt32(a.value * static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt32 b) { + return QInt32(static_cast(a.value) * b.value); +} + +// Mixed QInt32 op QUInt8 operations. This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt8 b) { + return QInt32(a.value + static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QUInt8 a, const QInt32 b) { + return QInt32(static_cast(a.value) + b.value); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt8 b) { + return QInt32(a.value - static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QUInt8 a, const QInt32 b) { + return QInt32(static_cast(a.value) - b.value); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt8 b) { + return QInt32(a.value * static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QUInt8 a, const QInt32 b) { + return QInt32(static_cast(a.value) * b.value); +} + +// Mixed QInt32 op QUInt16 operations. This will be vectorized in future CLs. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt16 b) { + return QInt32(a.value + static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator+(const QUInt16 a, const QInt32 b) { + return QInt32(static_cast(a.value) + b.value); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt16 b) { + return QInt32(a.value - static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator-(const QUInt16 a, const QInt32 b) { + return QInt32(static_cast(a.value) - b.value); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt16 b) { + return QInt32(a.value * static_cast(b.value)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const QUInt16 a, const QInt32 b) { + return QInt32(static_cast(a.value) * b.value); +} + +// Basic arithmetic operations on QInt32, which behaves like a int32_t. +EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt32 b) { + return a.value + b.value; +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt32 b) { + return a.value - b.value; +} +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt32 b) { + return a.value * b.value; +} +EIGEN_STRONG_INLINE QInt32 operator/(const QInt32 a, const QInt32 b) { + return a.value / b.value; +} +EIGEN_STRONG_INLINE QInt32& operator+=(QInt32& a, const QInt32 b) { + a.value += b.value; + return a; +} +EIGEN_STRONG_INLINE QInt32& operator-=(QInt32& a, const QInt32 b) { + a.value -= b.value; + return a; +} +EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const QInt32 b) { + a.value *= b.value; + return a; +} +EIGEN_STRONG_INLINE QInt32& operator/=(QInt32& a, const QInt32 b) { + a.value /= b.value; + return a; +} +EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) { return -a.value; } + +// Scaling QInt32 by double. We do the arithmetic in double because +// float only has 23 bits of mantissa, so casting QInt32 to float might reduce +// accuracy by discarding up to 7 (least significant) bits. +EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const double b) { + return static_cast(lrint(static_cast(a.value) * b)); +} +EIGEN_STRONG_INLINE QInt32 operator*(const double a, const QInt32 b) { + return static_cast(lrint(a * static_cast(b.value))); +} +EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const double b) { + a.value = static_cast(lrint(static_cast(a.value) * b)); + return a; +} + +// Comparisons +EIGEN_STRONG_INLINE bool operator==(const QInt8 a, const QInt8 b) { + return a.value == b.value; +} +EIGEN_STRONG_INLINE bool operator==(const QUInt8 a, const QUInt8 b) { + return a.value == b.value; +} +EIGEN_STRONG_INLINE bool operator==(const QInt16 a, const QInt16 b) { + return a.value == b.value; +} +EIGEN_STRONG_INLINE bool operator==(const QUInt16 a, const QUInt16 b) { + return a.value == b.value; +} +EIGEN_STRONG_INLINE bool operator==(const QInt32 a, const QInt32 b) { + return a.value == b.value; +} + +EIGEN_STRONG_INLINE bool operator<(const QInt8 a, const QInt8 b) { + return a.value < b.value; +} +EIGEN_STRONG_INLINE bool operator<(const QUInt8 a, const QUInt8 b) { + return a.value < b.value; +} +EIGEN_STRONG_INLINE bool operator<(const QInt16 a, const QInt16 b) { + return a.value < b.value; +} +EIGEN_STRONG_INLINE bool operator<(const QUInt16 a, const QUInt16 b) { + return a.value < b.value; +} +EIGEN_STRONG_INLINE bool operator<(const QInt32 a, const QInt32 b) { + return a.value < b.value; +} + +EIGEN_STRONG_INLINE bool operator>(const QInt8 a, const QInt8 b) { + return a.value > b.value; +} +EIGEN_STRONG_INLINE bool operator>(const QUInt8 a, const QUInt8 b) { + return a.value > b.value; +} +EIGEN_STRONG_INLINE bool operator>(const QInt16 a, const QInt16 b) { + return a.value > b.value; +} +EIGEN_STRONG_INLINE bool operator>(const QUInt16 a, const QUInt16 b) { + return a.value > b.value; +} +EIGEN_STRONG_INLINE bool operator>(const QInt32 a, const QInt32 b) { + return a.value > b.value; +} + +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt8 a) { + os << static_cast(a.value); + return os; +} +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt8 a) { + os << static_cast(a.value); + return os; +} +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt16 a) { + os << static_cast(a.value); + return os; +} +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt16 a) { + os << static_cast(a.value); + return os; +} +EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt32 a) { + os << a.value; + return os; +} + +} // namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_FIXEDPOINTTYPES_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h new file mode 100644 index 000000000..3f93f9f73 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h @@ -0,0 +1,345 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef CXX11_SRC_FIXEDPOINT_MATMATPRODUCT_H_ +#define CXX11_SRC_FIXEDPOINT_MATMATPRODUCT_H_ + +namespace Eigen { +namespace internal { + +// Accumulate the product of 2 QInt8 inputs on 32 bits to prevent +// overflows +template <> +struct scalar_product_traits { + enum { Defined = 1 }; + typedef QInt32 ReturnType; +}; + +// Accumulate the product of 2 QInt16 inputs on 32 bits to prevent +// overflows +template <> +struct scalar_product_traits { + enum { Defined = 1 }; + typedef QInt32 ReturnType; +}; + +// Accumulate the product of QInt8 inputs with QUint8 inputs on 32 bits +// to prevent overflows +template <> +struct scalar_product_traits { + enum { Defined = 1 }; + typedef QInt32 ReturnType; +}; + +// Accumulate the product of QUInt8 inputs with Qint8 inputs on 32 bits +// to prevent overflows +template <> +struct scalar_product_traits { + enum { Defined = 1 }; + typedef QInt32 ReturnType; +}; + +// Description of the product implementation. It's pretty simple now since +// nothing is vectorized yet. +// This definition tackle the case where both lhs and rhs are encoded using +// signed 8bit integers +#ifndef EIGEN_USE_OPTIMIZED_INT8_INT8_MAT_MAT_PRODUCT + +template +class gebp_traits { + public: + typedef QInt8 LhsScalar; + typedef QInt8 RhsScalar; + typedef QInt32 ResScalar; + + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + +// The signed 8bit Mat-Mat product itself. +template +struct gebp_kernel { + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, + const QInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); +}; + +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt8* blockA, const QInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + +// This definition tackle the case where the lhs is encoded using signed 8bit +// integers and the rhs using unsigned 8bit integers. +#ifndef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT +template +class gebp_traits { + public: + typedef QInt8 LhsScalar; + typedef QUInt8 RhsScalar; + typedef QInt32 ResScalar; + + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + +// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs +template +struct gebp_kernel { + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, + const QUInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); +}; + +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + +// This definition tackle the case where the khs is encoded using unsigned 8bit +// integers and the rhs using signed 8bit integers. +#ifndef EIGEN_USE_OPTIMIZED_UINT8_INT8_MAT_MAT_PRODUCT +template +class gebp_traits { + public: + typedef QUInt8 LhsScalar; + typedef QInt8 RhsScalar; + typedef QInt32 ResScalar; + + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + +// Mat-Mat product of an unsigned 8bit lhs with a signed 8bit rhs +template +struct gebp_kernel { + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QUInt8* blockA, + const QInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); +}; + +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QUInt8* blockA, const QInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + +#ifndef EIGEN_USE_OPTIMIZED_INT16_INT16_MAT_MAT_PRODUCT + +template +class gebp_traits { + public: + typedef QInt16 LhsScalar; + typedef QInt16 RhsScalar; + typedef QInt32 ResScalar; + + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + +// The signed 16bit Mat-Mat product itself. +template +struct gebp_kernel { + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt16* blockA, + const QInt16* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); +}; + +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt16* blockA, const QInt16* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + +} // namespace internal +} // namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_MATMATPRODUCT_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h new file mode 100644 index 000000000..b06c33521 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h @@ -0,0 +1,2292 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// Copyright (C) 2015 Matthew Sarett +// Copyright (C) 2016 Nishant Patil +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_ +#define CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_ + +namespace Eigen { +namespace internal { + +// AVX2 optimized implementation of Mat-Mat product. +// LHS is encoded using signed 16-bit integers. +// RHS is encoded using signed 16-bit integers. +#ifdef EIGEN_USE_OPTIMIZED_INT16_INT16_MAT_MAT_PRODUCT + +// Define quantized traits +template +class gebp_traits { + public: + typedef QInt16 LhsScalar; + typedef QInt16 RhsScalar; + typedef QInt32 ResScalar; + + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + + enum { + // Define register blocking scheme. + nr = 16, + mr = 16, + kr = 4, + // Ignore progress tracking per loop iteration. + LhsProgress = -1, + RhsProgress = -1 + }; +}; + +// Specialized blocking for quantized implementations. +// Used by TensorContractionThreadPool, inputs must have dimensions that are +// multiples of 32. +template +class TensorContractionBlocking { + public: + TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) + : kc_(((k + 15) / 16) * 16), + mc_(((m + 15) / 16) * 16), + nc_(((n + 15) / 16) * 16) { + eigen_assert(mc_ % 16 == 0); + eigen_assert(kc_ % 16 == 0); + if (!k || !m || !n) { + return; + } + + if (ShardingType == ShardByCol) { + eigen_assert(nc_ % 16 == 0); + nc_ = (((nc_ / num_threads) + 15) / 16) * 16; + } else { + eigen_assert(nc_ % 16 == 0); + mc_ = (((mc_ / num_threads) + 15) / 16) * 16; + } + } + + EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } + EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } + EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } + + private: + Index kc_; + Index mc_; + Index nc_; +}; + +// Specialized blocking for quantized implementations. +// Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to +// multiples of 32. +template +class gemm_blocking_space + : public level3_blocking { + DenseIndex m_sizeA; + DenseIndex m_sizeB; + + public: + gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, + DenseIndex /*num_threads*/, bool /*l3_blocking*/) { + this->m_mc = ((rows + 15) / 16) * 16; + this->m_nc = ((cols + 15) / 16) * 16; + this->m_kc = ((depth + 15) / 16) * 16; + m_sizeA = this->m_mc * this->m_kc; + m_sizeB = this->m_kc * this->m_nc; + } + void allocateA() { + if (this->m_blockA == 0) this->m_blockA = aligned_new(m_sizeA); + } + void allocateB() { + if (this->m_blockB == 0) this->m_blockB = aligned_new(m_sizeB); + } + void allocateAll() { + allocateA(); + allocateB(); + } + ~gemm_blocking_space() { + aligned_delete(this->m_blockA, m_sizeA); + aligned_delete(this->m_blockB, m_sizeB); + } +}; + +// Below are the fully optimized versions that are correct only for sizes that +// are multiple of 16. It is about a 10% performance benefit to keep these +// implementations separate. + +// Arrange a block of the left input matrix in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... +// A1 B1 C1 D1 E1 F1 G1 H1 ... +// A2 B2 C2 D2 E2 F2 G2 H2 ... +// A3 B3 C3 D3 E3 F3 G3 H3 ... +// A4 B4 C4 D4 E4 F4 G4 H4 ... +// A5 B5 C5 D5 E5 F5 G5 H5 ... +// A6 B6 C6 D6 E6 F6 G6 H6 ... +// A7 B7 C7 D7 E7 F7 G7 H7 ... +// A8 ... +// ... +// +// Packing with m = 8 yields row major output (A0 beside B0 in memory): +// A0 B0 +// A1 B1 +// A2 B2 +// A3 B3 +// A4 B4 +// A5 B5 +// A6 B6 +// A7 B7 +// ... +// +// The purpose is to collect m rows of size k. Two elements of the same +// row are arranged contiguously because madd performs an adjacent addition +// in the kernel. + +template +struct gemm_pack_lhs { + EIGEN_DONT_INLINE void operator()(QInt16* blockA, const DataMapper& lhs, + Index depth, Index rows, Index stride = 0, + Index offset = 0); +}; + +template +EIGEN_DONT_INLINE void gemm_pack_lhs:: +operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows, + Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + typedef typename packet_traits::type Packet; + + // Use alternate function for weird sizes + if (rows % 16 != 0 || depth % 16 != 0) { + assert(false && + "only depths and rows that are a multiple of 16 are currently " + "supported"); + // gemm_pack_lhs_any lhs_pack; + // return lhs_pack(blockA, lhs, depth, rows, stride, offset); + } + + // Get vector pointer + __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); + + // Pack rows in sets of 16 + for (Index m = 0; m < rows; m += 16) { + // Pack depth in sets of 4 + for (Index k = 0; k < depth; k += 4) { + // Load vectors + __m256i L_A = lhs.template loadPacket(m, k); + __m256i L_B = lhs.template loadPacket(m, k + 1); + __m256i L_C = lhs.template loadPacket(m, k + 2); + __m256i L_D = lhs.template loadPacket(m, k + 3); + + // Rearrange the inputs as required by the kernel + __m256i L_AB0_AB7 = _mm256_unpacklo_epi16(L_A, L_B); + __m256i L_AB8_AB15 = _mm256_unpackhi_epi16(L_A, L_B); + __m256i L_CD0_CD7 = _mm256_unpacklo_epi16(L_C, L_D); + __m256i L_CD8_CD15 = _mm256_unpackhi_epi16(L_C, L_D); + + __m256i L_AD0 = _mm256_permute2x128_si256(L_AB0_AB7, L_AB8_AB15, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + __m256i L_AD8 = _mm256_permute2x128_si256(L_CD0_CD7, L_CD8_CD15, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + __m256i L_AD16 = _mm256_permute2x128_si256(L_AB0_AB7, L_AB8_AB15, 0x31); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_CD0_CD7, L_CD8_CD15, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + } + } +} + +// Arrange a block of the right input matrix in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... +// A1 B1 C1 D1 E1 F1 G1 H1 ... +// A2 B2 C2 D2 E2 F2 G2 H2 ... +// A3 B3 C3 D3 E3 F3 G3 H3 ... +// A4 B4 C4 D4 E4 F4 G4 H4 ... +// A5 B5 C5 D5 E5 F5 G5 H5 ... +// A6 B6 C6 D6 E6 F6 G6 H6 ... +// A7 B7 C7 D7 E7 F7 G7 H7 ... +// A8 ... +// ... +// Packing yields row major output (A0 beside A1 in memory): +// A0 A1 A2 A3 A4 A5 A6 A7 +// B0 B1 B2 B3 B4 B5 B6 B7 +// ... +// +// At least two elements of the same col are arranged contiguously because +// maddubs and madd both perform an adjacent addition in the kernel. We can +// save work by leaving 4 adjacent elements because kr = 4. +// The purpose is to collect n cols of size k. Two elements of the same +// col are arranged contiguously because madd performs an adjacent addition +// in the kernel. +template +struct gemm_pack_rhs { + EIGEN_DONT_INLINE void operator()(QInt16* blockB, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0); +}; + +template +EIGEN_DONT_INLINE void +gemm_pack_rhs:: +operator()(QInt16* blockB, const DataMapper& rhs, Index depth, Index cols, + Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + typedef typename packet_traits::type Packet; + + // Use alternate function for weird sizes + if (cols % 16 != 0 || depth % 16 != 0) { + assert(false && + "only depths and cols that are a multiple of 16 are currently " + "supported"); + // gemm_pack_rhs_any rhs_pack; + // return rhs_pack(blockB, rhs, depth, cols, stride, offset); + } + + // Get vector pointer + __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); + + // Perform a step of the packing for 4 columns + __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_4, R_AD_8, R_AD_12; +#define PACK_STEP \ + R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \ + R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \ + R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \ + R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \ + R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \ + R_AD_8 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \ + R_AD_4 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \ + R_AD_12 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \ + _mm256_store_si256(blockB_256, R_AD_0); \ + _mm256_store_si256(blockB_256 + 4, R_AD_4); \ + _mm256_store_si256(blockB_256 + 8, R_AD_8); \ + _mm256_store_si256(blockB_256 + 12, R_AD_12); \ + blockB_256++; + + // Pack cols in sets of 16 + for (Index n = 0; n < cols; n += 16) { + // Pack depth in sets of 16 + for (Index k = 0; k < depth; k += 16) { + __m256i R_A = rhs.template loadPacket(k, n); + __m256i R_B = rhs.template loadPacket(k, n + 1); + __m256i R_C = rhs.template loadPacket(k, n + 2); + __m256i R_D = rhs.template loadPacket(k, n + 3); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 4); + R_B = rhs.template loadPacket(k, n + 5); + R_C = rhs.template loadPacket(k, n + 6); + R_D = rhs.template loadPacket(k, n + 7); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 8); + R_B = rhs.template loadPacket(k, n + 9); + R_C = rhs.template loadPacket(k, n + 10); + R_D = rhs.template loadPacket(k, n + 11); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 12); + R_B = rhs.template loadPacket(k, n + 13); + R_C = rhs.template loadPacket(k, n + 14); + R_D = rhs.template loadPacket(k, n + 15); + PACK_STEP; + + blockB_256 += 12; + } + } +#undef PACK_STEP +} + +// Perform the actual multiplication on packed inputs +template +struct gebp_kernel { + typedef typename DataMapper::LinearMapper LinearMapper; + + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt16* blockA, + const QInt16* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); +}; + +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt16* blockA, const QInt16* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + // Use alternate function for weird sizes + if (rows % 16 != 0 || cols % 16 != 0 || depth % 16 != 0) { + assert(false && + "only depths, cols and rows that are a multiple of 16 are currently " + "supported"); + // gebp_kernel_any gebp; + // return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, + // strideB, offsetA, offsetB); + } + + // Create result block + QInt32* blockO = aligned_new(16 * 16); + memset(blockO, 0, 16 * 16 * sizeof(QInt32)); + + // Get vectorized pointers + __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO); + const __m256i* blockA_256 = reinterpret_cast(blockA); + const __m256i* blockB_256 = reinterpret_cast(blockB); + + // Loop over blocks of 16 columns + for (Index n = 0; n < cols; n += 16) { + // Reset index into blockA + Index indexL = 0; + // Loop over blocks of 16 rows + for (Index m = 0; m < rows; m += 16) { + // Reset index into blockB + Index indexR = n / 16 * depth; + // Loop over blocks of 4 on depth + for (Index k = 0; k < depth; k += 4) { + // Load inputs + __m256i L_AD0 = blockA_256[indexL++]; + __m256i L_AD8 = blockA_256[indexL++]; + __m256i L_EH0 = blockA_256[indexL++]; + __m256i L_EH8 = blockA_256[indexL++]; + + __m256i R_AH0 = blockB_256[indexR++]; + __m256i R_AH4 = blockB_256[indexR++]; + __m256i R_AH8 = blockB_256[indexR++]; + __m256i R_AH12 = blockB_256[indexR++]; + + // Declare variables used in COMPUTE_STEP + __m256i P_32_A, P_32_B, P_32; + +#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \ + P_32_A = _mm256_madd_epi16(R_INPUT_A, L_AD0); \ + P_32_B = _mm256_madd_epi16(R_INPUT_B, L_AD8); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 2 * OFFSET, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 2 * OFFSET), P_32)); \ + \ + P_32_A = _mm256_madd_epi16(R_INPUT_A, L_EH0); \ + P_32_B = _mm256_madd_epi16(R_INPUT_B, L_EH8); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 2 * OFFSET + 1, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 2 * OFFSET + 1), P_32)); + + // Permute and shuffle to copy a single value across the entire vector + // Then compute the multiplication + // Replicate lower 128-bits of R_AH0 across both lanes + __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00); + // Copy first two elements of R_AH0 across entire vector + __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + // Copy second two elements of R_AH0 across entire vector + __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + + COMPUTE_STEP(R_AD0, R_EH0, 0); + __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 1); + + // Replicate upper 128-bits of R_AH0 across both lanes + R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11); + __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 2); + __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 3); + + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 4); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 5); + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 6); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 7); + + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 8); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 9); + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 10); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 11); + + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 12); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 13); + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 14); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 15); + +#undef COMPUTE_STEP + } + + // Transfer the results to the result matrix + Index i = 0; + for (Index j = n; j < n + 16; j++) { + LinearMapper r0 = res.getLinearMapper(m, j); + LinearMapper r1 = res.getLinearMapper(m + 8, j); + typedef typename packet_traits::type Packet; + r0.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r0.template loadPacket(0))); + r1.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r1.template loadPacket(0))); + } + + // Zero the result block so it can be reused + memset(blockO, 0, 16 * 16 * sizeof(QInt32)); + } + } + aligned_delete(blockO, 16 * 16); +} + +#endif + +// AVX2 optimized implementation of Mat-Mat product. +// LHS is encoded using signed 8-bit integers. +// RHS is encoded using unsigned 8-bit integers. +#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT + +// Define quantized traits +template +class gebp_traits { + public: + typedef QInt8 LhsScalar; + typedef QUInt8 RhsScalar; + typedef QInt32 ResScalar; + + typedef typename packet_traits::type LhsPacket; + typedef LhsPacket LhsPacket4Packing; + + enum { + // Define register blocking scheme. + nr = 32, + mr = 32, + kr = 8, + // Ignore progress tracking per loop iteration. + LhsProgress = -1, + RhsProgress = -1 + }; +}; + +// Specialized blocking for quantized implementations. +// Used by TensorContractionThreadPool, inputs must have dimensions that are +// multiples of 32. +template +class TensorContractionBlocking< + ResScalar, + TensorContractionInputMapper< + QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32, + left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>, + TensorContractionInputMapper, + Index, ShardingType> { + public: + typedef QInt8 LhsScalar; + typedef QUInt8 RhsScalar; + + TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) + : kc_(k), mc_(m), nc_(n) { + eigen_assert(m % 32 == 0); + eigen_assert(k % 32 == 0); + if (!k || !m || !n) { + return; + } + + if (ShardingType == ShardByCol) { + eigen_assert(n % 32 == 0); + nc_ = (((n / num_threads) + 31) / 32) * 32; + } else { + eigen_assert(n % 32 == 0 || n == 1); + // Special case to avoid breaking the unimplemented matrix-vector case + if (n == 1) { + nc_ = 32; + } + mc_ = (((m / num_threads) + 31) / 32) * 32; + } + } + + EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } + EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } + EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } + + private: + Index kc_; + Index mc_; + Index nc_; +}; + +// Specialized blocking for quantized implementations. +// Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to +// multiples of 32. +template +class gemm_blocking_space + : public level3_blocking { + DenseIndex m_sizeA; + DenseIndex m_sizeB; + + public: + gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, + DenseIndex /*num_threads*/, bool /*l3_blocking*/) { + this->m_mc = ((rows + 31) / 32) * 32; + this->m_nc = ((cols + 31) / 32) * 32; + this->m_kc = ((depth + 31) / 32) * 32; + m_sizeA = this->m_mc * this->m_kc; + m_sizeB = this->m_kc * this->m_nc; + } + void allocateA() { + if (this->m_blockA == 0) this->m_blockA = aligned_new(m_sizeA); + } + void allocateB() { + if (this->m_blockB == 0) this->m_blockB = aligned_new(m_sizeB); + } + void allocateAll() { + allocateA(); + allocateB(); + } + ~gemm_blocking_space() { + aligned_delete(this->m_blockA, m_sizeA); + aligned_delete(this->m_blockB, m_sizeB); + } +}; + +template +class gemm_blocking_space + : public level3_blocking { + DenseIndex m_sizeA; + DenseIndex m_sizeB; + + public: + gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, + DenseIndex /*num_threads*/, bool /*l3_blocking*/) { + this->m_mc = ((rows + 31) / 32) * 32; + this->m_nc = ((cols + 31) / 32) * 32; + this->m_kc = ((depth + 31) / 32) * 32; + m_sizeA = this->m_mc * this->m_kc; + m_sizeB = this->m_kc * this->m_nc; + } + void allocateA() { + if (this->m_blockA == 0) this->m_blockA = aligned_new(m_sizeA); + } + void allocateB() { + if (this->m_blockB == 0) this->m_blockB = aligned_new(m_sizeB); + } + void allocateAll() { + allocateA(); + allocateB(); + } + ~gemm_blocking_space() { + aligned_delete(this->m_blockA, m_sizeA); + aligned_delete(this->m_blockB, m_sizeB); + } +}; + +// Alternate templates for any input sizes +template +struct gemm_pack_lhs_any; +template +struct gemm_pack_lhs_any { + EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs, + Index depth, Index rows, Index stride = 0, + Index offset = 0); +}; + +template +struct gemm_pack_rhs_any; +template +struct gemm_pack_rhs_any { + EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0); +}; + +template +struct gebp_kernel_any; +template +struct gebp_kernel_any { + typedef typename DataMapper::LinearMapper LinearMapper; + + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, + const QUInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); +}; + +// Alternate implementations for any input sizes +template +EIGEN_DONT_INLINE void gemm_pack_lhs_any:: +operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, + Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + typedef typename packet_traits::type Packet; + + // Get vector pointer + __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); + + // Get even multiples of the dimensions + Index rows_32 = (rows / 32) * 32; + Index depth_8 = (depth / 8) * 8; + + // Get padding for when depth is not a multiple of 32 + int padding = 0; + if (depth % 32 != 0) { + int depth_32 = (depth / 32) * 32; + int extra_depth = depth - depth_32; + int extra_depth_8 = ((extra_depth + 7) / 8) * 8; + padding = 32 - extra_depth_8; + } + + // Pack rows in sets of 32 + for (Index m = 0; m < rows_32; m += 32) { + // Pack depth in sets of 8 + for (Index k = 0; k < depth_8; k += 8) { + // Load vectors + __m256i L_A = lhs.template loadPacket(m, k); + __m256i L_B = lhs.template loadPacket(m, k + 1); + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + + __m256i L_C = lhs.template loadPacket(m, k + 2); + __m256i L_D = lhs.template loadPacket(m, k + 3); + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing for 32 x 8 block + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_E = lhs.template loadPacket(m, k + 4); + __m256i L_F = lhs.template loadPacket(m, k + 5); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_G = lhs.template loadPacket(m, k + 6); + __m256i L_H = lhs.template loadPacket(m, k + 7); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + + // Finish the k dimension, padding with zeros + if (depth_8 < depth) { + __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; + switch (depth - depth_8) { + case 1: + L_A = lhs.template loadPacket(m, depth_8); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 2: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 3: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 4: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = lhs.template loadPacket(m, depth_8 + 3); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 5: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = lhs.template loadPacket(m, depth_8 + 3); + L_E = lhs.template loadPacket(m, depth_8 + 4); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 6: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = lhs.template loadPacket(m, depth_8 + 3); + L_E = lhs.template loadPacket(m, depth_8 + 4); + L_F = lhs.template loadPacket(m, depth_8 + 5); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + break; + case 7: + L_A = lhs.template loadPacket(m, depth_8); + L_B = lhs.template loadPacket(m, depth_8 + 1); + L_C = lhs.template loadPacket(m, depth_8 + 2); + L_D = lhs.template loadPacket(m, depth_8 + 3); + L_E = lhs.template loadPacket(m, depth_8 + 4); + L_F = lhs.template loadPacket(m, depth_8 + 5); + L_G = lhs.template loadPacket(m, depth_8 + 6); + L_H = _mm256_setzero_si256(); + break; + } + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + blockA_256 += padding; + } + + // Finish the m dimension, padding with zeros + if (rows_32 < rows) { + // Pack depth in sets of 8 + for (Index k = 0; k < depth_8; k += 8) { + // Load vectors + __m256i L_A = _mm256_setzero_si256(); + __m256i L_B = _mm256_setzero_si256(); + __m256i L_C = _mm256_setzero_si256(); + __m256i L_D = _mm256_setzero_si256(); + __m256i L_E = _mm256_setzero_si256(); + __m256i L_F = _mm256_setzero_si256(); + __m256i L_G = _mm256_setzero_si256(); + __m256i L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + QInt8* ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, k); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, k + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, k + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, k + 3); + ptr = (QInt8*)&L_E; + ptr[m] = lhs(rows_32 + m, k + 4); + ptr = (QInt8*)&L_F; + ptr[m] = lhs(rows_32 + m, k + 5); + ptr = (QInt8*)&L_G; + ptr[m] = lhs(rows_32 + m, k + 6); + ptr = (QInt8*)&L_H; + ptr[m] = lhs(rows_32 + m, k + 7); + } + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing for 32 x 8 block + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + + // Finish the k dimension, padding with zeros + if (depth_8 < depth) { + __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; + QInt8* ptr; + switch (depth - depth_8) { + case 1: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + QInt8* ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + } + break; + case 2: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + } + break; + case 3: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + } + break; + case 4: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + } + break; + case 5: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*)&L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + } + break; + case 6: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*)&L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + ptr = (QInt8*)&L_F; + ptr[m] = lhs(rows_32 + m, depth_8 + 5); + } + break; + case 7: + L_A = _mm256_setzero_si256(); + L_B = _mm256_setzero_si256(); + L_C = _mm256_setzero_si256(); + L_D = _mm256_setzero_si256(); + L_E = _mm256_setzero_si256(); + L_F = _mm256_setzero_si256(); + L_G = _mm256_setzero_si256(); + L_H = _mm256_setzero_si256(); + for (Index m = 0; m < rows - rows_32; m++) { + ptr = (QInt8*)&L_A; + ptr[m] = lhs(rows_32 + m, depth_8); + ptr = (QInt8*)&L_B; + ptr[m] = lhs(rows_32 + m, depth_8 + 1); + ptr = (QInt8*)&L_C; + ptr[m] = lhs(rows_32 + m, depth_8 + 2); + ptr = (QInt8*)&L_D; + ptr[m] = lhs(rows_32 + m, depth_8 + 3); + ptr = (QInt8*)&L_E; + ptr[m] = lhs(rows_32 + m, depth_8 + 4); + ptr = (QInt8*)&L_F; + ptr[m] = lhs(rows_32 + m, depth_8 + 5); + ptr = (QInt8*)&L_G; + ptr[m] = lhs(rows_32 + m, depth_8 + 6); + } + break; + } + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + } +} + +template +EIGEN_DONT_INLINE void gemm_pack_rhs_any:: +operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, + Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + typedef typename packet_traits::type Packet; + + // Get vector pointer + __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); + + // Get even multiples of the dimensions + Index cols_32 = (cols / 32) * 32; + Index depth_32 = (depth / 32) * 32; + + // Perform a step of the packing for 4 columns + __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24; +#define PACK_STEP \ + R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \ + R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \ + R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \ + R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \ + R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \ + R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \ + R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \ + R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \ + _mm256_store_si256(blockB_256, R_AD_0); \ + _mm256_store_si256(blockB_256 + 8, R_AD_8); \ + _mm256_store_si256(blockB_256 + 16, R_AD_16); \ + _mm256_store_si256(blockB_256 + 24, R_AD_24); \ + blockB_256++; + + // Pack cols in sets of 32 + for (Index n = 0; n < cols_32; n += 32) { + // Pack depth in sets of 32 + for (Index k = 0; k < depth_32; k += 32) { + __m256i R_A = rhs.template loadPacket(k, n); + __m256i R_B = rhs.template loadPacket(k, n + 1); + __m256i R_C = rhs.template loadPacket(k, n + 2); + __m256i R_D = rhs.template loadPacket(k, n + 3); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 4); + R_B = rhs.template loadPacket(k, n + 5); + R_C = rhs.template loadPacket(k, n + 6); + R_D = rhs.template loadPacket(k, n + 7); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 8); + R_B = rhs.template loadPacket(k, n + 9); + R_C = rhs.template loadPacket(k, n + 10); + R_D = rhs.template loadPacket(k, n + 11); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 12); + R_B = rhs.template loadPacket(k, n + 13); + R_C = rhs.template loadPacket(k, n + 14); + R_D = rhs.template loadPacket(k, n + 15); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 16); + R_B = rhs.template loadPacket(k, n + 17); + R_C = rhs.template loadPacket(k, n + 18); + R_D = rhs.template loadPacket(k, n + 19); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 20); + R_B = rhs.template loadPacket(k, n + 21); + R_C = rhs.template loadPacket(k, n + 22); + R_D = rhs.template loadPacket(k, n + 23); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 24); + R_B = rhs.template loadPacket(k, n + 25); + R_C = rhs.template loadPacket(k, n + 26); + R_D = rhs.template loadPacket(k, n + 27); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 28); + R_B = rhs.template loadPacket(k, n + 29); + R_C = rhs.template loadPacket(k, n + 30); + R_D = rhs.template loadPacket(k, n + 31); + PACK_STEP; + + blockB_256 += 24; + } + + if (depth_32 < depth) { + QUInt8* ptr; + __m256i R_A = _mm256_setzero_si256(); + __m256i R_B = _mm256_setzero_si256(); + __m256i R_C = _mm256_setzero_si256(); + __m256i R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 1); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 2); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 3); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n + 4); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 5); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 6); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 7); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n + 8); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 9); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 10); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 11); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n + 12); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 13); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 14); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 15); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n + 16); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 17); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 18); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 19); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n + 20); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 21); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 22); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 23); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n + 24); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 25); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 26); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 27); + } + PACK_STEP; + + R_A = _mm256_setzero_si256(); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n + 28); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 29); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 30); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 31); + } + PACK_STEP; + blockB_256 += 24; + } + } + + // Finish packing cols + if (cols_32 < cols) { + // Pack depth in sets of 32 + for (Index k = 0; k < depth_32; k += 32) { + __m256i R_A, R_B, R_C, R_D; + Index n; + for (n = cols_32; n < cols; n += 4) { + switch (cols - n) { + case 1: + R_A = rhs.template loadPacket(k, n); + R_B = _mm256_setzero_si256(); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + case 2: + R_A = rhs.template loadPacket(k, n); + R_B = rhs.template loadPacket(k, n + 1); + R_C = _mm256_setzero_si256(); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + case 3: + R_A = rhs.template loadPacket(k, n); + R_B = rhs.template loadPacket(k, n + 1); + R_C = rhs.template loadPacket(k, n + 2); + R_D = _mm256_setzero_si256(); + PACK_STEP; + break; + default: + R_A = rhs.template loadPacket(k, n); + R_B = rhs.template loadPacket(k, n + 1); + R_C = rhs.template loadPacket(k, n + 2); + R_D = rhs.template loadPacket(k, n + 3); + PACK_STEP; + break; + } + } + + // Increment the block pointer. + // We must pad if cols is not a multiple of 32. + blockB_256 += 32 - (n - cols_32) / 4; + } + + if (depth_32 < depth) { + for (Index n = cols_32; n < cols; n += 4) { + QUInt8* ptr; + __m256i R_A = _mm256_setzero_si256(); + __m256i R_B = _mm256_setzero_si256(); + __m256i R_C = _mm256_setzero_si256(); + __m256i R_D = _mm256_setzero_si256(); + switch (cols - n) { + case 1: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + } + PACK_STEP; + break; + case 2: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 1); + } + PACK_STEP; + break; + case 3: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 1); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 2); + } + PACK_STEP; + break; + default: + for (Index k = depth_32; k < depth; k++) { + ptr = (QUInt8*)&R_A; + ptr[k - depth_32] = rhs(k, n); + ptr = (QUInt8*)&R_B; + ptr[k - depth_32] = rhs(k, n + 1); + ptr = (QUInt8*)&R_C; + ptr[k - depth_32] = rhs(k, n + 2); + ptr = (QUInt8*)&R_D; + ptr[k - depth_32] = rhs(k, n + 3); + } + PACK_STEP; + break; + } + } + } + } +#undef PACK_STEP +} + +template +EIGEN_DONT_INLINE void gebp_kernel_any:: +operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + Index rows_32 = ((rows + 31) / 32) * 32; + Index cols_32 = ((cols + 31) / 32) * 32; + Index depth_32 = ((depth + 31) / 32) * 32; + + // Create result block + ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0); + memset(blockO, 0, 32 * 32 * sizeof(QInt32)); + + // Get vectorized pointers + __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO); + const __m256i* blockA_256 = reinterpret_cast(blockA); + const __m256i* blockB_256 = reinterpret_cast(blockB); + + // Loop over blocks of 32 columns + for (Index n = 0; n < cols_32; n += 32) { + // Reset index into blockA + Index indexL = 0; + // Loop over blocks of 32 rows + for (Index m = 0; m < rows_32; m += 32) { + // Reset index into blockB + Index indexR = n / 32 * depth_32; + // Loop over blocks of 8 on depth + for (Index k = 0; k < depth_32; k += 8) { + // Load inputs + __m256i L_AD0 = blockA_256[indexL++]; + __m256i L_AD8 = blockA_256[indexL++]; + __m256i L_AD16 = blockA_256[indexL++]; + __m256i L_AD24 = blockA_256[indexL++]; + __m256i L_EH0 = blockA_256[indexL++]; + __m256i L_EH8 = blockA_256[indexL++]; + __m256i L_EH16 = blockA_256[indexL++]; + __m256i L_EH24 = blockA_256[indexL++]; + __m256i R_AH0 = blockB_256[indexR++]; + __m256i R_AH4 = blockB_256[indexR++]; + __m256i R_AH8 = blockB_256[indexR++]; + __m256i R_AH12 = blockB_256[indexR++]; + __m256i R_AH16 = blockB_256[indexR++]; + __m256i R_AH20 = blockB_256[indexR++]; + __m256i R_AH24 = blockB_256[indexR++]; + __m256i R_AH28 = blockB_256[indexR++]; + + // This constant is used with madd to convert 16 bit to 32 bit + const __m256i ONE = _mm256_set1_epi32(0x00010001); + + // Declare variables used in COMPUTE_STEP + __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32; + +#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 1, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 2, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 3, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32)); + + // Permute and shuffle to copy a single value across the entire vector + // Then compute the multiplication + __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00); + __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 0); + __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 1); + R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11); + __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 2); + __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 3); + + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 4); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 5); + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 6); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 7); + + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 8); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 9); + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 10); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 11); + + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 12); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 13); + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 14); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 15); + + R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 16); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 17); + R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 18); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 19); + + R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 20); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 21); + R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 22); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 23); + + R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 24); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 25); + R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 26); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 27); + + R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 28); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 29); + R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 30); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 31); + +#undef COMPUTE_STEP + } + + // Transfer the results to the result matrix. + if (m + 32 <= rows && n + 32 <= cols) { + Index i = 0; + for (Index j = n; j < n + 32; j++) { + LinearMapper r0 = res.getLinearMapper(m, j); + LinearMapper r1 = res.getLinearMapper(m + 8, j); + LinearMapper r2 = res.getLinearMapper(m + 16, j); + LinearMapper r3 = res.getLinearMapper(m + 24, j); + typedef typename packet_traits::type Packet; + r0.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r0.template loadPacket(0))); + r1.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r1.template loadPacket(0))); + r2.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r2.template loadPacket(0))); + r3.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r3.template loadPacket(0))); + } + } else { + for (Index j = n; j < cols; j++) { + for (Index i = m; i < rows; i++) { + res(i, j) = blockO[(j - n) * 32 + (i - m)]; + } + } + } + + // Zero the result block so it can be reused + memset(blockO, 0, 32 * 32 * sizeof(QInt32)); + } + } +} + +// Below are the fully optimized versions that are correct only for sizes that +// are multiple of 32. It is about a 10% performance benefit to keep these +// implementations separate. + +// Arrange a block of the left input matrix in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... +// A1 B1 C1 D1 E1 F1 G1 H1 ... +// A2 B2 C2 D2 E2 F2 G2 H2 ... +// A3 B3 C3 D3 E3 F3 G3 H3 ... +// A4 B4 C4 D4 E4 F4 G4 H4 ... +// A5 B5 C5 D5 E5 F5 G5 H5 ... +// A6 B6 C6 D6 E6 F6 G6 H6 ... +// A7 B7 C7 D7 E7 F7 G7 H7 ... +// A8 ... +// ... +// +// Packing yields output (A0 beside B0 in memory): +// A0 B0 C0 D0 +// A1 B1 C1 D1 +// A2 B2 C2 D2 +// A3 B3 C3 D3 +// A4 B4 C4 D4 +// A5 B5 C5 D5 +// A6 B6 C6 D6 +// A7 B7 C7 D7 +// ... +// A31 B31 C31 D31 +// E0 F0 G0 H0 +// E1 F1 G1 H1 +// E2 F2 G2 H2 +// E3 F3 G3 H3 +// E4 F4 G4 H4 +// E5 F5 G5 H5 +// E6 F6 G6 H6 +// E7 F7 G7 H7 +// ... +// +// Four elements of the same row are arranged contiguously because maddubs and +// madd both perform an adjacent addition in the kernel. +template +struct gemm_pack_lhs { + EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs, + Index depth, Index rows, Index stride = 0, + Index offset = 0); +}; + +template +EIGEN_DONT_INLINE void gemm_pack_lhs:: +operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, + Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + typedef typename packet_traits::type Packet; + + // Use alternate function for weird sizes + if (rows % 32 != 0 || depth % 32 != 0) { + gemm_pack_lhs_any + lhs_pack; + return lhs_pack(blockA, lhs, depth, rows, stride, offset); + } + + // Get vector pointer + __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); + + // Pack rows in sets of 32 + for (Index m = 0; m < rows; m += 32) { + // Pack depth in sets of 8 + for (Index k = 0; k < depth; k += 8) { + // Load vectors + __m256i L_A = lhs.template loadPacket(m, k); + __m256i L_B = lhs.template loadPacket(m, k + 1); + + // Interleave 8-bit elements + __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); + __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); + + __m256i L_C = lhs.template loadPacket(m, k + 2); + __m256i L_D = lhs.template loadPacket(m, k + 3); + __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); + __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); + + // Interleave 16-bit elements + __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); + __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); + + // Use permute before we store to cross 128-bit lanes + __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); + _mm256_store_si256(blockA_256++, L_AD0); + + // Complete packing for 32 x 8 block + __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); + __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); + __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); + _mm256_store_si256(blockA_256++, L_AD8); + _mm256_store_si256(blockA_256++, L_AD16); + __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); + _mm256_store_si256(blockA_256++, L_AD24); + __m256i L_E = lhs.template loadPacket(m, k + 4); + __m256i L_F = lhs.template loadPacket(m, k + 5); + __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); + __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); + __m256i L_G = lhs.template loadPacket(m, k + 6); + __m256i L_H = lhs.template loadPacket(m, k + 7); + __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); + __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); + __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); + __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); + _mm256_store_si256(blockA_256++, L_EH0); + __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); + __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); + __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); + _mm256_store_si256(blockA_256++, L_EH8); + _mm256_store_si256(blockA_256++, L_EH16); + __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); + _mm256_store_si256(blockA_256++, L_EH24); + } + } +} + +// Arrange a block of the right input matrix in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... +// A1 B1 C1 D1 E1 F1 G1 H1 ... +// A2 B2 C2 D2 E2 F2 G2 H2 ... +// A3 B3 C3 D3 E3 F3 G3 H3 ... +// A4 B4 C4 D4 E4 F4 G4 H4 ... +// A5 B5 C5 D5 E5 F5 G5 H5 ... +// A6 B6 C6 D6 E6 F6 G6 H6 ... +// A7 B7 C7 D7 E7 F7 G7 H7 ... +// A8 ... +// ... +// +// Packing yields row major output (A0 beside A1 in memory): +// A0 A1 A2 A3 A4 A5 A6 A7 +// B0 B1 B2 B3 B4 B5 B6 B7 +// ... +// +// At least four elements of the same col are arranged contiguously because +// maddubs and madd both perform an adjacent addition in the kernel. We can +// save work by leaving 8 adjacent elements because kr = 8. +template +struct gemm_pack_rhs { + EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0); +}; + +template +EIGEN_DONT_INLINE void +gemm_pack_rhs:: +operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, + Index stride, Index offset) { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + typedef typename packet_traits::type Packet; + + // Use alternate function for weird sizes + if (cols % 32 != 0 || depth % 32 != 0) { + gemm_pack_rhs_any + rhs_pack; + return rhs_pack(blockB, rhs, depth, cols, stride, offset); + } + + // Get vector pointer + __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); + + // Perform a step of the packing for 4 columns + __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24; +#define PACK_STEP \ + R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \ + R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \ + R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \ + R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \ + R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \ + R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \ + R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \ + R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \ + _mm256_store_si256(blockB_256, R_AD_0); \ + _mm256_store_si256(blockB_256 + 8, R_AD_8); \ + _mm256_store_si256(blockB_256 + 16, R_AD_16); \ + _mm256_store_si256(blockB_256 + 24, R_AD_24); \ + blockB_256++; + + // Pack cols in sets of 32 + for (Index n = 0; n < cols; n += 32) { + // Pack depth in sets of 32 + for (Index k = 0; k < depth; k += 32) { + __m256i R_A = rhs.template loadPacket(k, n); + __m256i R_B = rhs.template loadPacket(k, n + 1); + __m256i R_C = rhs.template loadPacket(k, n + 2); + __m256i R_D = rhs.template loadPacket(k, n + 3); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 4); + R_B = rhs.template loadPacket(k, n + 5); + R_C = rhs.template loadPacket(k, n + 6); + R_D = rhs.template loadPacket(k, n + 7); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 8); + R_B = rhs.template loadPacket(k, n + 9); + R_C = rhs.template loadPacket(k, n + 10); + R_D = rhs.template loadPacket(k, n + 11); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 12); + R_B = rhs.template loadPacket(k, n + 13); + R_C = rhs.template loadPacket(k, n + 14); + R_D = rhs.template loadPacket(k, n + 15); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 16); + R_B = rhs.template loadPacket(k, n + 17); + R_C = rhs.template loadPacket(k, n + 18); + R_D = rhs.template loadPacket(k, n + 19); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 20); + R_B = rhs.template loadPacket(k, n + 21); + R_C = rhs.template loadPacket(k, n + 22); + R_D = rhs.template loadPacket(k, n + 23); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 24); + R_B = rhs.template loadPacket(k, n + 25); + R_C = rhs.template loadPacket(k, n + 26); + R_D = rhs.template loadPacket(k, n + 27); + PACK_STEP; + + R_A = rhs.template loadPacket(k, n + 28); + R_B = rhs.template loadPacket(k, n + 29); + R_C = rhs.template loadPacket(k, n + 30); + R_D = rhs.template loadPacket(k, n + 31); + PACK_STEP; + + blockB_256 += 24; + } + } +#undef PACK_STEP +} + +// Perform the actual multiplication on packed inputs +template +struct gebp_kernel { + typedef typename DataMapper::LinearMapper LinearMapper; + + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, + const QUInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); +}; + +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + // Use alternate function for weird sizes + if (rows % 32 != 0 || cols % 32 != 0 || depth % 32 != 0) { + gebp_kernel_any + gebp; + return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, + offsetA, offsetB); + } + + // Create result block + QInt32* blockO = aligned_new(32 * 32); + // Allocating the result block is about 5-10% faster than declaring stack + // space. It is unclear why this is the case. + // ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0); + memset(blockO, 0, 32 * 32 * sizeof(QInt32)); + + // Get vectorized pointers + __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO); + const __m256i* blockA_256 = reinterpret_cast(blockA); + const __m256i* blockB_256 = reinterpret_cast(blockB); + + // Loop over blocks of 32 columns + for (Index n = 0; n < cols; n += 32) { + // Reset index into blockA + Index indexL = 0; + // Loop over blocks of 32 rows + for (Index m = 0; m < rows; m += 32) { + // Reset index into blockB + Index indexR = n / 32 * depth; + // Loop over blocks of 8 on depth + for (Index k = 0; k < depth; k += 8) { + // Load inputs + __m256i L_AD0 = blockA_256[indexL++]; + __m256i L_AD8 = blockA_256[indexL++]; + __m256i L_AD16 = blockA_256[indexL++]; + __m256i L_AD24 = blockA_256[indexL++]; + __m256i L_EH0 = blockA_256[indexL++]; + __m256i L_EH8 = blockA_256[indexL++]; + __m256i L_EH16 = blockA_256[indexL++]; + __m256i L_EH24 = blockA_256[indexL++]; + __m256i R_AH0 = blockB_256[indexR++]; + __m256i R_AH4 = blockB_256[indexR++]; + __m256i R_AH8 = blockB_256[indexR++]; + __m256i R_AH12 = blockB_256[indexR++]; + __m256i R_AH16 = blockB_256[indexR++]; + __m256i R_AH20 = blockB_256[indexR++]; + __m256i R_AH24 = blockB_256[indexR++]; + __m256i R_AH28 = blockB_256[indexR++]; + + // This constant is used with madd to convert 16 bit to 32 bit + const __m256i ONE = _mm256_set1_epi32(0x00010001); + + // Declare variables used in COMPUTE_STEP + __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32; + +#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 1, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 2, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \ + \ + P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24); \ + P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ + P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24); \ + P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ + P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ + _mm256_store_si256( \ + blockO_256 + 4 * OFFSET + 3, \ + _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32)); + + // Permute and shuffle to copy a single value across the entire vector + // Then compute the multiplication + __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00); + __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 0); + __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 1); + R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11); + __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 2); + __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 3); + + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 4); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 5); + R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 6); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 7); + + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 8); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 9); + R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 10); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 11); + + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 12); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 13); + R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 14); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 15); + + R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 16); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 17); + R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 18); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 19); + + R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 20); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 21); + R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 22); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 23); + + R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 24); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 25); + R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 26); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 27); + + R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00); + R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD0, R_EH0, 28); + R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD1, R_EH1, 29); + R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11); + R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); + R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); + COMPUTE_STEP(R_AD2, R_EH2, 30); + R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); + R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); + COMPUTE_STEP(R_AD3, R_EH3, 31); + +#undef COMPUTE_STEP + } + + // Transfer the results to the result matrix + Index i = 0; + for (Index j = n; j < n + 32; j++) { + LinearMapper r0 = res.getLinearMapper(m, j); + LinearMapper r1 = res.getLinearMapper(m + 8, j); + LinearMapper r2 = res.getLinearMapper(m + 16, j); + LinearMapper r3 = res.getLinearMapper(m + 24, j); + typedef typename packet_traits::type Packet; + r0.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r0.template loadPacket(0))); + r1.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r1.template loadPacket(0))); + r2.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r2.template loadPacket(0))); + r3.template storePacket( + 0, _mm256_add_epi32(blockO_256[i++], + r3.template loadPacket(0))); + } + + // Zero the result block so it can be reused + memset(blockO, 0, 32 * 32 * sizeof(QInt32)); + } + } + aligned_delete(blockO, 32 * 32); +} + +#endif // EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT + +} // namespace internal +} // namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h new file mode 100644 index 000000000..9e0efae6c --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h @@ -0,0 +1,92 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// Copyright (C) 2015 Benoit Jacob +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef CXX11_SRC_FIXEDPOINT_MATMATPRODUCTNEON_H_ +#define CXX11_SRC_FIXEDPOINT_MATMATPRODUCTNEON_H_ + +namespace Eigen { +namespace internal { + +// AVX2 optimized implementation of the case where the lhs is encoded using +// signed 8bit +// integers and the rhs using unsigned 8bit integers. +#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT + +template +class gebp_traits { + public: + typedef QInt8 LhsScalar; + typedef QUInt8 RhsScalar; + typedef QInt32 ResScalar; + + enum { + // register block size along the M and N directions + // One for the current implementation + nr = 1, + mr = 1, + // Progress made at each iteration of the product loop + // also 1 for the current implementation + LhsProgress = 1, + RhsProgress = 1 + }; +}; + +// Mat-Mat product of a signed 8bit lhs with an unsigned 8bit rhs +template +struct gebp_kernel { + EIGEN_DONT_INLINE + void operator()(const DataMapper& res, const QInt8* blockA, + const QUInt8* blockB, Index rows, Index depth, Index cols, + QInt32 alpha, Index strideA = -1, Index strideB = -1, + Index offsetA = 0, Index offsetB = 0); +}; + +template +EIGEN_DONT_INLINE void gebp_kernel:: +operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, + Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, + Index strideB, Index offsetA, Index offsetB) { + EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); + + eigen_assert(alpha.value == 1); + eigen_assert(strideA == -1); + eigen_assert(strideB == -1); + eigen_assert(offsetA == 0); + eigen_assert(offsetB == 0); + + eigen_assert(rows > 0); + eigen_assert(cols > 0); + eigen_assert(depth > 0); + eigen_assert(blockA); + eigen_assert(blockB); + + for (Index j = 0; j < cols; ++j) { + Index startB = j * depth; + + for (Index i = 0; i < rows; ++i) { + Index startA = i * depth; + + for (Index k = 0; k < depth; ++k) { + res(i, j) += blockA[startA + k] * blockB[startB + k]; + } + } + } +} +#endif + +} // namespace internal +} // namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_MATMATPRODUCTNEON_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h new file mode 100644 index 000000000..f15200cab --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h @@ -0,0 +1,145 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef CXX11_SRC_FIXEDPOINT_MATVECPRODUCT_H_ +#define CXX11_SRC_FIXEDPOINT_MATVECPRODUCT_H_ + +namespace Eigen { +namespace internal { + +// Mat-Vec product +// Both lhs and rhs are encoded as 8bit signed integers +template +struct general_matrix_vector_product { + EIGEN_DONT_INLINE static void run(Index rows, Index cols, + const LhsMapper& lhs, const RhsMapper& rhs, + QInt32* res, Index resIncr, QInt8 alpha); +}; + +template +EIGEN_DONT_INLINE void general_matrix_vector_product< + Index, QInt8, LhsMapper, ColMajor, ConjugateLhs, QInt8, RhsMapper, + ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, QInt32* res, + Index resIncr, QInt8 alpha) { + eigen_assert(alpha.value == 1); + eigen_assert(resIncr == 1); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + + for (Index i = 0; i < rows; ++i) { + for (Index j = 0; j < cols; ++j) { + res[i] += lhs(i, j) * rhs(j, 0); + } + } +} + +// Mat-Vec product +// Both lhs and rhs are encoded as 16bit signed integers +template +struct general_matrix_vector_product { + EIGEN_DONT_INLINE static void run(Index rows, Index cols, + const LhsMapper& lhs, const RhsMapper& rhs, + QInt32* res, Index resIncr, QInt16 alpha); +}; + +template +EIGEN_DONT_INLINE void general_matrix_vector_product< + Index, QInt16, LhsMapper, ColMajor, ConjugateLhs, QInt16, RhsMapper, + ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, QInt32* res, + Index resIncr, QInt16 alpha) { + eigen_assert(alpha.value == 1); + eigen_assert(resIncr == 1); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + + for (Index i = 0; i < rows; ++i) { + for (Index j = 0; j < cols; ++j) { + res[i] += lhs(i, j) * rhs(j, 0); + } + } +} + +// Mat-Vec product +// The lhs is encoded using 8bit signed integers, the rhs using 8bit unsigned +// integers +template +struct general_matrix_vector_product { + EIGEN_DONT_INLINE static void run(Index rows, Index cols, + const LhsMapper& lhs, const RhsMapper& rhs, + QInt32* res, Index resIncr, QUInt8 alpha); +}; + +template +EIGEN_DONT_INLINE void general_matrix_vector_product< + Index, QInt8, LhsMapper, ColMajor, ConjugateLhs, QUInt8, RhsMapper, + ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, QInt32* res, + Index resIncr, QUInt8 alpha) { + eigen_assert(alpha.value == 1); + eigen_assert(resIncr == 1); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + + for (Index i = 0; i < rows; ++i) { + for (Index j = 0; j < cols; ++j) { + res[i] += lhs(i, j) * rhs(j, 0); + } + } +} + +// Mat-Vec product +// The lhs is encoded using bit unsigned integers, the rhs using 8bit signed +// integers +template +struct general_matrix_vector_product { + EIGEN_DONT_INLINE static void run(Index rows, Index cols, + const LhsMapper& lhs, const RhsMapper& rhs, + QInt32* res, Index resIncr, QInt8 alpha); +}; + +template +EIGEN_DONT_INLINE void general_matrix_vector_product< + Index, QUInt8, LhsMapper, ColMajor, ConjugateLhs, QInt8, RhsMapper, + ConjugateRhs, Version>::run(Index rows, Index cols, const LhsMapper& lhs, + const RhsMapper& rhs, QInt32* res, + Index resIncr, QInt8 alpha) { + eigen_assert(alpha.value == 1); + eigen_assert(resIncr == 1); + eigen_assert(rows > 0); + eigen_assert(cols > 0); + + for (Index i = 0; i < rows; ++i) { + for (Index j = 0; j < cols; ++j) { + res[i] += lhs(i, j) * rhs(j, 0); + } + } +} + +} // namespace internal +} // namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_MATVECPRODUCT_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h new file mode 100644 index 000000000..1a7cd03d4 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h @@ -0,0 +1,149 @@ +#ifndef CXX11_SRC_FIXEDPOINT_PACKETMATHAVX_H_ +#define CXX11_SRC_FIXEDPOINT_PACKETMATHAVX_H_ +#ifdef _MSC_VER + +#include +#include +#include + +#endif + +namespace Eigen { +namespace internal { + +typedef eigen_packet_wrapper<__m256i, 10> Packet32q8i; +typedef eigen_packet_wrapper<__m128i, 11> Packet16q8i; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet32q8i type; + typedef Packet16q8i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 32, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasConj = 0, + HasSetLinear = 0 + }; +}; + +template <> +struct unpacket_traits { + typedef QInt8 type; + typedef Packet16q8i half; + enum { + size = 32, + alignment = Aligned32, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template <> +struct unpacket_traits { + typedef QInt8 type; + typedef Packet16q8i half; + enum { + size = 16, + alignment = Aligned32, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +EIGEN_STRONG_INLINE Packet32q8i pset1(const QInt8& from) { + return _mm256_set1_epi8(from.value); +} +template <> +EIGEN_STRONG_INLINE Packet32q8i ploadu(const QInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q8i ploadu(const QInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_si128( + reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet32q8i pload(const QInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q8i pload(const QInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128( + reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt8* to, const Packet32q8i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256( + reinterpret_cast<__m256i*>(to), from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt8* to, const Packet16q8i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), + from.m_val); +} + +template <> +EIGEN_STRONG_INLINE void pstore(QInt8* to, const Packet32q8i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt8* to, const Packet16q8i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), + from.m_val); +} + +typedef __m256 Packet8f; + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q8i +pcast(const Packet8f& a, const Packet8f& b, + const Packet8f& c, const Packet8f& d) { + const __m256i a_conv = _mm256_cvtps_epi32(a); + const __m256i b_conv = _mm256_cvtps_epi32(b); + const __m256i c_conv = _mm256_cvtps_epi32(c); + const __m256i d_conv = _mm256_cvtps_epi32(d); + __m128i low = _mm256_castsi256_si128(a_conv); + __m128i high = _mm256_extractf128_si256(a_conv, 1); + __m128i tmp = _mm_packs_epi32(low, high); + __m128i low2 = _mm256_castsi256_si128(b_conv); + __m128i high2 = _mm256_extractf128_si256(b_conv, 1); + __m128i tmp2 = _mm_packs_epi32(low2, high2); + __m128i converted_low = _mm_packs_epi16(tmp, tmp2); + low = _mm256_castsi256_si128(c_conv); + high = _mm256_extractf128_si256(c_conv, 1); + tmp = _mm_packs_epi32(low, high); + low2 = _mm256_castsi256_si128(d_conv); + high2 = _mm256_extractf128_si256(d_conv, 1); + tmp2 = _mm_packs_epi32(low2, high2); + __m128i converted_high = _mm_packs_epi16(tmp, tmp2); + return _mm256_insertf128_si256(_mm256_castsi128_si256(converted_low), + converted_high, 1); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_PACKETMATHAVX_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h new file mode 100644 index 000000000..4c5e02abc --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h @@ -0,0 +1,547 @@ +#ifndef CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_ +#define CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_ +#ifdef _MSC_VER + +#include +#include +#include + +#endif + +inline int _mm256_extract_epi16_N0(const __m256i X) { + return _mm_extract_epi16(_mm256_extractf128_si256(X, 0 >> 3), 0 % 8); +} + +inline int _mm256_extract_epi16_N1(const __m256i X) { + return _mm_extract_epi16(_mm256_extractf128_si256(X, 1 >> 3), 1 % 8); +} + +inline int _mm256_extract_epi8_N0(const __m256i X) { + return _mm_extract_epi8(_mm256_extractf128_si256((X), 0 >> 4), 0 % 16); +} + +inline int _mm256_extract_epi8_N1(const __m256i X) { + return _mm_extract_epi8(_mm256_extractf128_si256((X), 1 >> 4), 1 % 16); +} + +namespace Eigen { +namespace internal { + +typedef eigen_packet_wrapper<__m256i, 20> Packet32q8i; +typedef eigen_packet_wrapper<__m256i, 21> Packet16q16i; +typedef eigen_packet_wrapper<__m256i, 22> Packet32q8u; +typedef eigen_packet_wrapper<__m128i, 23> Packet16q8i; +typedef eigen_packet_wrapper<__m128i, 25> Packet16q8u; +typedef eigen_packet_wrapper<__m128i, 26> Packet8q16i; +typedef eigen_packet_wrapper<__m256i, 27> Packet8q32i; +typedef eigen_packet_wrapper<__m128i, 28> Packet4q32i; + +#ifndef EIGEN_VECTORIZE_AVX512 +template <> +struct packet_traits : default_packet_traits { + typedef Packet32q8i type; + typedef Packet16q8i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 32, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet32q8u type; + typedef Packet16q8u half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 32, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet16q16i type; + typedef Packet8q16i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet8q32i type; + typedef Packet4q32i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + }; + enum { + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +#endif + +template <> +struct unpacket_traits { + typedef QInt8 type; + typedef Packet16q8i half; + enum { + size = 32, + alignment = Aligned32, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +struct unpacket_traits { + typedef QInt8 type; + typedef Packet16q8i half; + enum { + size = 16, + alignment = Aligned32, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +struct unpacket_traits { + typedef QInt16 type; + typedef Packet8q16i half; + enum { + size = 16, + alignment = Aligned32, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +struct unpacket_traits { + typedef QInt16 type; + typedef Packet8q16i half; + enum { + size = 8, + alignment = Aligned32, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +struct unpacket_traits { + typedef QUInt8 type; + typedef Packet16q8u half; + enum { + size = 32, + alignment = Aligned32, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +struct unpacket_traits { + typedef QInt32 type; + typedef Packet4q32i half; + enum { + size = 8, + alignment = Aligned32, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +// Unaligned load +template <> +EIGEN_STRONG_INLINE Packet32q8i ploadu(const QInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q8i ploadu(const QInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_si128( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u ploadu(const QUInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q16i ploadu(const QInt16* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet8q16i ploadu(const QInt16* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_si128( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i ploadu(const QInt32* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256( + reinterpret_cast(from)); +} + +// Aligned load +template <> +EIGEN_STRONG_INLINE Packet32q8i pload(const QInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q8i pload(const QInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u pload(const QUInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q16i pload(const QInt16* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet8q16i pload(const QInt16* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i pload(const QInt32* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256( + reinterpret_cast(from)); +} + +// Unaligned store +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt8* to, const Packet32q8i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256( + reinterpret_cast<__m256i*>(to), from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt8* to, const Packet16q8i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QUInt8* to, const Packet32q8u& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256( + reinterpret_cast<__m256i*>(to), from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt16* to, const Packet16q16i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256( + reinterpret_cast<__m256i*>(to), from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt16* to, const Packet8q16i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt32* to, const Packet8q32i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256( + reinterpret_cast<__m256i*>(to), from.m_val); +} + +// Aligned store +template <> +EIGEN_STRONG_INLINE void pstore(QInt32* to, const Packet8q32i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt16* to, const Packet16q16i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt16* to, const Packet8q16i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QUInt8* to, const Packet32q8u& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt8* to, const Packet32q8i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt8* to, const Packet16q8i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), + from.m_val); +} + +// Extract first element. +template <> +EIGEN_STRONG_INLINE QInt32 pfirst(const Packet8q32i& a) { + return _mm_cvtsi128_si32(_mm256_castsi256_si128(a)); +} +template <> +EIGEN_STRONG_INLINE QInt16 pfirst(const Packet16q16i& a) { + return _mm256_extract_epi16_N0(a.m_val); +} +template <> +EIGEN_STRONG_INLINE QUInt8 pfirst(const Packet32q8u& a) { + return static_cast(_mm256_extract_epi8_N0(a.m_val)); +} +template <> +EIGEN_STRONG_INLINE QInt8 pfirst(const Packet32q8i& a) { + return _mm256_extract_epi8_N0(a.m_val); +} + +// Initialize to constant value. +template <> +EIGEN_STRONG_INLINE Packet32q8i pset1(const QInt8& from) { + return _mm256_set1_epi8(from.value); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u pset1(const QUInt8& from) { + return _mm256_set1_epi8(static_cast(from.value)); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i pset1(const QInt32& from) { + return _mm256_set1_epi32(from.value); +} + +// Basic arithmetic packet ops for QInt32. +template <> +EIGEN_STRONG_INLINE Packet8q32i padd(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_add_epi32(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet16q16i pset1(const QInt16& from) { + return _mm256_set1_epi16(from.value); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i psub(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_sub_epi32(a.m_val, b.m_val); +} +// Note: mullo truncates the result to 32 bits. +template <> +EIGEN_STRONG_INLINE Packet8q32i pmul(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_mullo_epi32(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i pnegate(const Packet8q32i& a) { + return _mm256_sub_epi32(_mm256_setzero_si256(), a.m_val); +} + +// Min and max. +template <> +EIGEN_STRONG_INLINE Packet8q32i pmin(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_min_epi32(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet8q32i pmax(const Packet8q32i& a, + const Packet8q32i& b) { + return _mm256_max_epi32(a.m_val, b.m_val); +} + +template <> +EIGEN_STRONG_INLINE Packet16q16i pmin(const Packet16q16i& a, + const Packet16q16i& b) { + return _mm256_min_epi16(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet16q16i pmax(const Packet16q16i& a, + const Packet16q16i& b) { + return _mm256_max_epi16(a.m_val, b.m_val); +} + +template <> +EIGEN_STRONG_INLINE Packet32q8u pmin(const Packet32q8u& a, + const Packet32q8u& b) { + return _mm256_min_epu8(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet32q8u pmax(const Packet32q8u& a, + const Packet32q8u& b) { + return _mm256_max_epu8(a.m_val, b.m_val); +} + +template <> +EIGEN_STRONG_INLINE Packet32q8i pmin(const Packet32q8i& a, + const Packet32q8i& b) { + return _mm256_min_epi8(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet32q8i pmax(const Packet32q8i& a, + const Packet32q8i& b) { + return _mm256_max_epi8(a.m_val, b.m_val); +} + +// Reductions. +template <> +EIGEN_STRONG_INLINE QInt32 predux_min(const Packet8q32i& a) { + __m256i tmp = _mm256_min_epi32(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_min_epi32(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return pfirst( + _mm256_min_epi32(tmp, _mm256_shuffle_epi32(tmp, 1))); +} +template <> +EIGEN_STRONG_INLINE QInt32 predux_max(const Packet8q32i& a) { + __m256i tmp = _mm256_max_epi32(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_max_epi32(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return pfirst( + _mm256_max_epi32(tmp, _mm256_shuffle_epi32(tmp, 1))); +} + +template <> +EIGEN_STRONG_INLINE QInt16 predux_min(const Packet16q16i& a) { + __m256i tmp = _mm256_min_epi16(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_min_epi16(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_min_epi16(tmp, _mm256_shuffle_epi32(tmp, 1)); + return std::min(_mm256_extract_epi16_N0(tmp), _mm256_extract_epi16_N1(tmp)); +} +template <> +EIGEN_STRONG_INLINE QInt16 predux_max(const Packet16q16i& a) { + __m256i tmp = _mm256_max_epi16(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_max_epi16(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_max_epi16(tmp, _mm256_shuffle_epi32(tmp, 1)); + return std::max(_mm256_extract_epi16_N0(tmp), _mm256_extract_epi16_N1(tmp)); +} + +template <> +EIGEN_STRONG_INLINE QUInt8 predux_min(const Packet32q8u& a) { + __m256i tmp = _mm256_min_epu8(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_min_epu8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_min_epu8(tmp, _mm256_shuffle_epi32(tmp, 1)); + tmp = _mm256_min_epu8(tmp, + _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return std::min(static_cast(_mm256_extract_epi8_N0(tmp)), + static_cast(_mm256_extract_epi8_N1(tmp))); +} +template <> +EIGEN_STRONG_INLINE QUInt8 predux_max(const Packet32q8u& a) { + __m256i tmp = _mm256_max_epu8(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_max_epu8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_max_epu8(tmp, _mm256_shuffle_epi32(tmp, 1)); + tmp = _mm256_max_epu8(tmp, + _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return std::max(static_cast(_mm256_extract_epi8_N0(tmp)), + static_cast(_mm256_extract_epi8_N1(tmp))); +} + +template <> +EIGEN_STRONG_INLINE QInt8 predux_min(const Packet32q8i& a) { + __m256i tmp = _mm256_min_epi8(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_min_epi8(tmp, _mm256_shuffle_epi32(tmp, 1)); + tmp = _mm256_min_epi8(tmp, + _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return std::min(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp)); +} +template <> +EIGEN_STRONG_INLINE QInt8 predux_max(const Packet32q8i& a) { + __m256i tmp = _mm256_max_epi8(a, _mm256_permute2f128_si256(a, a, 1)); + tmp = + _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + tmp = _mm256_max_epi8(tmp, _mm256_shuffle_epi32(tmp, 1)); + tmp = _mm256_max_epi8(tmp, + _mm256_shufflelo_epi16(tmp, _MM_SHUFFLE(1, 0, 3, 2))); + return std::max(_mm256_extract_epi8_N0(tmp), _mm256_extract_epi8_N1(tmp)); +} + +// Vectorized scaling of Packet32q8i by float. +template <> +struct scalar_product_op : binary_op_base { + typedef typename ScalarBinaryOpTraits::ReturnType result_type; +#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN + EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op) +#else + scalar_product_op() { EIGEN_SCALAR_BINARY_OP_PLUGIN } +#endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type + operator()(const QInt32& a, const double& b) const { + return a * b; + } + + EIGEN_STRONG_INLINE const Packet8q32i packetOp(const Packet8q32i& a, + const double& b) const { + __m256d scale = _mm256_set1_pd(b); + __m256d a_lo = _mm256_cvtepi32_pd(_mm256_castsi256_si128(a)); + __m128i result_lo = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_lo)); + __m256d a_hi = _mm256_cvtepi32_pd(_mm256_extracti128_si256(a, 1)); + __m128i result_hi = _mm256_cvtpd_epi32(_mm256_mul_pd(scale, a_hi)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, + 1); + } +}; + +template <> +struct functor_traits> { + enum { Cost = 4 * NumTraits::MulCost, PacketAccess = true }; +}; + +} // end namespace internal +} // end namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h new file mode 100644 index 000000000..5a0ae2e8c --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h @@ -0,0 +1,516 @@ +#ifndef CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_ +#define CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_ + +#include "PacketMathAVX2.h" + +namespace Eigen { +namespace internal { + +typedef eigen_packet_wrapper<__m512i, 30> Packet64q8i; +typedef eigen_packet_wrapper<__m512i, 31> Packet32q16i; +typedef eigen_packet_wrapper<__m512i, 32> Packet64q8u; +typedef eigen_packet_wrapper<__m512i, 33> Packet16q32i; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet64q8i type; + typedef Packet32q8i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 64, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet64q8u type; + typedef Packet32q8u half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 64, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet32q16i type; + typedef Packet16q16i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 32, + }; + enum { + HasAdd = 0, + HasSub = 0, + HasMul = 0, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet16q32i type; + typedef Packet8q32i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + }; + enum { + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 0, + HasSetLinear = 0 + }; +}; + +template <> +struct unpacket_traits { + typedef QInt8 type; + typedef Packet32q8i half; + enum { + size = 64, + alignment = Aligned64, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +struct unpacket_traits { + typedef QInt16 type; + typedef Packet16q16i half; + enum { + size = 32, + alignment = Aligned64, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +struct unpacket_traits { + typedef QUInt8 type; + typedef Packet32q8u half; + enum { + size = 64, + alignment = Aligned64, + masked_load_available = false, + masked_store_available = false + }; +}; +template <> +struct unpacket_traits { + typedef QInt32 type; + typedef Packet8q32i half; + enum { + size = 16, + alignment = Aligned64, + masked_load_available = false, + masked_store_available = false + }; +}; + +// Unaligned load +template <> +EIGEN_STRONG_INLINE Packet64q8i ploadu(const QInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet32q16i ploadu(const QInt16* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet64q8u ploadu(const QUInt8* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i ploadu(const QInt32* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} + +// Aligned load +template <> +EIGEN_STRONG_INLINE Packet64q8i pload(const QInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet32q16i pload(const QInt16* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet64q8u pload(const QUInt8* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i pload(const QInt32* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} + +// Unaligned store +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt8* to, const Packet64q8i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt16* to, const Packet32q16i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QUInt8* to, const Packet64q8u& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(QInt32* to, const Packet16q32i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from.m_val); +} + +// Aligned store +template <> +EIGEN_STRONG_INLINE void pstore(QInt32* to, const Packet16q32i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_si512(reinterpret_cast<__m512i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QUInt8* to, const Packet64q8u& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_si512(reinterpret_cast<__m512i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt8* to, const Packet64q8i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_si512(reinterpret_cast<__m512i*>(to), + from.m_val); +} +template <> +EIGEN_STRONG_INLINE void pstore(QInt16* to, const Packet32q16i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_si512(reinterpret_cast<__m512i*>(to), + from.m_val); +} + +// Extract first element. +template <> +EIGEN_STRONG_INLINE QInt32 pfirst(const Packet16q32i& a) { + return _mm_cvtsi128_si32(_mm512_extracti32x4_epi32(a, 0)); +} +template <> +EIGEN_STRONG_INLINE QUInt8 pfirst(const Packet64q8u& a) { + return static_cast( + _mm_extract_epi8(_mm512_extracti32x4_epi32(a.m_val, 0), 0)); +} +template <> +EIGEN_STRONG_INLINE QInt8 pfirst(const Packet64q8i& a) { + return _mm_extract_epi8(_mm512_extracti32x4_epi32(a.m_val, 0), 0); +} +template <> +EIGEN_STRONG_INLINE QInt16 pfirst(const Packet32q16i& a) { + return _mm_extract_epi16(_mm512_extracti32x4_epi32(a.m_val, 0), 0); +} + +// Initialize to constant value. +template <> +EIGEN_STRONG_INLINE Packet64q8i pset1(const QInt8& from) { + return _mm512_set1_epi8(from.value); +} +template <> +EIGEN_STRONG_INLINE Packet32q16i pset1(const QInt16& from) { + return _mm512_set1_epi16(from.value); +} +template <> +EIGEN_STRONG_INLINE Packet64q8u pset1(const QUInt8& from) { + return _mm512_set1_epi8(static_cast(from.value)); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i pset1(const QInt32& from) { + return _mm512_set1_epi32(from.value); +} + +// Basic arithmetic packet ops for QInt32. +template <> +EIGEN_STRONG_INLINE Packet16q32i padd(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_add_epi32(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i psub(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_sub_epi32(a.m_val, b.m_val); +} +// Note: mullo truncates the result to 32 bits. +template <> +EIGEN_STRONG_INLINE Packet16q32i pmul(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_mullo_epi32(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i pnegate(const Packet16q32i& a) { + return _mm512_sub_epi32(_mm512_setzero_si512(), a.m_val); +} + +// Min and max. +template <> +EIGEN_STRONG_INLINE Packet16q32i pmin(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_min_epi32(a.m_val, b.m_val); +} +template <> +EIGEN_STRONG_INLINE Packet16q32i pmax(const Packet16q32i& a, + const Packet16q32i& b) { + return _mm512_max_epi32(a.m_val, b.m_val); +} + +template <> +EIGEN_STRONG_INLINE Packet64q8u pmin(const Packet64q8u& a, + const Packet64q8u& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_min_epu8(a.m_val, b.m_val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.m_val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.m_val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.m_val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.m_val, 1); + __m256i r0 = _mm256_min_epu8(ap0, bp0); + __m256i r1 = _mm256_min_epu8(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet64q8u pmax(const Packet64q8u& a, + const Packet64q8u& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_max_epu8(a.m_val, b.m_val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.m_val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.m_val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.m_val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.m_val, 1); + __m256i r0 = _mm256_max_epu8(ap0, bp0); + __m256i r1 = _mm256_max_epu8(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet64q8i pmin(const Packet64q8i& a, + const Packet64q8i& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_min_epi8(a.m_val, b.m_val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.m_val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.m_val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.m_val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.m_val, 1); + __m256i r0 = _mm256_min_epi8(ap0, bp0); + __m256i r1 = _mm256_min_epi8(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet32q16i pmin(const Packet32q16i& a, + const Packet32q16i& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_min_epi16(a.m_val, b.m_val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.m_val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.m_val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.m_val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.m_val, 1); + __m256i r0 = _mm256_min_epi16(ap0, bp0); + __m256i r1 = _mm256_min_epi16(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet64q8i pmax(const Packet64q8i& a, + const Packet64q8i& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_max_epi8(a.m_val, b.m_val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.m_val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.m_val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.m_val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.m_val, 1); + __m256i r0 = _mm256_max_epi8(ap0, bp0); + __m256i r1 = _mm256_max_epi8(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet32q16i pmax(const Packet32q16i& a, + const Packet32q16i& b) { +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_max_epi16(a.m_val, b.m_val); +#else + __m256i ap0 = _mm512_extracti32x8_epi32(a.m_val, 0); + __m256i ap1 = _mm512_extracti32x8_epi32(a.m_val, 1); + __m256i bp0 = _mm512_extracti32x8_epi32(b.m_val, 0); + __m256i bp1 = _mm512_extracti32x8_epi32(b.m_val, 1); + __m256i r0 = _mm256_max_epi16(ap0, bp0); + __m256i r1 = _mm256_max_epi16(ap1, bp1); + return _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); +#endif +} + +// Reductions. +template <> +EIGEN_STRONG_INLINE QInt32 predux_min(const Packet16q32i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.m_val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.m_val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.m_val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.m_val, 3); + Packet4i res = + _mm_min_epi32(_mm_min_epi32(lane0, lane1), _mm_min_epi32(lane2, lane3)); + res = _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + res = _mm_min_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))); + return pfirst(res); +} +template <> +EIGEN_STRONG_INLINE QInt32 predux_max(const Packet16q32i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.m_val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.m_val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.m_val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.m_val, 3); + Packet4i res = + _mm_max_epi32(_mm_max_epi32(lane0, lane1), _mm_max_epi32(lane2, lane3)); + res = _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + res = _mm_max_epi32(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))); + return pfirst(res); +} +template <> +EIGEN_STRONG_INLINE QInt16 predux_min(const Packet32q16i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.m_val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.m_val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.m_val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.m_val, 3); + Packet4i res = + _mm_min_epi16(_mm_min_epi16(lane0, lane1), _mm_min_epi16(lane2, lane3)); + res = _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + res = _mm_min_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))); + std::uint32_t w = pfirst(res); + return std::min( + {static_cast(w >> 16), static_cast(w)}); +} +template <> +EIGEN_STRONG_INLINE QInt16 predux_max(const Packet32q16i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.m_val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.m_val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.m_val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.m_val, 3); + Packet4i res = + _mm_max_epi16(_mm_max_epi16(lane0, lane1), _mm_max_epi16(lane2, lane3)); + res = _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + res = _mm_max_epi16(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))); + std::uint32_t w = pfirst(res); + return std::max( + {static_cast(w >> 16), static_cast(w)}); +} +template <> +EIGEN_STRONG_INLINE QUInt8 predux_min(const Packet64q8u& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.m_val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.m_val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.m_val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.m_val, 3); + Packet4i res = + _mm_min_epu8(_mm_min_epu8(lane0, lane1), _mm_min_epu8(lane2, lane3)); + res = _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + res = _mm_min_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))); + std::uint32_t w = pfirst(res); + return std::min( + {static_cast(w >> 24), static_cast(w >> 16), + static_cast(w >> 8), static_cast(w)}); +} +template <> +EIGEN_STRONG_INLINE QUInt8 predux_max(const Packet64q8u& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.m_val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.m_val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.m_val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.m_val, 3); + Packet4i res = + _mm_max_epu8(_mm_max_epu8(lane0, lane1), _mm_max_epu8(lane2, lane3)); + res = _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + res = _mm_max_epu8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))); + std::uint32_t w = pfirst(res); + return std::max( + {static_cast(w >> 24), static_cast(w >> 16), + static_cast(w >> 8), static_cast(w)}); +} +template <> +EIGEN_STRONG_INLINE QInt8 predux_min(const Packet64q8i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.m_val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.m_val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.m_val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.m_val, 3); + Packet4i res = + _mm_min_epi8(_mm_min_epi8(lane0, lane1), _mm_min_epi8(lane2, lane3)); + res = _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + res = _mm_min_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))); + std::uint32_t w = pfirst(res); + return std::min( + {static_cast(w >> 24), static_cast(w >> 16), + static_cast(w >> 8), static_cast(w)}); +} +template <> +EIGEN_STRONG_INLINE QInt8 predux_max(const Packet64q8i& a) { + Packet4i lane0 = _mm512_extracti32x4_epi32(a.m_val, 0); + Packet4i lane1 = _mm512_extracti32x4_epi32(a.m_val, 1); + Packet4i lane2 = _mm512_extracti32x4_epi32(a.m_val, 2); + Packet4i lane3 = _mm512_extracti32x4_epi32(a.m_val, 3); + Packet4i res = + _mm_max_epi8(_mm_max_epi8(lane0, lane1), _mm_max_epi8(lane2, lane3)); + res = _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 3, 2))); + res = _mm_max_epi8(res, _mm_shuffle_epi32(res, _MM_SHUFFLE(0, 0, 0, 1))); + std::uint32_t w = pfirst(res); + return std::min( + {static_cast(w >> 24), static_cast(w >> 16), + static_cast(w >> 8), static_cast(w)}); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h new file mode 100644 index 000000000..5dd2cd309 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h @@ -0,0 +1,93 @@ +#ifndef CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_ +#define CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_ + +namespace Eigen { +namespace internal { + +typedef __m256 Packet8f; + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet8f pcast(const Packet8q32i& a) { + return _mm256_cvtepi32_ps(a.m_val); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet8q32i pcast(const Packet8f& a) { + return _mm256_cvtps_epi32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q8i +pcast(const Packet8q32i& a, const Packet8q32i& b, + const Packet8q32i& c, const Packet8q32i& d) { + __m256i converted = _mm256_packs_epi16(_mm256_packs_epi32(a.m_val, b.m_val), + _mm256_packs_epi32(c.m_val, d.m_val)); + // Since packs does not cross 128 bit lane boundaries, + // we have to permute to properly order the final result. + const __m256i permute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + return _mm256_permutevar8x32_epi32(converted, permute_mask); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q8i +pcast(const Packet8f& a, const Packet8f& b, + const Packet8f& c, const Packet8f& d) { + const __m256i a_conv = _mm256_cvtps_epi32(a); + const __m256i b_conv = _mm256_cvtps_epi32(b); + const __m256i c_conv = _mm256_cvtps_epi32(c); + const __m256i d_conv = _mm256_cvtps_epi32(d); + __m256i converted = _mm256_packs_epi16(_mm256_packs_epi32(a_conv, b_conv), + _mm256_packs_epi32(c_conv, d_conv)); + const __m256i permute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + return _mm256_permutevar8x32_epi32(converted, permute_mask); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q8u +pcast(const Packet8q32i& a, const Packet8q32i& b, + const Packet8q32i& c, const Packet8q32i& d) { + // _mm256_packus_epi32 trims negative numbers to 0 but we can't allow numbers + // that are too large because _mm256_packus_epi16 expects signed input + // (example of problem input: 0x11111111, which saturates to 0xffff = -1, + // which saturates to 0). + const __m256i a_clip = _mm256_min_epi32(a, _mm256_set1_epi32(255)); + const __m256i b_clip = _mm256_min_epi32(b, _mm256_set1_epi32(255)); + const __m256i c_clip = _mm256_min_epi32(c, _mm256_set1_epi32(255)); + const __m256i d_clip = _mm256_min_epi32(d, _mm256_set1_epi32(255)); + const __m256i converted = _mm256_packus_epi16( + _mm256_packus_epi32(a_clip, b_clip), _mm256_packus_epi32(c_clip, d_clip)); + // Since packus does not cross 128 bit lane boundaries, + // we have to permute to properly order the final result. + const __m256i permute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + return _mm256_permutevar8x32_epi32(converted, permute_mask); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h new file mode 100644 index 000000000..17408d13a --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h @@ -0,0 +1,191 @@ +#ifndef CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_ +#define CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_ + +namespace Eigen { +namespace internal { + +typedef __m512 Packet16f; +typedef __m512i Packet16i; + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet16f pcast(const Packet16q32i& a) { + return _mm512_cvtepi32_ps(a.m_val); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet16q32i pcast(const Packet16f& a) { + return _mm512_cvtps_epi32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q16i pcast(const Packet16f& a, + const Packet16f& b) { + Packet16i a_int = _mm512_cvtps_epi32(a); + Packet16i b_int = _mm512_cvtps_epi32(b); +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_packs_epi32(a_int, b_int); +#else + Packet8i ab_int16_low = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_castsi512_si256(a_int), + _mm512_castsi512_si256(b_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i ab_int16_high = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_extracti32x8_epi32(a_int, 1), + _mm512_extracti32x8_epi32(b_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + return _mm512_inserti32x8(_mm512_castsi256_si512(ab_int16_low), ab_int16_high, + 1); +#endif +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet64q8i pcast(const Packet16f& a, + const Packet16f& b, + const Packet16f& c, + const Packet16f& d) { + Packet16i a_int = _mm512_cvtps_epi32(a); + Packet16i b_int = _mm512_cvtps_epi32(b); + Packet16i c_int = _mm512_cvtps_epi32(c); + Packet16i d_int = _mm512_cvtps_epi32(d); +#ifdef EIGEN_VECTORIZE_AVX512BW + return _mm512_packs_epi16(_mm512_packs_epi32(a_int, b_int), + _mm512_packs_epi32(c_int, d_int)); +#else + Packet8i ab_int16_low = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_castsi512_si256(a_int), + _mm512_castsi512_si256(b_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i cd_int16_low = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_castsi512_si256(c_int), + _mm512_castsi512_si256(d_int)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i ab_int16_high = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_extracti32x8_epi32(a_int, 1), + _mm512_extracti32x8_epi32(b_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i cd_int16_high = _mm256_permute4x64_epi64( + _mm256_packs_epi32(_mm512_extracti32x8_epi32(c_int, 1), + _mm512_extracti32x8_epi32(d_int, 1)), + _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i abcd_int8_low = _mm256_permute4x64_epi64( + _mm256_packs_epi16(ab_int16_low, cd_int16_low), _MM_SHUFFLE(0, 2, 1, 3)); + Packet8i abcd_int8_high = + _mm256_permute4x64_epi64(_mm256_packs_epi16(ab_int16_high, cd_int16_high), + _MM_SHUFFLE(0, 2, 1, 3)); + return _mm512_inserti32x8(_mm512_castsi256_si512(abcd_int8_low), + abcd_int8_high, 1); +#endif +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet64q8i +pcast(const Packet16q32i& a, const Packet16q32i& b, + const Packet16q32i& c, const Packet16q32i& d) { + __m128i a_part = _mm512_cvtsepi32_epi8(a); + __m128i b_part = _mm512_cvtsepi32_epi8(b); + __m128i c_part = _mm512_cvtsepi32_epi8(c); + __m128i d_part = _mm512_cvtsepi32_epi8(d); + __m256i ab = + _mm256_inserti128_si256(_mm256_castsi128_si256(a_part), b_part, 1); + __m256i cd = + _mm256_inserti128_si256(_mm256_castsi128_si256(c_part), d_part, 1); + __m512i converted = _mm512_inserti64x4(_mm512_castsi256_si512(ab), cd, 1); + return converted; +} + +template <> +EIGEN_STRONG_INLINE Packet32q16i pcast( + const Packet16q32i& a, const Packet16q32i& b) { + __m256i a_part = _mm512_cvtsepi32_epi16(a); + __m256i b_part = _mm512_cvtsepi32_epi16(b); + __m512i converted = + _mm512_inserti64x4(_mm512_castsi256_si512(a_part), b_part, 1); + return converted; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet64q8u +pcast(const Packet16q32i& a, const Packet16q32i& b, + const Packet16q32i& c, const Packet16q32i& d) { + // Brute-force saturation since there isn't a pack operation for unsigned + // numbers that keeps the elements in order. + __m128i a_part = _mm512_cvtepi32_epi8(_mm512_max_epi32( + _mm512_min_epi32(a, _mm512_set1_epi32(255)), _mm512_setzero_si512())); + __m128i b_part = _mm512_cvtepi32_epi8(_mm512_max_epi32( + _mm512_min_epi32(b, _mm512_set1_epi32(255)), _mm512_setzero_si512())); + __m128i c_part = _mm512_cvtepi32_epi8(_mm512_max_epi32( + _mm512_min_epi32(c, _mm512_set1_epi32(255)), _mm512_setzero_si512())); + __m128i d_part = _mm512_cvtepi32_epi8(_mm512_max_epi32( + _mm512_min_epi32(d, _mm512_set1_epi32(255)), _mm512_setzero_si512())); + __m256i ab = + _mm256_inserti128_si256(_mm256_castsi128_si256(a_part), b_part, 1); + __m256i cd = + _mm256_inserti128_si256(_mm256_castsi128_si256(c_part), d_part, 1); + __m512i converted = _mm512_inserti64x4(_mm512_castsi256_si512(ab), cd, 1); + return converted; +} + +#if 0 +// The type Packet32q16u does not exist for AVX-512 yet +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE Packet32q16u +pcast(const Packet16q32i& a, + const Packet16q32i& b) { + // Brute-force saturation since there isn't a pack operation for unsigned + // numbers that keeps the elements in order. + __m256i a_part = + _mm512_cvtepi32_epi16(_mm512_max_epi32( + _mm512_min_epi32(a, _mm512_set1_epi32(65535)), _mm512_setzero_si512())); + __m256i b_part = _mm512_cvtepi32_epi16( + _mm512_max_epi32(_mm512_min_epi32(b, _mm512_set1_epi32(65535)), + _mm512_setzero_si512())); + __m512i converted = + _mm512_inserti64x4(_mm512_castsi256_si512(a_part), b_part, 1); + return converted; +} +#endif + +} // end namespace internal +} // end namespace Eigen + +#endif // CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_ diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/MatrixFunctions b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/MatrixFunctions new file mode 100644 index 000000000..314b325f8 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/MatrixFunctions @@ -0,0 +1 @@ +#include "unsupported/Eigen/MatrixFunctions" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/SpecialFunctions b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/SpecialFunctions new file mode 100644 index 000000000..ad13359ab --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/eigen3/unsupported/Eigen/SpecialFunctions @@ -0,0 +1 @@ +#include "unsupported/Eigen/SpecialFunctions" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/protobuf/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/protobuf/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/protobuf/protobuf.patch b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/protobuf/protobuf.patch new file mode 100644 index 000000000..8ce4a8437 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/protobuf/protobuf.patch @@ -0,0 +1,43 @@ +diff --git a/BUILD b/BUILD +index dbae719ff..87dc38470 100644 +--- a/BUILD ++++ b/BUILD +@@ -23,7 +23,7 @@ config_setting( + # ZLIB configuration + ################################################################################ + +-ZLIB_DEPS = ["@zlib//:zlib"] ++ZLIB_DEPS = ["@zlib"] + + ################################################################################ + # Protobuf Runtime Library +@@ -143,6 +143,7 @@ cc_library( + copts = COPTS, + includes = ["src/"], + linkopts = LINK_OPTS, ++ alwayslink = 1, + visibility = ["//visibility:public"], + ) + +@@ -213,6 +214,7 @@ cc_library( + copts = COPTS, + includes = ["src/"], + linkopts = LINK_OPTS, ++ alwayslink = 1, + visibility = ["//visibility:public"], + deps = [":protobuf_lite"] + PROTOBUF_DEPS, + ) +diff --git a/protobuf.bzl b/protobuf.bzl +index e0653321f..253d9cbb5 100644 +--- a/protobuf.bzl ++++ b/protobuf.bzl +@@ -84,7 +84,9 @@ def _proto_gen_impl(ctx): + + for dep in ctx.attr.deps: + import_flags += dep.proto.import_flags + deps += dep.proto.deps ++ import_flags = depset(import_flags).to_list() ++ deps = depset(deps).to_list() + + if not ctx.attr.gen_cc and not ctx.attr.gen_py and not ctx.executable.plugin: + return struct( \ No newline at end of file diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/repo.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/repo.bzl new file mode 100644 index 000000000..697410e76 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/repo.bzl @@ -0,0 +1,240 @@ +# Copyright 2017 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for defining TensorFlow Bazel dependencies.""" + +_SINGLE_URL_WHITELIST = depset([ + "arm_compiler", +]) + +def _is_windows(ctx): + return ctx.os.name.lower().find("windows") != -1 + +def _wrap_bash_cmd(ctx, cmd): + if _is_windows(ctx): + bazel_sh = _get_env_var(ctx, "BAZEL_SH") + if not bazel_sh: + fail("BAZEL_SH environment variable is not set") + cmd = [bazel_sh, "-l", "-c", " ".join(["\"%s\"" % s for s in cmd])] + return cmd + +def _get_env_var(ctx, name): + if name in ctx.os.environ: + return ctx.os.environ[name] + else: + return None + +# Checks if we should use the system lib instead of the bundled one +def _use_system_lib(ctx, name): + syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS") + if syslibenv: + for n in syslibenv.strip().split(","): + if n.strip() == name: + return True + return False + +# Executes specified command with arguments and calls 'fail' if it exited with +# non-zero code +def _execute_and_check_ret_code(repo_ctx, cmd_and_args): + result = repo_ctx.execute(cmd_and_args, timeout = 60) + if result.return_code != 0: + fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" + + "Stderr: {3}").format( + " ".join(cmd_and_args), + result.return_code, + result.stdout, + result.stderr, + )) + +def _repos_are_siblings(): + return Label("@foo//bar").workspace_root.startswith("../") + +# Apply a patch_file to the repository root directory +# Runs 'patch -p1' +def _apply_patch(ctx, patch_file): + # Don't check patch on Windows, because patch is only available under bash. + if not _is_windows(ctx) and not ctx.which("patch"): + fail("patch command is not found, please install it") + cmd = _wrap_bash_cmd( + ctx, + ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)], + ) + _execute_and_check_ret_code(ctx, cmd) + +def _apply_delete(ctx, paths): + for path in paths: + if path.startswith("/"): + fail("refusing to rm -rf path starting with '/': " + path) + if ".." in path: + fail("refusing to rm -rf path containing '..': " + path) + cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths]) + _execute_and_check_ret_code(ctx, cmd) + +def _tf_http_archive(ctx): + if ("mirror.tensorflow.org" not in ctx.attr.urls[0] and + (len(ctx.attr.urls) < 2 and + ctx.attr.name not in _SINGLE_URL_WHITELIST.to_list())): + fail("tf_http_archive(urls) must have redundant URLs. The " + + "mirror.tensorflow.org URL must be present and it must come first. " + + "Even if you don't have permission to mirror the file, please " + + "put the correctly formatted mirror URL there anyway, because " + + "someone will come along shortly thereafter and mirror the file.") + + urls = [] + for url in ctx.attr.urls: + if "PWD" in url: + url = url.replace("PWD", _get_env_var(ctx, "PWD")) + urls.append(url) + use_syslib = _use_system_lib(ctx, ctx.attr.name) + + # Work around the bazel bug that redownloads the whole library. + # Remove this after https://github.com/bazelbuild/bazel/issues/10515 is fixed. + if ctx.attr.additional_build_files: + for internal_src in ctx.attr.additional_build_files: + _ = ctx.path(Label(internal_src)) + + if not use_syslib: + ctx.download_and_extract( + urls, + "", + ctx.attr.sha256, + ctx.attr.type, + ctx.attr.strip_prefix, + ) + if ctx.attr.delete: + _apply_delete(ctx, ctx.attr.delete) + if ctx.attr.patch_file != None: + _apply_patch(ctx, ctx.attr.patch_file) + + if use_syslib and ctx.attr.system_build_file != None: + # Use BUILD.bazel to avoid conflict with third party projects with + # BUILD or build (directory) underneath. + ctx.template("BUILD.bazel", ctx.attr.system_build_file, { + "%prefix%": ".." if _repos_are_siblings() else "external", + }, False) + + elif ctx.attr.build_file != None: + # Use BUILD.bazel to avoid conflict with third party projects with + # BUILD or build (directory) underneath. + ctx.template("BUILD.bazel", ctx.attr.build_file, { + "%prefix%": ".." if _repos_are_siblings() else "external", + }, False) + + if use_syslib: + for internal_src, external_dest in ctx.attr.system_link_files.items(): + ctx.symlink(Label(internal_src), ctx.path(external_dest)) + + if ctx.attr.additional_build_files: + for internal_src, external_dest in ctx.attr.additional_build_files.items(): + ctx.symlink(Label(internal_src), ctx.path(external_dest)) + +tf_http_archive = repository_rule( + implementation = _tf_http_archive, + attrs = { + "sha256": attr.string(mandatory = True), + "urls": attr.string_list(mandatory = True, allow_empty = False), + "strip_prefix": attr.string(), + "type": attr.string(), + "delete": attr.string_list(), + "patch_file": attr.label(), + "build_file": attr.label(), + "system_build_file": attr.label(), + "system_link_files": attr.string_dict(), + "additional_build_files": attr.string_dict(), + }, + environ = [ + "TF_SYSTEM_LIBS", + ], +) +"""Downloads and creates Bazel repos for dependencies. + +This is a swappable replacement for both http_archive() and +new_http_archive() that offers some additional features. It also helps +ensure best practices are followed. +""" + +def _third_party_http_archive(ctx): + if ("mirror.tensorflow.org" not in ctx.attr.urls[0] and + (len(ctx.attr.urls) < 2 and + ctx.attr.name not in _SINGLE_URL_WHITELIST.to_list())): + fail("tf_http_archive(urls) must have redundant URLs. The " + + "mirror.tensorflow.org URL must be present and it must come first. " + + "Even if you don't have permission to mirror the file, please " + + "put the correctly formatted mirror URL there anyway, because " + + "someone will come along shortly thereafter and mirror the file.") + + use_syslib = _use_system_lib(ctx, ctx.attr.name) + + # Use "BUILD.bazel" to avoid conflict with third party projects that contain a + # file or directory called "BUILD" + buildfile_path = ctx.path("BUILD.bazel") + + if use_syslib: + if ctx.attr.system_build_file == None: + fail("Bazel was configured with TF_SYSTEM_LIBS to use a system " + + "library for %s, but no system build file for %s was configured. " + + "Please add a system_build_file attribute to the repository rule" + + "for %s." % (ctx.attr.name, ctx.attr.name, ctx.attr.name)) + ctx.symlink(Label(ctx.attr.system_build_file), buildfile_path) + + else: + ctx.download_and_extract( + ctx.attr.urls, + "", + ctx.attr.sha256, + ctx.attr.type, + ctx.attr.strip_prefix, + ) + if ctx.attr.delete: + _apply_delete(ctx, ctx.attr.delete) + if ctx.attr.patch_file != None: + _apply_patch(ctx, ctx.attr.patch_file) + ctx.symlink(Label(ctx.attr.build_file), buildfile_path) + + link_dict = {} + if use_syslib: + link_dict.update(ctx.attr.system_link_files) + + for internal_src, external_dest in ctx.attr.link_files.items(): + # if syslib and link exists in both, use the system one + if external_dest not in link_dict.values(): + link_dict[internal_src] = external_dest + + for internal_src, external_dest in link_dict.items(): + ctx.symlink(Label(internal_src), ctx.path(external_dest)) + +# Downloads and creates Bazel repos for dependencies. +# +# This is an upgrade for tf_http_archive that works with go/tfbr-thirdparty. +# +# For link_files, specify each dict entry as: +# "//path/to/source:file": "localfile" +third_party_http_archive = repository_rule( + implementation = _third_party_http_archive, + attrs = { + "sha256": attr.string(mandatory = True), + "urls": attr.string_list(mandatory = True, allow_empty = False), + "strip_prefix": attr.string(), + "type": attr.string(), + "delete": attr.string_list(), + "build_file": attr.string(mandatory = True), + "system_build_file": attr.string(mandatory = False), + "patch_file": attr.label(), + "link_files": attr.string_dict(), + "system_link_files": attr.string_dict(), + }, + environ = [ + "TF_SYSTEM_LIBS", + ], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/BUILD.tpl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/BUILD.tpl new file mode 100644 index 000000000..e69de29bb diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/build_defs.bzl.tpl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/build_defs.bzl.tpl new file mode 100644 index 000000000..3faa46c58 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/build_defs.bzl.tpl @@ -0,0 +1,32 @@ +# -*- Python -*- +"""Skylark macros for system libraries. +""" + +SYSTEM_LIBS_ENABLED = %{syslibs_enabled} + +SYSTEM_LIBS_LIST = [ +%{syslibs_list} +] + + +def if_any_system_libs(a, b=[]): + """Conditional which evaluates to 'a' if any system libraries are configured.""" + if SYSTEM_LIBS_ENABLED: + return a + else: + return b + + +def if_system_lib(lib, a, b=[]): + """Conditional which evaluates to 'a' if we're using the system version of lib""" + + if SYSTEM_LIBS_ENABLED and lib in SYSTEM_LIBS_LIST: + return a + else: + return b + + +def if_not_system_lib(lib, a, b=[]): + """Conditional which evaluates to 'a' if we're using the system version of lib""" + + return if_system_lib(lib, b, a) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/nsync.BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/nsync.BUILD new file mode 100644 index 000000000..c5d4ad0a7 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/nsync.BUILD @@ -0,0 +1,23 @@ +licenses(["notice"]) # BSD 3-Clause + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nsync_headers", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nsync", + linkopts = ["-lnsync"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "nsync_cpp", + linkopts = ["-lnsync_cpp"], + visibility = ["//visibility:public"], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/protobuf.BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/protobuf.BUILD new file mode 100644 index 000000000..003b4e6da --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/protobuf.BUILD @@ -0,0 +1,104 @@ +load( + "@com_google_protobuf//:protobuf.bzl", + "cc_proto_library", + "proto_gen", + "py_proto_library", +) + +licenses(["notice"]) + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +HEADERS = [ + "google/protobuf/any.pb.h", + "google/protobuf/any.proto", + "google/protobuf/arena.h", + "google/protobuf/compiler/importer.h", + "google/protobuf/descriptor.h", + "google/protobuf/descriptor.pb.h", + "google/protobuf/descriptor.proto", + "google/protobuf/duration.pb.h", + "google/protobuf/duration.proto", + "google/protobuf/dynamic_message.h", + "google/protobuf/empty.pb.h", + "google/protobuf/empty.proto", + "google/protobuf/field_mask.pb.h", + "google/protobuf/field_mask.proto", + "google/protobuf/io/coded_stream.h", + "google/protobuf/io/zero_copy_stream.h", + "google/protobuf/io/zero_copy_stream_impl_lite.h", + "google/protobuf/map.h", + "google/protobuf/repeated_field.h", + "google/protobuf/text_format.h", + "google/protobuf/timestamp.pb.h", + "google/protobuf/timestamp.proto", + "google/protobuf/util/json_util.h", + "google/protobuf/util/type_resolver_util.h", + "google/protobuf/wrappers.pb.h", + "google/protobuf/wrappers.proto", +] + +genrule( + name = "link_headers", + outs = HEADERS, + cmd = """ + for i in $(OUTS); do + f=$${i#$(@D)/} + mkdir -p $(@D)/$${f%/*} + ln -sf $(INCLUDEDIR)/$$f $(@D)/$$f + done + """, +) + +cc_library( + name = "protobuf", + hdrs = HEADERS, + linkopts = ["-lprotobuf"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "protobuf_headers", + hdrs = HEADERS, + linkopts = ["-lprotobuf"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "protoc_lib", + linkopts = ["-lprotoc"], + visibility = ["//visibility:public"], +) + +genrule( + name = "protoc", + outs = ["protoc.bin"], + cmd = "which protoc; pwd; ln -s $$(which protoc) $@", + executable = 1, + visibility = ["//visibility:public"], +) + +cc_proto_library( + name = "cc_wkt_protos", + hdrs = HEADERS, + internal_bootstrap_hack = 1, + protoc = ":protoc", + visibility = ["//visibility:public"], +) + +proto_gen( + name = "protobuf_python_genproto", + includes = ["."], + protoc = "@com_google_protobuf//:protoc", + visibility = ["//visibility:public"], +) + +py_library( + name = "protobuf_python", + data = [":link_headers"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/protobuf.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/protobuf.bzl new file mode 100644 index 000000000..367ac2863 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/protobuf.bzl @@ -0,0 +1,430 @@ +def _GetPath(ctx, path): + if ctx.label.workspace_root: + return ctx.label.workspace_root + "/" + path + else: + return path + +def _IsNewExternal(ctx): + # Bazel 0.4.4 and older have genfiles paths that look like: + # bazel-out/local-fastbuild/genfiles/external/repo/foo + # After the exec root rearrangement, they look like: + # ../repo/bazel-out/local-fastbuild/genfiles/foo + return ctx.label.workspace_root.startswith("../") + +def _GenDir(ctx): + if _IsNewExternal(ctx): + # We are using the fact that Bazel 0.4.4+ provides repository-relative paths + # for ctx.genfiles_dir. + return ctx.genfiles_dir.path + ( + "/" + ctx.attr.includes[0] if ctx.attr.includes and ctx.attr.includes[0] else "" + ) + + # This means that we're either in the old version OR the new version in the local repo. + # Either way, appending the source path to the genfiles dir works. + return ctx.var["GENDIR"] + "/" + _SourceDir(ctx) + +def _SourceDir(ctx): + if not ctx.attr.includes: + return ctx.label.workspace_root + if not ctx.attr.includes[0]: + return _GetPath(ctx, ctx.label.package) + if not ctx.label.package: + return _GetPath(ctx, ctx.attr.includes[0]) + return _GetPath(ctx, ctx.label.package + "/" + ctx.attr.includes[0]) + +def _CcHdrs(srcs, use_grpc_plugin = False): + ret = [s[:-len(".proto")] + ".pb.h" for s in srcs] + if use_grpc_plugin: + ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs] + return ret + +def _CcSrcs(srcs, use_grpc_plugin = False): + ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs] + if use_grpc_plugin: + ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs] + return ret + +def _CcOuts(srcs, use_grpc_plugin = False): + return _CcHdrs(srcs, use_grpc_plugin) + _CcSrcs(srcs, use_grpc_plugin) + +def _PyOuts(srcs, use_grpc_plugin = False): + ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs] + if use_grpc_plugin: + ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs] + return ret + +def _RelativeOutputPath(path, include, dest = ""): + if include == None: + return path + + if not path.startswith(include): + fail("Include path %s isn't part of the path %s." % (include, path)) + + if include and include[-1] != "/": + include = include + "/" + if dest and dest[-1] != "/": + dest = dest + "/" + + path = path[len(include):] + return dest + path + +def _proto_gen_impl(ctx): + """General implementation for generating protos""" + srcs = ctx.files.srcs + deps = [] + deps += ctx.files.srcs + source_dir = _SourceDir(ctx) + gen_dir = _GenDir(ctx) + if source_dir: + import_flags = ["-I" + source_dir, "-I" + gen_dir] + else: + import_flags = ["-I."] + + for dep in ctx.attr.deps: + import_flags += dep.proto.import_flags + deps += dep.proto.deps + import_flags = depset(import_flags).to_list() + deps = depset(deps).to_list() + + args = [] + if ctx.attr.gen_cc: + args += ["--cpp_out=" + gen_dir] + if ctx.attr.gen_py: + args += ["--python_out=" + gen_dir] + + inputs = srcs + deps + tools = [ctx.executable.protoc] + if ctx.executable.plugin: + plugin = ctx.executable.plugin + lang = ctx.attr.plugin_language + if not lang and plugin.basename.startswith("protoc-gen-"): + lang = plugin.basename[len("protoc-gen-"):] + if not lang: + fail("cannot infer the target language of plugin", "plugin_language") + + outdir = gen_dir + if ctx.attr.plugin_options: + outdir = ",".join(ctx.attr.plugin_options) + ":" + outdir + args += ["--plugin=protoc-gen-%s=%s" % (lang, plugin.path)] + args += ["--%s_out=%s" % (lang, outdir)] + tools.append(plugin) + + if args: + ctx.actions.run( + inputs = inputs, + outputs = ctx.outputs.outs, + arguments = args + import_flags + [s.path for s in srcs], + executable = ctx.executable.protoc, + mnemonic = "ProtoCompile", + tools = tools, + use_default_shell_env = True, + ) + + return struct( + proto = struct( + srcs = srcs, + import_flags = import_flags, + deps = deps, + ), + ) + +proto_gen = rule( + attrs = { + "srcs": attr.label_list(allow_files = True), + "deps": attr.label_list(providers = ["proto"]), + "includes": attr.string_list(), + "protoc": attr.label( + cfg = "host", + executable = True, + allow_single_file = True, + mandatory = True, + ), + "plugin": attr.label( + cfg = "host", + allow_files = True, + executable = True, + ), + "plugin_language": attr.string(), + "plugin_options": attr.string_list(), + "gen_cc": attr.bool(), + "gen_py": attr.bool(), + "outs": attr.output_list(), + }, + output_to_genfiles = True, + implementation = _proto_gen_impl, +) +"""Generates codes from Protocol Buffers definitions. + +This rule helps you to implement Skylark macros specific to the target +language. You should prefer more specific `cc_proto_library `, +`py_proto_library` and others unless you are adding such wrapper macros. + +Args: + srcs: Protocol Buffers definition files (.proto) to run the protocol compiler + against. + deps: a list of dependency labels; must be other proto libraries. + includes: a list of include paths to .proto files. + protoc: the label of the protocol compiler to generate the sources. + plugin: the label of the protocol compiler plugin to be passed to the protocol + compiler. + plugin_language: the language of the generated sources + plugin_options: a list of options to be passed to the plugin + gen_cc: generates C++ sources in addition to the ones from the plugin. + gen_py: generates Python sources in addition to the ones from the plugin. + outs: a list of labels of the expected outputs from the protocol compiler. +""" + +def cc_proto_library( + name, + srcs = [], + deps = [], + cc_libs = [], + include = None, + protoc = "@com_google_protobuf//:protoc", + internal_bootstrap_hack = False, + use_grpc_plugin = False, + default_runtime = "@com_google_protobuf//:protobuf", + **kargs): + """Bazel rule to create a C++ protobuf library from proto source files + + NOTE: the rule is only an internal workaround to generate protos. The + interface may change and the rule may be removed when bazel has introduced + the native rule. + + Args: + name: the name of the cc_proto_library. + srcs: the .proto files of the cc_proto_library. + deps: a list of dependency labels; must be cc_proto_library. + cc_libs: a list of other cc_library targets depended by the generated + cc_library. + include: a string indicating the include path of the .proto files. + protoc: the label of the protocol compiler to generate the sources. + internal_bootstrap_hack: a flag indicate the cc_proto_library is used only + for bootstraping. When it is set to True, no files will be generated. + The rule will simply be a provider for .proto files, so that other + cc_proto_library can depend on it. + use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin + when processing the proto files. + default_runtime: the implicitly default runtime which will be depended on by + the generated cc_library target. + **kargs: other keyword arguments that are passed to cc_library. + + """ + + includes = [] + if include != None: + includes = [include] + + if internal_bootstrap_hack: + # For pre-checked-in generated files, we add the internal_bootstrap_hack + # which will skip the codegen action. + proto_gen( + name = name + "_genproto", + srcs = srcs, + deps = [s + "_genproto" for s in deps], + includes = includes, + protoc = protoc, + visibility = ["//visibility:public"], + ) + + # An empty cc_library to make rule dependency consistent. + native.cc_library( + name = name, + **kargs + ) + return + + grpc_cpp_plugin = None + if use_grpc_plugin: + grpc_cpp_plugin = "//external:grpc_cpp_plugin" + + gen_srcs = _CcSrcs(srcs, use_grpc_plugin) + gen_hdrs = _CcHdrs(srcs, use_grpc_plugin) + outs = gen_srcs + gen_hdrs + + proto_gen( + name = name + "_genproto", + srcs = srcs, + deps = [s + "_genproto" for s in deps], + includes = includes, + protoc = protoc, + plugin = grpc_cpp_plugin, + plugin_language = "grpc", + gen_cc = 1, + outs = outs, + visibility = ["//visibility:public"], + ) + + if default_runtime and not default_runtime in cc_libs: + cc_libs = cc_libs + [default_runtime] + if use_grpc_plugin: + cc_libs = cc_libs + ["//external:grpc_lib"] + + native.cc_library( + name = name, + srcs = gen_srcs, + hdrs = gen_hdrs, + deps = cc_libs + deps, + includes = includes, + alwayslink = 1, + **kargs + ) + +def internal_gen_well_known_protos_java(srcs): + """Bazel rule to generate the gen_well_known_protos_java genrule + + Args: + srcs: the well known protos + """ + root = Label("%s//protobuf_java" % (native.repository_name())).workspace_root + pkg = native.package_name() + "/" if native.package_name() else "" + if root == "": + include = " -I%ssrc " % pkg + else: + include = " -I%s/%ssrc " % (root, pkg) + native.genrule( + name = "gen_well_known_protos_java", + srcs = srcs, + outs = [ + "wellknown.srcjar", + ], + cmd = "$(location :protoc) --java_out=$(@D)/wellknown.jar" + + " %s $(SRCS) " % include + + " && mv $(@D)/wellknown.jar $(@D)/wellknown.srcjar", + tools = [":protoc"], + ) + +def internal_copied_filegroup(name, srcs, strip_prefix, dest, **kwargs): + """Macro to copy files to a different directory and then create a filegroup. + + This is used by the //:protobuf_python py_proto_library target to work around + an issue caused by Python source files that are part of the same Python + package being in separate directories. + + Args: + srcs: The source files to copy and add to the filegroup. + strip_prefix: Path to the root of the files to copy. + dest: The directory to copy the source files into. + **kwargs: extra arguments that will be passesd to the filegroup. + """ + outs = [_RelativeOutputPath(s, strip_prefix, dest) for s in srcs] + + native.genrule( + name = name + "_genrule", + srcs = srcs, + outs = outs, + cmd = " && ".join( + ["cp $(location %s) $(location %s)" % + (s, _RelativeOutputPath(s, strip_prefix, dest)) for s in srcs], + ), + ) + + native.filegroup( + name = name, + srcs = outs, + **kwargs + ) + +def py_proto_library( + name, + srcs = [], + deps = [], + py_libs = [], + py_extra_srcs = [], + include = None, + default_runtime = "@com_google_protobuf//:protobuf_python", + protoc = "@com_google_protobuf//:protoc", + use_grpc_plugin = False, + **kargs): + """Bazel rule to create a Python protobuf library from proto source files + + NOTE: the rule is only an internal workaround to generate protos. The + interface may change and the rule may be removed when bazel has introduced + the native rule. + + Args: + name: the name of the py_proto_library. + srcs: the .proto files of the py_proto_library. + deps: a list of dependency labels; must be py_proto_library. + py_libs: a list of other py_library targets depended by the generated + py_library. + py_extra_srcs: extra source files that will be added to the output + py_library. This attribute is used for internal bootstrapping. + include: a string indicating the include path of the .proto files. + default_runtime: the implicitly default runtime which will be depended on by + the generated py_library target. + protoc: the label of the protocol compiler to generate the sources. + use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin + when processing the proto files. + **kargs: other keyword arguments that are passed to cc_library. + + """ + outs = _PyOuts(srcs, use_grpc_plugin) + + includes = [] + if include != None: + includes = [include] + + grpc_python_plugin = None + if use_grpc_plugin: + grpc_python_plugin = "//external:grpc_python_plugin" + # Note: Generated grpc code depends on Python grpc module. This dependency + # is not explicitly listed in py_libs. Instead, host system is assumed to + # have grpc installed. + + proto_gen( + name = name + "_genproto", + srcs = srcs, + deps = [s + "_genproto" for s in deps], + includes = includes, + protoc = protoc, + gen_py = 1, + outs = outs, + visibility = ["//visibility:public"], + plugin = grpc_python_plugin, + plugin_language = "grpc", + ) + + if default_runtime and not default_runtime in py_libs + deps: + py_libs = py_libs + [default_runtime] + + native.py_library( + name = name, + srcs = outs + py_extra_srcs, + deps = py_libs + deps, + imports = includes, + **kargs + ) + +def internal_protobuf_py_tests( + name, + modules = [], + **kargs): + """Bazel rules to create batch tests for protobuf internal. + + Args: + name: the name of the rule. + modules: a list of modules for tests. The macro will create a py_test for + each of the parameter with the source "google/protobuf/%s.py" + kargs: extra parameters that will be passed into the py_test. + + """ + for m in modules: + s = "python/google/protobuf/internal/%s.py" % m + native.py_test( + name = "py_%s" % m, + srcs = [s], + main = s, + **kargs + ) + +def check_protobuf_required_bazel_version(): + """For WORKSPACE files, to check the installed version of bazel. + + This ensures bazel supports our approach to proto_library() depending on a + copied filegroup. (Fixed in bazel 0.5.4) + """ + expected = apple_common.dotted_version("0.5.4") + current = apple_common.dotted_version(native.bazel_version) + if current.compare_to(expected) < 0: + fail("Bazel must be newer than 0.5.4") diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/syslibs_configure.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/syslibs_configure.bzl new file mode 100644 index 000000000..a2d4123fd --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/systemlibs/syslibs_configure.bzl @@ -0,0 +1,171 @@ +# -*- Python -*- +"""Repository rule for system library autoconfiguration. + +`syslibs_configure` depends on the following environment variables: + + * `TF_SYSTEM_LIBS`: list of third party dependencies that should use + the system version instead +""" + +_TF_SYSTEM_LIBS = "TF_SYSTEM_LIBS" + +VALID_LIBS = [ + "absl_py", + "astor_archive", + "boringssl", + "com_github_googleapis_googleapis", + "com_github_googlecloudplatform_google_cloud_cpp", + "com_google_protobuf", + "com_google_protobuf_cc", + "com_googlesource_code_re2", + "curl", + "cython", + "double_conversion", + "flatbuffers", + "gast_archive", + "gif_archive", + "grpc", + "hwloc", + "icu", + "jpeg", + "jsoncpp_git", + "keras_applications_archive", + "lmdb", + "nasm", + "nsync", + "org_sqlite", + "pasta", + "pcre", + "png_archive", + "protobuf_archive", + "six_archive", + "snappy", + "swig", + "termcolor_archive", + "wrapt", + "zlib_archive", +] + +def auto_configure_fail(msg): + """Output failure message when syslibs configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sSystem Library Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _is_windows(repository_ctx): + """Returns true if the host operating system is windows.""" + os_name = repository_ctx.os.name.lower() + if os_name.find("windows") != -1: + return True + return False + +def _enable_syslibs(repository_ctx): + s = repository_ctx.os.environ.get(_TF_SYSTEM_LIBS, "").strip() + if not _is_windows(repository_ctx) and s != None and s != "": + return True + return False + +def _get_system_lib_list(repository_ctx): + """Gets the list of deps that should use the system lib. + + Args: + repository_ctx: The repository context. + + Returns: + A string version of a python list + """ + if _TF_SYSTEM_LIBS not in repository_ctx.os.environ: + return [] + + libenv = repository_ctx.os.environ[_TF_SYSTEM_LIBS].strip() + libs = [] + + for lib in list(libenv.split(",")): + lib = lib.strip() + if lib == "": + continue + if lib not in VALID_LIBS: + auto_configure_fail("Invalid system lib set: %s" % lib) + return [] + libs.append(lib) + + return libs + +def _format_system_lib_list(repository_ctx): + """Formats the list of deps that should use the system lib. + + Args: + repository_ctx: The repository context. + + Returns: + A list of the names of deps that should use the system lib. + """ + libs = _get_system_lib_list(repository_ctx) + ret = "" + for lib in libs: + ret += "'%s',\n" % lib + + return ret + +def _tpl(repository_ctx, tpl, substitutions = {}, out = None): + if not out: + out = tpl.replace(":", "") + repository_ctx.template( + out, + Label("//third_party/systemlibs%s.tpl" % tpl), + substitutions, + False, + ) + +def _create_dummy_repository(repository_ctx): + """Creates the dummy repository to build with all bundled libraries.""" + + _tpl(repository_ctx, ":BUILD") + _tpl( + repository_ctx, + ":build_defs.bzl", + { + "%{syslibs_enabled}": "False", + "%{syslibs_list}": "", + }, + ) + +def _create_local_repository(repository_ctx): + """Creates the repository to build with system libraries.""" + + _tpl(repository_ctx, ":BUILD") + _tpl( + repository_ctx, + ":build_defs.bzl", + { + "%{syslibs_enabled}": "True", + "%{syslibs_list}": _format_system_lib_list(repository_ctx), + }, + ) + +def _syslibs_autoconf_impl(repository_ctx): + """Implementation of the syslibs_configure repository rule.""" + if not _enable_syslibs(repository_ctx): + _create_dummy_repository(repository_ctx) + else: + _create_local_repository(repository_ctx) + +syslibs_configure = repository_rule( + implementation = _syslibs_autoconf_impl, + environ = [ + _TF_SYSTEM_LIBS, + ], +) + +"""Configures the build to link to system libraries +instead of using bundled versions. + +Add the following to your WORKSPACE FILE: + +```python +syslibs_configure(name = "local_config_syslibs") +``` + +Args: + name: A unique name for this workspace rule. +""" diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/tf_dependency/BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/tf_dependency/BUILD new file mode 100644 index 000000000..e69de29bb diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/tf_dependency/BUILD.tpl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/tf_dependency/BUILD.tpl new file mode 100644 index 000000000..c9c3d41ac --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/tf_dependency/BUILD.tpl @@ -0,0 +1,10 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "tf_header_lib", + hdrs = [":tf_header_include"], + includes = ["include"], + visibility = ["//visibility:public"], +) + +%{TF_HEADER_GENRULE} diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/version_check.bzl b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/version_check.bzl new file mode 100644 index 000000000..74feaa19f --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/version_check.bzl @@ -0,0 +1,52 @@ +""" Helpers to check minimum version of bazel.""" + +def _extract_version_number(bazel_version): + """Extracts the semantic version number from a version string + + Args: + bazel_version: the version string that begins with the semantic version + e.g. "1.2.3rc1 abc1234" where "abc1234" is a commit hash. + + Returns: + The semantic version string, like "1.2.3". + """ + for i in range(len(bazel_version)): + c = bazel_version[i] + if not (c.isdigit() or c == "."): + return bazel_version[:i] + return bazel_version + +# Parse the bazel version string from `native.bazel_version`. +# e.g. +# "0.10.0rc1 abc123d" => (0, 10, 0) +# "0.3.0" => (0, 3, 0) +def _parse_bazel_version(bazel_version): + """Parses a version string into a 3-tuple of ints + + int tuples can be compared directly using binary operators (<, >). + + Args: + bazel_version: the Bazel version string + + Returns: + An int 3-tuple of a (major, minor, patch) version. + """ + + version = _extract_version_number(bazel_version) + return tuple([int(n) for n in version.split(".")]) + +def check_bazel_version_at_least(minimum_bazel_version): + if "bazel_version" not in dir(native): + fail("\nCurrent Bazel version is lower than 0.2.1, expected at least %s\n" % minimum_bazel_version) + elif not native.bazel_version: + print("\nCurrent Bazel is not a release version, cannot check for compatibility.") + print("Make sure that you are running at least Bazel %s.\n" % minimum_bazel_version) + return + + if _parse_bazel_version(native.bazel_version) < _parse_bazel_version(minimum_bazel_version): + fail("\nCurrent Bazel version is {}, expected at least {}\n".format( + native.bazel_version, + minimum_bazel_version, + )) + +parse_bazel_version = _parse_bazel_version diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/zlib.BUILD b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/zlib.BUILD new file mode 100644 index 000000000..e35d02812 --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/sample/third_party/zlib.BUILD @@ -0,0 +1,40 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # BSD/MIT-like license (for zlib) + +cc_library( + name = "zlib", + srcs = [ + "adler32.c", + "compress.c", + "crc32.c", + "crc32.h", + "deflate.c", + "deflate.h", + "gzclose.c", + "gzguts.h", + "gzlib.c", + "gzread.c", + "gzwrite.c", + "infback.c", + "inffast.c", + "inffast.h", + "inffixed.h", + "inflate.c", + "inflate.h", + "inftrees.c", + "inftrees.h", + "trees.c", + "trees.h", + "uncompr.c", + "zconf.h", + "zutil.c", + "zutil.h", + ], + hdrs = ["zlib.h"], + copts = [ + "-Wno-shift-negative-value", + "-DZ_HAVE_UNISTD_H", + ], + includes = ["."], +) diff --git a/rfcs/20200624-pluggable-device-for-tensorflow/tutorial.md b/rfcs/20200624-pluggable-device-for-tensorflow/tutorial.md new file mode 100644 index 000000000..3e132b29a --- /dev/null +++ b/rfcs/20200624-pluggable-device-for-tensorflow/tutorial.md @@ -0,0 +1,1000 @@ +# Tutorial: How to create a TensorFlow plugin +1. [Introduction](#Introduction) + +2. [Getting started](#Getting-started) + + 1. [Plugin Implementation](#Plugin-Implementation) + + 1). [Device Runtime](#Device-Runtime) + + 2). [Kernels/Ops](#Kernels/Ops) + + 3). [Graph optimization](#Graph-optimization) + + 4). [Profiler](#Profiler) + + 2. [Plugin build](#Plugin-build) + + 3. [Plugin installation](#[Plugin-installation) + + 4. [Plugin Running](#Plugin-Running) + +# **Introduction** + +This tutorial is intended for those developers who want to extend TensorFlow to support a new device for the current TensorFlow runtime stack through the Modular TensorFlow interface. Plugin provides a decoupled way to add a new device to TensorFlow and has benefits: + + - Simpler process: Does not have to add a new build toolchain to TensorFlow + + - Faster time-to-solution: Does not need code review from the TensorFlow team. + + - Lower maintenance efforts: Only C-API-related changes could break the integration. Unrelated TensorFlow changes would not break the code. + +The article describes how to implement, build, install and run the plugin. The plugin implementation section covers device runtime registration, kernel registration, graph optimizer registration as well as profiler registration. + +Developers are also recommended to read the Modular TensorFlow design RFC to have a better understanding of the whole architecture. + +* [Modular TensorFlow](https://github.com/tensorflow/community/blob/master/rfcs/20190305-modular-tensorflow.md) + +* [Kernel and Op Implementation and Registration API](https://github.com/tensorflow/community/blob/master/rfcs/20190814-kernel-and-op-registration.md) + +* [StreamExecutor C API](https://github.com/tensorflow/community/blob/master/rfcs/20200612-stream-executor-c-api.md) + +* [Adding Pluggable Device for TensorFlow](https://github.com/tensorflow/community/blob/master/rfcs/20200624-pluggable-device-for-tensorflow.md) + +* [Modular TensorFlow Graph C API](https://github.com/tensorflow/community/blob/master/rfcs/20201027-modular-tensorflow-graph-c-api.md) + +* [Modular TensorFlow profiler C API](https://github.com/jzhoulon/community/blob/32df34e83472a6f6eb0655b5822afc498da49dd9/rfcs/20210513-pluggable-profiler-for-tensorflow.md) + +The build environment in this tutorial is based on Linux, however, it is also expected to work on other OS(Windows, MacOS, etc). + +# **Getting started** + +In this section, you will learn how to implement, build, install, and run a plugin. + +## **Plugin Implementation** + +Modular TensorFlow provides a set of C API as an ABI-stable way to register a custom device runtime, kernels/ops, graph optimizer and profiler. This will simplify the distribution of plugins and allow plugin authors to distribute binary artifacts without necessarily publishing plugin source code. + +
+ +
+ +We anticipate three basic functionalities within a device plug-in module: device runtime, kernel/op, graph optimize and profiler. + +### **Device Runtime** + +StreamExecutor is TensorFlow’s main device manager, responsible for work execution and memory management. It provides a set of methods (such as Memcpy) that can be customized for a particular device. Modular TensorFlow proposed a C API wrapper of a subset of methods in StreamExecutorInterface as an ABI-stable way to register a custom StreamExecutor platform. The API can be found in[ tensorflow/c/experimental/stream_executor/stream_executor.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/experimental/stream_executor/stream_executor.h). Plugins need to implement those interfaces declared in this file. + +Here we will introduce how to register a device runtime through StreamExecutor C API. Before that, we will have some conventions: + +* Struct defined in StreamExecutor C API: struct prefix indicates whether fields should be filled by the plugin or core implementation + + * SE_: set/filled by core unless explicitly marked otherwise. + + * SP_: set/filled by plugin unless explicitly marked otherwise. + +* Struct with Plugin prefix: these are structs defined in plugin, plugin can choose whatever name/definition they want. + +* Function with plugin_ prefix: these are functions defined in plugin, plugin can choose whatever function name they want. + +§ **SE_InitPlugin** + +Plugins need to define `SE_InitPlugin` function and populates `SE_PlatformRegistrationParams::SP_Platform` and `SE_PlatformRegistrationParams::SP_PlatformFns`.It is the entry point to initialize the plugin device runtime. When this plugin is loaded by TF(the loading procedure is transparent to user, it will automatically loaded by TF as long as the library installed in site-packages/tensorflow/python/tensorflow-plugins/), `SE_InitPlugin` method will be invoked and a new StreamExecutor platform will be created and registered by Core TensorFlow. + +Example: +```c++ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status) { + std::string type = "MyDevice"; // It is device's type, such as GPU, APU, which is visible in the python front-end. + std::string name = "MyPlatform"; // it is SE platform's name, such as CUDA, ROCM. + // Sets struct_size to a valid value, and zero initializes other attributes. + params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE; + params->platform->type = type.c_str(); + params->platform->name = name.c_str(); + params->platform_fns->get_device_count = plugin_get_device_count; + params->platform_fns->create_device = plugin_create_device; + params->platform_fns->destroy_device = plugin_destroy_device; + params->platform_fns->create_stream_executor = plugin_create_stream_executor; + params->platform_fns->destroy_stream_executor = plugin_destroy_stream_executor; + params->platform_fns->create_timer_fns = plugin_create_timer_fns; + params->platform_fns->destroy_timer_fns = plugin_destroy_timer_fns; + params->destroy_platform = plugin_destroy_platform; + params->destroy_platform_fns = plugin_destroy_platform_fns; +} +``` +As you may see in the example, the plugin needs to populate the platform and platform_fns. + +* `platform->struct_size`: plugin needs to set it as `SP_PLATFORM_STRUCT_SIZE` (defined in stream_executor.h). This field is for the StreamExecutor C API version check between Core TensorFlow and the plugin. + +* `platform->type`: This field allows plugin authors to register a new device type to the Core TensorFlow, such as GPU, APU..,this device type will be visible in the python front-end, for example, user can assign the graph to "device type" through `with tf.device("device type")`. + +* `platform->name`: This field allows plugin authors to register a new StreamExecutor platform name to the Core TensorFlow, such as CUDA, ROCM, this name is not visible in python front-end. Note: this name should be a unique name, you can’t choose a name like "CUDA", “ROCM” which are first party platform names. + +* `platform_fns->get_device_count`: a callback for querying the number of physical devices discovered by the plugin's device runtime. +```c++ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +void plugin_get_device_count() { + int device_count; + pluginGetDeviceCount(&device_count); + return device_count; +} +``` +* `platform_fns->create_device`: a callback for creating `SP_Device`. plugin authors need to define the function that populate the `SP_Device`: +```c++ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +void plugin_create_device(const SP_Platform* platform, + SE_CreateDeviceParams* params, TF_Status* const status) { + params->device->struct_size = SP_DEVICE_STRUCT_SIZE; + PluginDeviceHandle* device_h; + plugin_get_device(&device_h, params->device->ordinal); + params->device->device_handle = static_cast(device_h); + params->device->ordinal = params->ordinal; +} +``` +* `platform_fns->destroy_device`: a callback for destroying `SP_Device`. plugin authors need to define the function that to destroy the `SP_Device`: +```c++ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +void plugin_destroy_device(const SP_Platform* platform, SP_Device* device) { + device->device_handle = nullptr; + device->ordinal = -1; +} +``` +* `platform_fns->create_stream_executor`: a callback for creating `SP_StreamExecutor`. plugin authors need to define a function that populates `SP_StreamExecutor`. +```c++ +void plugin_create_stream_executor(const SP_Platform* platform, + SE_CreateStreamExecutorParams* params, + TF_Status* const status) { + params->stream_executor->struct_size = SP_STREAMEXECUTOR_STRUCT_SIZE; + params->stream_executor->allocate = plugin_allocate; + params->stream_executor->deallocate = plugin_deallocate; + params->stream_executor->host_memory_allocate= plugin_host_memory_allocate; + params->stream_executor->host_memory_deallocate = plugin_host_memory_deallocate; + params->stream_executor->get_allocator_stats = plugin_get_allocator_stats; + params->stream_executor->device_memory_usage = plugin_device_memory_usage; + params->stream_executor->create_stream = plugin_create_stream; + params->stream_executor->destroy_stream = plugin_destroy_stream; + params->stream_executor->create_stream_dependency = plugin_create_stream_dependency; + params->stream_executor->get_stream_status = plugin_get_stream_status; + params->stream_executor->create_event = plugin_create_event; + params->stream_executor->destroy_event = plugin_destroy_event; + params->stream_executor->get_event_status = plugin_get_event_status; + params->stream_executor->record_event = plugin_record_event; + params->stream_executor->wait_for_event = plugin_wait_for_event; + params->stream_executor->create_timer = plugin_create_timer; + params->stream_executor->destroy_timer = plugin_destroy_timer; + params->stream_executor->start_timer = plugin_start_timer; + params->stream_executor->stop_timer = plugin_stop_timer; + params->stream_executor->memcpy_dtoh = plugin_memcpy_dtoh; + params->stream_executor->memcpy_htod = plugin_memcpy_htod; + params->stream_executor->memcpy_dtod = plugin_memcpy_dtod; + ... ... +} +``` +plugin authors need to populate all fields in `SP_StreamExecutor`. For example, registering an allocation function with `plugin_allocate`, it synchronously allocates 'size' of bytes on the underlying platform and returns `SP_DeviceMemoryBase` representing that allocation. +```c++ +/*StreamExecutor Backend Impl*/ + +void plugin_allocate(const SP_Device* device, uint64_t size, int64_t memory_space, + SP_DeviceMemoryBase* mem) { + PluginDevice* device_handle = static_cast(device->device_handle); + mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE; + mem->opaque = plugin_malloc(device_handle, size); + mem->size = size; +} +``` +If the backend doesn't support this functionality, plugin authors can provide a dummy function + +* `platform_fns->destroy_stream_executor`: clean up fields inside `SP_StreamExecutor` that were allocated by the plugin. `stream_executor` itself should not be deleted here. +```c++ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +void plugin_destroy_stream_executor(const SP_Platform* platform, + SP_StreamExecutor* stream_executor) { + stream_executor->allocate = nullptr; + stream_executor->deallocate = nullptr; + stream_executor->host_memory_allocate = nullptr; + stream_executor->host_memory_deallocate = nullptr; + stream_executor->get_allocator_stats = nullptr; + stream_executor->device_memory_usage = nullptr; + ... ... +} +``` +* `platform_fns-> create_timer_fns`: creating `SP_Timer`. Allocates timer resources on the underlying platform and initializes its internals, setting 'timer' output variable. You can provide a dummy function if you don’t need this. + +* `platform_fns->destroy_timer_fns`: destroy `SP_Timer` and deallocate timer resources on the underlying platform. You can provide a dummy implementation if you don't need this. + +* `platform_fns->destroy_platform`: clean up fields inside `SP_Platform` that were allocated by the plugin. platform itself should not be deleted here. + +* `platform_fns->destroy_platform_fns`: clean up fields inside `SP_PlatformFns`. + +### **Kernels/Ops** + +Modular TensorFlow provides a set of C APIs as the ABI-stable API for implementing kernels and ops. The intention is that existing kernels should be able to be ported to the new APIs with a minimum of reimplementation effort. The ops C API can be found in[ tensorflow/c/ops.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/ops.h) and kernels C API can be found in[ tensorflow/c/kernels.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/kernels.h).[ tensorflow/c/tf_tensor.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_tensor.h),[ tensorflow/c/tf_status.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_status.h). + +Plugin authors need to define the `TF_InitKernel` function (include Ops/Kernels registration). When the plugin is loaded by TF at runtime, `TF_InitKernel` method will be called and new Ops/Kernels will be registered to Core TensorFlow. + +§ **Ops registration** + +This section introduces how to register a new op to Core TensorFlow. In the C++ API, ops are registered at static initialization time using the `REGISTER_OP` macro. For example: +```c++ +REGISTER_OP("Bitcast") +.Input("input: T") +.Output("output: type") +.Attr("T: {bfloat16, ...}") +.Attr("type: {bfloat16, ...}") +.SetShapeFn([](InferenceContext* ctx) { ... }) +.Doc("A bitcast operator"); +``` +The equivalent C API will be a series of functions that operate on `TF_OpDefinitionBuilder*`, a pointer to an opaque struct (i.e. a struct whose content is not made known to the plugin authors). The functions include, but not limited to: + +* `TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name)`: constructs and returns a new op registration builder for an op with the given name. + +* `void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder, const char* attr)`: adds the given attribute to the builder(equivalent to Attr above). + +* `void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder, const char* input)`: adds the given input to the builder(equivalent to Input above). + +Additional functions are provided for setting other properties of the operation (e.g. `TF_OpDefinitionBuilderSetIsCommutative`). + +Registration is then actually performed using the `TF_RegisterOpDefinition` function. This function populates a `TF_Status` indicating whether registration was successful and frees the resources associated with the op definition builder. + +The C equivalent of the bitcast op registration example above is shown below: +```c++ +#include "tensorflow/c/ops.h" +#include "tensorflow/c/kernels.h" + +void InferBitcastShape(TF_ShapeInferenceContext* ctx, // see the section below on + TF_Status* status); // shape inference + +void PluginRegisterBitCastOp() { + TF_OpDefinitionBuilder* b = TF_NewOpDefinitionBuilder("Bitcast"); + TF_OpDefinitionBuilderAddInput(b, "input: T"); + TF_OpDefinitionBuilderAddOutput(b, "output: type"); + TF_OpDefinitionBuilderAddAttr(b, "T: {bfloat16, ...}"); + TF_OpDefinitionBuilderAddAttr(b, "type: {bfloat16, ...}"); + TF_OpDefinitionBuilderSetShapeInferenceFunction(b, &InferBitcastShape); + TF_Status* status = TF_NewStatus(); + TF_RegisterOpDefinition(b, status); + if (TF_GetCode(status) != TF_OK) { /* handle errors */ } +} + +void TF_InitKernel() { + PluginRegisterBitCastOp(); +} +``` +§ **Ops shape inference** + +A significant feature of certain ops is their ability to infer their output shapes. TensorFlow will invoke the registered shape inference function (if one is provided) when it needs to know the op’s output shape. The registration function declaration is shown below: + +A series of functions prefixed with `TF_ShapeInferenceContext` is provided for the following purposes: + +* Examining operator input shapes (`TF_ShapeInferenceContextGetInput`). + +* Creating and deleting shape and dimension handles (`TF_{New,Delete}ShapeHandle`, `TF_{New,Delete}DimensionHandle`). + +* Manipulating shape and dimension handles (`TF_ShapeInferenceContextWithRank`, `TF_ShapeInferenceContextDim`). + +In general, C analogues to the C++ methods in `tensorflow::shape_inference` (see[ tensorflow/core/framework/shape_inference.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/shape_inference.h)) will be provided. + +§ **Kernels implementation and registration.** + +In this section, you will learn how to implement kernels and register them to Core TensorFlow. Here we will use Conv2D as the example. + +***Kernel Implementation*** + +The main classes for C++ kernel implementations are `OpKernelConstruction` (provided by TensorFlow to the kernel's constructor) and `OpKernelContext` (provided to the kernel's compute method). The analogues in the C API are `TF_OpKernelConstruction` and `TF_OpKernelContext`.The aim of the C API is providing functions for working with these structs that match, as closely as possible, the C++ API. +See below for an example of Conv2D kernel with the C++ API: +```c++ +struct Conv2DParameters { + std::vector dilations; + std::vector strides; + Padding padding; + std::vector explicit_paddings; +}; + +template +class Conv2DOp : public BinaryOp { +public: + explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp(context) {} + void Compute(OpKernelContext* context) override {} +private: + Conv2DParameters params_; +} +``` +Above code shows a prototype of Conv2D C++ kernel, basically we can find that it has a constructor, a compute function and a parameter struct. The C equivalent Conv2D op can be: +```c++ +#include "tensorflow/c/kernels.h" + +struct Conv2DParameters { + std::vector dilations; + std::vector strides; + Padding padding; + std::vector explicit_paddings; +}; + +typedef struct Conv2DOp{ + Conv2DParameters params_; +}; + +void* Conv2DOp_Create(Conv2DOp* kernel, TF_OpKernelConstruction* ctx); + +template +void Conv2DOp_Compute(void* kernel, TF_OpKernelContext* ctx); + +void Conv2DOp_Destroy(void* kernel) +``` +Usually, plugin authors need to provide three functions: a creation function, a compute function and a deletion function. Compute function is a must, creation function and deletion functions are optional but if a creation is provided that causes memory allocation, a deletion function that frees the memory should also be provided, otherwise a leak will occur. + +* **Creation function(optional)**: responsible for creating a kernel, allocating private resources (such as memory), and storing attributions (if it has) retrieved from `TF_OpKernelConstruction` to the kernel. Core TensorFlow will call this function when it needs to instantiate the kernel. The `TF_OpKernelConstruction` pointer is owned by TensorFlow and will be deleted once the creation function returns. + +* **Compute function**: responsible for retrieving inputs and a compute stream and producing outputs. Core TensorFlow will call this function when needed to perform a computation with this kernel. + +* **Destroy function(optional)**: responsible for destroying the kernel and free the resource allocated in the creation function. When TensorFlow no longer needs the kernel, it will call this function if one is provided. This function will retrieve the pointer returned in the creation function or nullptr if no creation function was provided. + +Here we will show how to use kernel C APIs to implement these functions: + + **Creation function** + +In the C++ API, kernel’s attributions are retrieved through the `GetAttr` method in `OpKernelConstruction`. +```c++ +explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp(context) { + TF_RETURN_IF_ERROR(context->GetAttr("dilations", ¶ms_.dilations)); + TF_RETURN_IF_ERROR(context->GetAttr("strides", ¶ms_.strides)); + TF_RETURN_IF_ERROR(context->GetAttr("padding", ¶ms_.padding)); + if (context->HasAttr("explicit_paddings")) { + TF_RETURN_IF_ERROR( + context->GetAttr("explicit_paddings", ¶ms_.explicit_paddings)); + } + ... ... +} +``` +Kernel C API provides a set of `TF_OpKernelConstruction_GetAttrXX` API to retrieve attributions from `TF_OpKernelConstruction`. These APIs can be separated into four categories according to the attribution’s container: + +1. Scalar + +`TF_OpKernelConstruction_GetAttr(Type, Float,Int32, Int64, Bool…)` interprets the named kernel construction attribute as scalar value and places it into *val, float for example: +```c++ +float value; +TF_OpKernelConstruction_GetAttrFloat(ctx, "float_attr", &val, status); +``` +2. Vector + +`TF_OpKernelConstruction_GetAttr(Type, Float, Int32, Int64, Bool…)List` interprets the named kernel construction as a (Type, Float, Int32, Int64, Bool) array and places it into *vals. vals must point to an array of length at least `max_values` (ideally set to the list_size from `TF_OpKernelConstruction_GetAttrSize()`). +```c++ +int32_t list_size = 0; +int32_t total_size = 0; +TF_OpKernelConstruction_GetAttrSize(ctx, "vector_float_attr", + &list_size, &total_size, status); +std::vector values(list_size); +TF_OpKernelConstruction_GetAttrFloatList(ctx, "vector_float_attr", + values.data(), list_size, status); +``` +3. String + +`TF_OpKernelConstruction_GetAttrString` interprets the named kernel construction attribute as string and places it into *val. vals must point to an array of length at least 'max_length' (ideally set to total_size from `TF_OpKernelConstruction_GetAttrSize()`). +``` +int32_t list_size = 0; +int32_t total_size = 0; +TF_OpKernelConstruction_GetAttrSize(ctx, "string_attr", &list_size, + &total_size, status); +std::vector val(total_size); +TF_OpKernelConstruction_GetAttrString(ctx, "string_attr", val.data(), + total_size, status); +std::string value = std::string(val.data(), total_size); +``` +4. Vector of strings + +`TF_OpKernelConstruction_GetAttrStringList` interprets the named kernel construction attribute as string array and fills in `vals` and `length`, each of which must point to an array of length at least `max_values`. The elements of values will point to addresses in `storage` which must be at least `storage_size` bytes in length. Ideally, `max_values` would be set to list_size and `storage` would be at least total_size, obtained from `TF_OpKernelConstruction_GetAttrSize()`. +```c++ +int32_t list_size = 0; +int32_t total_size = 0; +TF_OpKernelConstruction_GetAttrSize(ctx, "vector_string_attr", + &list_size, &total_size, status); +std::unique_ptr lens(new size_t[list_size]); +std::unique_ptr storage(new char[total_size]); +size_t storage_size(total_size); +TF_OpKernelConstruction_GetAttrStringList(ctx, "vector_string_attr", +reinterpret_cast(vals.get()), lens.get(),list_size, storage.get(), +storage_size, status); +for (size_t i = 0; i < list_size; ++i) { + (*value)[i] = string(static_cast(vals[i]), lens[i]); +} +``` +With these C APIs, we can retrieve Conv2D kernel's attributions from `TF_OpKernelConstruction`, see below for an example of creating a Conv2D kernel with C API. In this example, we use a series of C API for retrieving `std::vector`, `std::vector` and `std::string` attributions from `TF_OpKernelConstruction`. We also use a series of C APIs for error handling (`TF_NewStatus`, `TF_GetCode`, `TF_DeleteStatus`). +```c++ +void* Conv2D_Create(Conv2D* kernel, TF_OpKernelConstruction* ctx) { + auto* kernel = new Conv2DOp; + TF_Status* s = TF_NewStatus(); + // C++: context->GetAttr("dilations", ¶ms.dilations); + int32_t list_size = 0; + int32_t total_size = 0; + TF_OpKernelConstruction_GetAttrSize(ctx, "dilations", &list_size, &total_size, s); + if (TF_GetCode(s) == TF_OK) { + kernel->dilations_.resize(list_size); + TF_OpKernelConstruction_GetAttrInt32List(ctx, "dilations", kernel->dilations.data(), list_size, s); + } + + // C++: context->GetAttr("strides", ¶ms.strides); + if (TF_GetCode(s) == TF_OK) { + list_size = total_size = 0; + TF_OpKernelConstruction_GetAttrSize(ctx, "strides", &list_size, &total_size, s); + if (TF_GetCode(s) == TF_OK) { + kernel->strides_.resize(list_size); + TF_OpKernelConstruction_GetAttrInt32List(ctx, "strides", kernel->strides.data(), list_size, s); + } + } + + // C++: context->GetAttr("padding", ¶ms.padding) + if (TF_GetCode(s) == TF_OK) { + list_size = total_size = 0; + TF_OpKernelConstruction_GetAttrSize(ctx, "padding", &list_size, &total_size, s); + if (TF_GetCode(s) == TF_OK) { + std::vector val(total_size); + TF_OpKernelConstruction_GetAttrString(ctx, "padding", val.data(), total_size, s); + std::string padding_str = std::string(val.data(), total_size); + if (padding_str == "VALID") { + kernel->padding_ = Padding::VALID; + } elif(padding_str == "SAME") { + kernel->padding_ = Padding::SAME; + } elif(padding_str == "EXPLICIT") { + kernel->padding_ = Padding::EXPLICIT; + } + } + + } + + // C++: context->HasAttr("explicit_padding") + + if (TF_GetCode(s) == TF_OK) { + if (TF_OpKernelConstruction_HasAttr(ctx, "explicit_paddings", s)) { + list_size = total_size = 0; + TF_OpKernelConstruction_GetAttrSize(ctx, "explicit_paddings", &list_size, &total_size, s); + kernel->explicit_paddings_.resize(list_size); + TF_OpKernelConstruction_GetAttrInt64List(ctx, "explicit_paddings", kernel->explicit_paddings_.data(), list_size, s); + } + } + + if (TF_GetCode(s) != TF_OK) { + TF_OpKenrelConstruction_Failure(ctx, s); + delete kernel; + kernel = nullptr; + } + + TF_DeleteStatus(s); + return kernel; + +} +``` + **Compute function** + +Basically, compute functions are able to retrieve their input tensors and provide output tensors. In the C++ API, the `tensorflow::OpKernelContext::input` and `setoutput` family of functions provide this functionality. The equivalent C calls will be the `TF_GetInput` and `TF_SetOutput` family of functions. These C functions operate on `TF_Tensor`. Besides, the kernel C API provides `TF_GetStream()` for retrieving a computation stream, which allows kernels submitted to the hardware. + +In the C++ API, `OpKernelContext` provides a set of functions to retrieve input tensors, shapes, stream as well as allocate output tensors or forward input to output tensor. A simple Conv2D compute function with C++ API can be like: +```c++ +void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& filter = context->input(1); + Tensor* output = nullptr; + TensorShape out_shape = ComputeConv2DShape(params_, input, filter); + + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + gpuStream_t* stream = reinterpret_cast( + context->op_device_context()->stream()->implementation()->GpUStreamMemberHack()) + GpuLaunchKernel(conv_kernel, grid_dim, block_dim, 0, stream, input.data(), + filter.data(), output.data(), input_shape...) +} +``` +The equivalent OpKernelContext C functions provided by Modular TensorFlow are: + +* `TF_GetInput()`: retrieves the ith input from ctx. + +* `TF_NumInputs()`: returns the number of inputs available in ctx. + +* `TF_NumOutputs()`: returns the number of outputs to be placed in *ctx by the kernel. + +* `TF_SetOutput()`: Sets the ith output of ctx to tensor. + +* `TF_AllocateOutput()`: allocates Tensor for output at given index. + +* `TF_ForwardInputOrAllocateOutput()`: tries to forward one of the inputs given in input_indices to output[output_index]. + +* `TF_AllocateTmp()`: Allocates a temporary Tensor of the specified type and shape. + +* `TF_GetStream()`: returns the SP_Stream available in ctx.[tensorflow/c/tf_tensor.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_tensor.h) also provides some C API for manipulate TF_Tensor: + +* `TF_NewTensor()`: return a new tensor that holds the bytes data[0, len-1]; + +* `TF_DeleteTensor()`: destroy a tensor. + +* `TF_TensorType()`: return the type of a tensor element. + +* `TF_NumDims()`: return the number of dimensions that the tensor has. + +* `TF_Dim()`: return the length of the tensor in the "dim_index" dimension. + +* `TF_TensorByteSize()`: return the size of the underlying data in bytes. + +* `TF_TensorData()`: return a pointer to the underlying data buffer. + +* `TF_TensorElementCount()`: returns the number of elements in the tensor. + +* `TF_TensorBitcastFrom()`: copy the internal data representation of `from` to `to`. `new_dims` and `num_new_dims` specify the new shape of the `to` tensor, `type` specifies its data type. + +* `TF_TensorIsAligned()`: return bool if this tensor is aligned. + +**It should be noted that**: when you call functions that deal with `TF_Tensor` on `TF_OpKernelContext`, such as :`TF_GetInput`, `TF_AllocateOutput`, `TF_ForwardInputOrAllocateOutput`, `TF_AllocateTmp`, you are creating a new `TF_Tensor` indeed, so you need to call `TF_DeleteTensor()` to delete these `TF_Tensor` manually at the exit of compute function, or you will get mem leak since when creating `TF_Tensor` based on `tensorflow::Tensor` in `OpKernelContext`, it will increase the ref count in the C++ Tensor and the tensor will not be freed if these `TF_Tensors` are not deleted. + +With these C APIs, we can retrieve the input tensors and a computation stream, do the computation and then produce the output tensors. See below for an example of computing a Conv2D kernel, you may also notice that when the computation is finished, we need to delete the input, filter, output tensors manually. +```c++ +template +void Conv2D_Compute(void* kernel, TF_OpKernelContext* ctx) { + auto op_kernel = static_cast(kernel); + TF_Status* s = TF_NewStatus(); + auto stream = TF_GetStream(ctx, s); + if (TF_GetCode(s) != TF_OK) { + TF_OpKernelContext_Failure(ctx, s); + return; + } + TF_Tensor* input, filter; + TF_GetInput(ctx, 0, &input, s); + TF_GetInput(ctx, 1, &filter, s); + TF_Tensor* output = nullptr; + PluginTensorShape out_shape = ComputeConv2DShape(op_kernel->params_, input, filter); + + auto output_type = TF_ExpectedOutputDataType(ctx, 0); + output = TF_AllocateOutput(ctx, 0, static_cast(out_type), + shape.dims_size().data(), shape.dims(), shape.num_elements() * DataTypeSize(out_type), s); + plugin_launch_kernel(conv_kernel, stream, TF_TensorData(input), TF_TensorData(filter), + TF_TensorData(output), shape); + if (TF_GetCode(s) != TF_OK) { + TF_OpKernelContext_Failure(ctx, s); + } + TF_DeleteStatus(s); + TF_DeleteTensor(input); + TF_DeleteTensor(filter); + TF_DeleteTensor(output); +} +``` +**Destroy function** + +When Tensorflow no longer needs the kernel, it will call the destructor function in the OpKernel to release the resources created in the constructor. In plugin, we need to implement and register a destroy function to release those resources. +```c++ +void Conv2DOp_Destroy(void* kernel) { +if (kernel != nullptr) { + delete static_cast(kernel); +} +} +``` +* **Kernel Registration** + +After implementing a kernel, we need to register this kernel to the Core TensorFlow so that it can be dispatched at runtime. Kernel registration with the C++ API is accomplished with the `REGISTER_KERNEL_BUILD` macro. This macro expands to code that relies on static initialization to register the provided kernel with the global kernel registry. See below for an example of registering a kernel with the C++ API: +```c++ +REGISTER_KERNEL_BUILDER( + Name("Conv2D").Device(DEVICE_GPU).TypeConstraint("T"), + Conv2DOp); +``` +The equivalent C API provides a series of functions that operate on `TF_KernelBuilder`, an opaque struct obtained with the `TF_NewKernelBuilder` call. The kernel builder is registered with TensorFlow using the `TF_RegisterKenrelBuilder` function. See below for an example of registering the conv kernel using the C API: +```c++ +template +void RegisterConv2DKernel() { + TF_Status* s = TF_NewStatus(); + auto* builder = TF_NewKernelBuilder("Conv2D", "MY_DEVICE", &Conv2D_Create, &Conv2D_Compute, &Conv2D_Destroy); + TF_KernelBuilder_TypeConstraint(builder, "T", static_cast(DataTypeToEnum::v()), s) + if (TF_GetCode(s) != TF_OK()) {/* handle errors*/} + TF_RegisterKernelBuilder("Conv2D", builder, s); + if (TF_GetCode(s) != TF_OK()) {/* handle errors*/} + TF_DeleteStatus(s); +} + +void TF_InitKernel() { + RegisterConv2DKenrel(); + +} +``` +The registration function prototypes are provided below. Kernel authors must provide a compute function. creation and destroy functions are optional, but if a creation function is provided that causes memory allocation, a destroy function that frees the memory should be provided, otherwise a leak will occur. +```c++ +TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)); + +void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, TF_Status* status); +``` +### **Graph optimization** + +Modular TensorFlow provides a new mechanism for custom graph optimizers and a set of C APIs as the ABI-stable APIs for implementing graph optimizers. +The C APIs follows current C++ API implementation, [TF_Buffer](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/c/c_api.h#L110-L114) and related proto files are the interface between proper and plugin. +When initializing, TensorFlow loads the plugin and registers a new graph optimizer into Grappler. In the [Optimize](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/c/experimental/grappler/grappler.h#L134) function, plugin authors need to deserialize `TF_Buffer` to `plugin::GraphDef` object to do some graph transformations, and serialize the optimized `plugin::GraphDef` object back to `TF_Buffer` as output. Note that the graph in this part is all represented by GraphDef/TF_Buffer, not [graph](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/graph/graph.h#L498). +The graph C APIs can be found in [grappler.h](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/c/experimental/grappler/grappler.h). + +We will introduce graph optimization C APIs from the following three aspects: optimize registration, implementation and util function. + +

+ +

+ +§ **Optimizer registration** + +Plugins need to define the `TF_InitGraph` function and populate `TP_OptimizerRegistrationParams`. +When the plugin is loaded by TF at runtime, `TF_InitGraph` method will be called and new plugin optimizers will be registered to Core TensorFlow. + +Example: +```c++ +#include "tensorflow/c/experimental/grappler/grappler.h" + +void TF_InitGraph(TP_OptimizerRegistrationParams* params, + TF_Status* status) { + params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE; + params->device_type = "CPU"; + + params->optimizer_configs->struct_size = TP_OPTIMIZER_CONFIGS_STRUCT_SIZE; + params->optimizer_configs->remapping = TF_TriState_Off; + params->optimizer_configs->layout_optimizer = TF_TriState_Off; + + params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE; + params->optimizer->create_func = Optimizer_Create; + params->optimizer->optimize_func = Optimizer_Optimize; + params->optimizer->destroy_func = Optimizer_Destroy; +} +``` + +As you may see in the example, the plugin needs to populate the `optimizer_configs` and `optimizer`. + +* `struct_size`: plugin needs to set it as `TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE` (defined in grappler.h). This field is used for the Graph C API version check between Core TensorFlow and the plugin. + +* `device_type`: This field indicates the backend device type that the graph optimizer is targeting. + +* `optimizer_configs->remapping`: This field indicates whether the remapping optimizer in Tensorflow proper should be disabled. It is a tri-state enum value `TF_TriState`, and the default value is on. Each optimizer defined in TensorFlow proper has a competitive config value. Detailed configuration of these optimizers can be seen in [grappler.h](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/c/experimental/grappler/grappler.h#L98-L115). + +* `optimizer->create_func`: This field is an optional function for creating an optimizer. Destroy functions are also optional. But if a creation is provided that causes memory allocation, a deletion function that frees the memory should also be provided, otherwise a leak will occur. + +* `optimizer->optimize_func`: This field is the main part of the optimizer. Core TensorFlow will call this function to perform a graph transformation. + +§ **Optimizer implementation** + +Graph Optimize function(`optimize_func`) is the main part that plugin authors need to implement. The function looks like below. The first param is an optimizer pointer created by `create_func`, or a nullptr if `create_func` is not provided. The second param is serialized input graph(`GraphDef`). The third param is the input `TF_GrapplerItem` handle which contains feed/fetch nodes info. The fourth param is serialized output graph(`GraphDef`). + +```cpp +void Optimizer_Optimize(void* optimizer, const TF_Buffer* graph_buf, const TF_GrapplerItem* item, + TF_Buffer* optimized_graph_buf, TF_Status* s); +``` + + +Example: +```cpp +void Optimizer_Optimize(void* optimizer, const TF_Buffer* graph_buf, const TF_GrapplerItem* item, + TF_Buffer* optimized_graph_buf, TF_Status* tf_status) { + + // Deserialize input graph + plugin::GraphDef graph_def; + BufferToMessage(graph_buf, graph_def); + + Status status; + // Create a GraphView object which provides helper functions to modify the graph. + GraphView graph_view(graph_def, status); + const int num_nodes = graph_def.node_size(); + for (int i = num_nodes - 1; i >= 0; --i) { + // Fetch a node. + const auto* node_view = graph_view.GetNode(i); + const auto* node_def = node_view->node(); + + // Create a new node. + NodeDef new_node; + new_node.set_name(node_def.name()); + new_node.set_op(node_def.name()); + + // Add new nodes into the graph. + Mutation* mutation = graph_view.GetMutationBuilder(); + mutation->AddNode(std::move(new_node), &status); + mutation->Apply(); + } + + // Serialize output graph. + plugin::GraphDef optimized_graph_def = graph_def; + MessageToBuffer(optimized_graph_def, optimized_graph_buf); +} +``` + +* `plugin::GraphDef`: This is a C++ object generated by protobuf toolchain with a predefined structure in graph.proto. Note that the namespace has changed from `tensorflow::` to `plugin::`, which means it is a class defined in plugin. Plugin should maintain protobuf toolchain and graph.proto files. They should copy graph.proto from tensorflow proper and change the package name to `plugin`. + + Here lists all proto files needed in plugin: + - [attr_value.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/attr_value.proto): AttrValue, NameAttrList + - [cost_graph.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/cost_graph.proto): CostGraphDef + - [function.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/function.proto): FunctionDefLibrary, FunctionDef, GradientDef + - [graph.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/graph.proto): GraphDef + - [node_def.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/node_def.proto): NodeDef + - [op_def.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/op_def.proto): OpDef, OpDeprecation, OpList + - [op_performance_data.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/grappler/costs/op_performance_data.proto): SessionInfo, OpInfo, NormalDistribution, LogNormalDistribution, OpPerformance, OpPerformanceList + - [resource_handle.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/resource_handle.proto): ResourceHandleProto + - [tensor.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/tensor.proto): TensorProto, VariantTensorDataProto + - [tensor_shape.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/tensor_shape.proto): TensorShapeProto + - [types.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/types.proto): DataType, SpecializedType + - [versions.proto](https://github.com/tensorflow/tensorflow/blob/r2.5/tensorflow/core/framework/versions.proto): VersionDef + +* `BufferToMessage`, `MessageToBuffer`: They are serialization/deserialization functions for `TF_Buffer` and protobuf objects(e.g., `GraphDef`). Plugin can deserialize input graph(`TF_Buffer`) to the plugin `GraphDef` object, and serialize the output `GraphDef` object when graph transformation is finished. + + Example: + ```cpp + Status MessageToBuffer(const protobuf::MessageLite& in, TF_Buffer* out) { + if (out->data != nullptr) { + return errors::InvalidArgument("Passing non-empty TF_Buffer is invalid."); + } + const size_t proto_size = in.ByteSizeLong(); + void* buf = malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '", + in.GetTypeName(), "' and size ", proto_size); + } + if (!in.SerializeWithCachedSizesToArray(static_cast(buf))) { + free(buf); + return errors::InvalidArgument( + "Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", proto_size, + " bytes) is too large?"); + } + out->data = buf; + out->length = proto_size; + out->data_deallocator = [](void* data, size_t length) { free(data); }; + return Status::OK(); + } + + Status BufferToMessage(const TF_Buffer* in, protobuf::MessageLite& out) { + if (in == nullptr || !out.ParseFromArray(in->data, in->length)) { + return errors::InvalidArgument("Unparsable proto"); + } + return Status::OK(); + } + ``` + +* `GraphView`, `Mutation`: These are helper classes provided by TensorFlow in [tensorflow/core/grappler/utils](https://github.com/tensorflow/tensorflow/tree/r2.5/tensorflow/core/grappler/utils) folder to modify `GraphDef` objects. Plugin authors can manually copy this part into the plugin side, or they can write their own util functions. + +§ **Optimizer util functions** + +Modular TensorFlow provides three opaque handles, i.e., `TF_GrapplerItem`, `TF_GraphProperties` and `TF_FunctionLibraryDefinition`, and related C APIs for retrieving necessary graph information: + - `TF_GrapplerItem` represents a combination of a graph, and some more information about feed/fetch nodes, preserved nodes. + - `TF_GetNodesToPreserveListSize()`,`TF_GetNodesToPreserveList()`: Get a set of preserved node names which can not be transformed or removed during the graph transformation. This includes feed and fetch nodes, keep_ops, init_ops. + - `TF_GetFetchNodesListSize()`,`TF_GetFetchNodesList()`: Get a set of node names for fetch nodes. + + An example of how to get a set of preserved nodes: + + ```cpp + void Optimizer_Optimize(void* optimizer, const TF_Buffer* graph_buf, const TF_GrapplerItem* item, + TF_Buffer* optimized_graph_buf, TF_Status* tf_status) { + TF_GrapplerItem* item; + TF_Status* status = TF_NewStatus(); + int num_values = 0, storage_size = 0; + TF_GetNodesToPreserveListSize(item, &num_values, &storage_size, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) + << "Error for TF_GetNodesToPreserveListSize"; + + std::unique_ptr values(new char*[num_values]); + std::unique_ptr lens(new size_t[num_values]); + std::unique_ptr storage(new char[storage_size]); + TF_GetNodesToPreserveList( + item, reinterpret_cast(values.get()), lens.get(), num_values, + reinterpret_cast(storage.get()), storage_size, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error for TF_GetNodesToPreserveList"; + + std::unordered_set nodes; + for (int32_t i = 0; i < num_values; ++i) { + nodes.insert(string(values[i], lens[i])); + } + TF_DeleteStatus(status); + } + ``` + +- `TF_GraphProperties` can be used to infer OpInfo::TensorProperties. Typical use case is to first call `TF_InferStatically` to statically infer shapes and then call `TF_GetInputPropertiesList` to get input shapes. + - `TF_NewGraphProperties()`,`TF_DeleteGraphProperties()`: Create/Destroy GraphProperties. + - `TF_InferStatically()`: Infer tensor shapes through abstract interpretation. + - `TF_GetInputPropertiesListSize()`,`TF_GetInputPropertiesList()`: Get a list of input `OpInfo::TensorProperties` given node name. + + An example of how to get input properties: + + ```cpp + void Optimizer_Optimize(void* optimizer, const TF_Buffer* graph_buf, const TF_GrapplerItem* item, + TF_Buffer* optimized_graph_buf, TF_Status* tf_status) { + TF_GrapplerItem* item; + TF_Status* status = TF_NewStatus(); + int num_values = 0, storage_size = 0; + TF_GraphProperties* graph_properties = TF_NewGraphProperties(item); + TF_InferStatically(graph_properties, true, false, false, false, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error for TF_InferStatically"; + + for (const NodeDef& node : item->graph.node()) { + int num_values = 0; + TF_GetInputPropertiesListSize(graph_properties, node.name().c_str(), + &num_values, status); + CHECK_EQ(TF_OK, TF_GetCode(status)); + + std::vector in_props_buf(num_values, TF_NewBuffer()); + TF_GetInputPropertiesList(graph_properties, node.name().c_str(), + in_props_buf.data(), num_values, status); + CHECK_EQ(TF_OK, TF_GetCode(status)); + + OpInfo::TensorProperties in_props; + Status s = BufferToMessage(in_props_buf[0], &in_props); + + for (int i = 0; i < in_props_buf.size(); i++) + TF_DeleteBuffer(in_props_buf[i]); + } + TF_DeleteGraphProperties(graph_properties); + TF_DeleteStatus(status); + } + ``` + +- `TF_FunctionLibraryDefinition` maintains a map between op names and op definitions, typical use case is to look up an OpDef by op name, and then get some op attributes. + - `TF_NewFunctionLibraryDefinition()`,`TF_DeleteFunctionLibraryDefinition()`: Create/Destroy FunctionLibraryDefinition. + - `TF_LookUpOpDef()`: Shorthand for calling LookUp to get the OpDef from FunctionLibraryDefinition given op name. + + An example of how to get OpDef: + + ```cpp + void Optimizer_Optimize(void* optimizer, const TF_Buffer* graph_buf, const TF_GrapplerItem* item, + TF_Buffer* optimized_graph_buf, TF_Status* tf_status) { + TF_GrapplerItem* item; + TF_Buffer* g_buf = TF_NewBuffer(); + TF_Buffer* op_buf = TF_NewBuffer(); + TF_Status* status = TF_NewStatus(); + + string name = "Add"; + Status s = MessageToBuffer(item->graph, g_buf); + TF_FunctionLibraryDefinition* func = + TF_NewFunctionLibraryDefinition(g_buf, status); + TF_LookUpOpDef(func, name.c_str(), op_buf, status); + OpDef op_def; + BufferToMessage(op_buf, op_def); + + TF_DeleteBuffer(g_buf); + TF_DeleteBuffer(op_buf); + TF_DeleteStatus(status); + TF_DeleteFunctionLibraryDefinition(func); + } + ``` +### **Profiler** +Performance is a key consideration of successful ML research and production solutions. TensorFlow profiler provides a set of good tools to help users better understand the performance bottlenecks of TensorFlow models. TensorFlow Profiler C API provides the capability of connecting third-party device's profiler library(e.g. CUPTI) to TensorFlow profiler. + +Note: Profiler is an optional module in plugin, plugin authors can decide whether to implement this module. + +To make C APIs portable, Modular TensorFlow adopts serialized `XSpace` as the objects to pass between TensorFlow framework and plugin. When the framework invokes `CollectData()`, the plugin needs to serialize `XSpace` into a sufficiently sized buffer provided by framework.Subsequently, the framework deserializes the buffer back into `XSpace`, and generates a trace view. + +
+ +
+ +In this section, you will learn how to plugin a profiler library step by step. + +§ **TF_InitProfiler** + +TF_InitProfiler is the entry point to initialize the plugin profiler, you need to define and implement this function if you want to enable profiler through the Modular TensorFlow interface. This function will be automatically loaded and invoked by TensorFlow if you define this function. + +Example: +``` +void TF_InitProfiler(TF_ProfilerRegistrationParams *params, TF_Status *status) { + params->struct_size = TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE; + params->profiler_fns->struct_size = TP_PROFILER_FNS_STRUCT_SIZE; + params->profiler->type = + DEVICE_TYPE; // type is device type, such as GPU, APU.. + params->profiler_fns->start = plugin_start; + params->profiler_fns->stop = plugin_stop; + params->profiler_fns->collect_data_xspace = plugin_collect_data_xspace; + params->destroy_profiler = plugin_destroy_profiler; + params->destroy_profiler_fns = plugin_destroy_profiler_fns; +} +``` +As you may see in the example, plugin needs to populate the `profiler` and `profiler_fns`. + +* `params->struct_size`: plugin needs to set it as `TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE` (defined in pluggable_profiler.h). This field is used for the Profiler C API version check between TensorFlow and the plugin. +* `params->profiler_fns->struct_size`: plugin needs to set it as `TP_PROFILER_FNS_STRUCT_SIZE` (defined in pluggable_profiler.h). This field is used for the Profiler C API version check between TensorFlow and the plugin. +* `params->profiler->type`: This field is the device type(e.g. GPU, APU..). +* `params->profiler_fns->start`: a callback for starting the profiler. +```c++ +void profiler_start(const TP_Profiler* profiler, TF_Status* status) { + /* Enable profiler */ + ... +} +``` +* `params->profiler_fns->stop`: a callback for stopping the profiler. +```c++ +void profiler_stop(const TP_Profiler* profiler, TF_Status* status) { + /* Disable Profiler */ + ... +} +``` +* `params->profiler_fns->collect_data_xspace`: a callback for saving collected profile data into XSpace and serializers it into the buffer. If this have been called, subsequent calls might return empty data. +```c++ +void profiler_collect_data_xspace(const TP_Profiler* profiler, uint8_t* +buffer, size_t* size_in_bytes, TF_Status* status) { + Xspace xspace = get_my_xspace(); /* Plugin generates Xspace based on + collected profiler data. */ size_t buffer_size_in_bytes = *size_in_bytes; + *size_in_bytes = xspace.ByteSizeLong(); /* get the size of Xspace */ + if (buffer == nullptr) { + return; /* TensorFlow will first get the size of Xspace, then allocate + the big enough buffer and pass it to the plugin for retrieving Xspace. + */ + } + bool success = xspace.SerializeToArray(buffer, buffer_size_in_bytes); +} + +``` +* `params->destroy_profiler`: pointer to plugin's `TP_Profiler` clean up function. Cleans up fields inside `TP_Profiler` that were allocated by the plugin. `profiler` itself must not be deleted by the plugin. +* `params->destroy_profiler_fns`: pointer to plugin's `TP_ProfilerFns` clean up function. Cleans up fields inside `TP_ProfilerFns` that were allocated by the plugin. `profiler_fns` itself must not be deleted by the plugin. + +you can find profiler implementation sample code in `sample/tensorflow_plugin/src/profiler`. +* profiler example: +
+ +
+ +## **Plugin build** + +After implementing the plugin, we need to build it as a dynamic library. Build system is decided by plugin authors, you can choose bazel, cmake or other build systems, it is out of scope in this tutorial. To make things simple, we just use the gcc command here. + +When building the plugin, we have two dependencies here: + +1. We need to include those C API header files provided by Core TensorFlow. + +2. The built plugin library needs to add dependency to `_pywrap_tensorflow_internal.so`, which is built by Core TensorFlow. `_pywrap_tensorflow_internal.so` contains those C API implementations. If you don’t add this dependency, it will report an "undefined symbol" error when loading the plugin library. + +A recommended build procedure is: + +Step1: install TF with: +``` +python3 -m venv venv +source venv/bin/activate +pip install tensorflow +``` +Step2: Then build plugin with: +``` +g++ -std=c++11 -shared plugin.cc -o plugin.so -fPIC -Ivenv/lib/python3.8/site-packages/tensorflow/include -Lvenv/lib/python3.8/site-packages/tensorflow/python -l:_pywrap_tensorflow_internal.so -O2 +``` +With this procedure, you can always build the plugin with installed TensorFlow ‘s compatible C API. + +**It should be noted** that you should pick up a unique name for the plugin's dynamic library, otherwise you may get conflict with(overwrite) other installed plugins. + +## **Plugin installation** + +After building the plugin, you may want to distribute it through the python package. One additional thing you need to do is to make the plugin’s dynamic library (libplugin.so for example) be installed/copied to the specified path (site-packages/tensorflow/python/ tensorflow-plugins/) when the user installs the package. Core TensorFlow will automatically iterate and load all the installed dynamic libraries in this path, then it will register device runtime, kernels/ops and graph optimizer by calling `SE_InitPlugin`, `TF_InitKernel` and `TF_InitGraphPlugin`. + +## **Plugin Running** + +After installing the plugin to the specified path (site-packages/tensorflow/python/tensorflow-plugins/). we can run TensorFlow with the plugin now. + +Front-end usage of the plugged device has no difference with first party devices. Suppose you have installed a plugin registers a new device with "MY_DEVICE" device type, you can: + +1) List device + +You can use *tf.config.list_physical_device()* to query whether the MY_DEVICE device is present on the host machine. If it is not found, then the plugin may not be loaded correctly. +``` +>>tf.list_physical_devices() +[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:MY_DEVICE:0', device_type=MY_DEVICE)] +``` +2) tf.device + +you can use with tf.device("my_device:0") to specify the MY_DEVICE device to be used for ops created/executed in a particular context. +``` +>>with tf.device("my_device:0"): + # ops created here have the device my_device:0 +``` +3) automatic device placement + +if you don’t specify the device to be used for ops created/executed in a particular context, the op will be auto placed into the MY_DEVICE device if the op for the MY_DEVICE device is registered. Plugged devices currently have the highest priority. + + + + + + + + + + + + + + +