Skip to content

Commit

Permalink
[CPU] Add an OpenMP-based CPU launcher (#15)
Browse files Browse the repository at this point in the history
* [CPU] Add OpenMP launcher

* Address the comments

* Fix induction variable type

* Always use preallocated output buffer for CPU with torch.add
  • Loading branch information
minjang committed Jun 24, 2024
1 parent 0b18898 commit 74f111f
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 25 deletions.
2 changes: 1 addition & 1 deletion python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
cc_cmd += [f"-I{dir}" for dir in include_dirs]
# CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag.
if src.endswith(".cpp") or src.endswith(".cc"):
cc_cmd += ["-std=c++17"]
cc_cmd += ["-std=c++17", "-fopenmp"]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
Expand Down
49 changes: 33 additions & 16 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import triton
import triton.language as tl

BLOCK_SIZE = 1024
GPU_BLOCK_SIZE = 1024
CPU_BLOCK_SIZE = 4096
USE_GPU = True


@triton.jit
Expand Down Expand Up @@ -59,10 +61,11 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
# and (2) enqueue the above kernel with appropriate grid/block sizes:


def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu
def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, is_cpu):
if output is None:
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cpu == is_cpu and y.is_cpu == is_cpu and output.is_cpu == is_cpu
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
Expand All @@ -72,7 +75,7 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE if is_cpu else GPU_BLOCK_SIZE)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
Expand All @@ -87,22 +90,22 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
x = torch.rand(size, device='cpu')
y = torch.rand(size, device='cpu')
output_torch_cpu = x + y
output_triton_cpu = add(x, y, is_cpu=True)
output_triton_cpu = add(x, y, None, is_cpu=True)
print(output_torch_cpu)
print(output_triton_cpu)
print(f'The maximum difference between torch-cpu and triton-cpu is '
f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}')

LINE_VALS = ['triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('green', '-')]
LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')]

if triton.runtime.driver.get_active_gpus():
if USE_GPU and triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
x = x.to('cuda')
y = y.to('cuda')
output_torch_gpu = x + y
output_triton_gpu = add(x, y, is_cpu=False)
output_triton_gpu = add(x, y, None, is_cpu=False)
print(output_torch_gpu)
print(output_triton_gpu)
print(f'The maximum difference between torch-gpu and triton-gpu is '
Expand Down Expand Up @@ -136,28 +139,42 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu):
ylabel='GB/s', # Label name for the y-axis.
plot_name=
# Name for the plot. Used also as a file name for saving the plot.
f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})',
f'vector-add-performance (CPU_BLOCK_SIZE={CPU_BLOCK_SIZE}, GPU_BLOCK_SIZE={GPU_BLOCK_SIZE})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(size, provider):
import os

device = 'cpu' if 'cpu' in provider else 'cuda'
x = torch.rand(size, device=device, dtype=torch.float32)
y = torch.rand(size, device=device, dtype=torch.float32)

if device == 'cpu':
triton.runtime.driver.set_active_to_cpu()
if 'single' in provider:
os.environ['TRITON_CPU_SINGLE_CORE'] = '1'
else:
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
triton.runtime.driver.set_active_to_gpu()

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles)
elif provider == 'torch-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True)
# Note that we preallocate the output buffer here to only measure the kernel performance
# without a large chunk of memory allocation.
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles,
is_cpu=True)
elif provider == 'triton-cpu-single':
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True)
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True)
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
75 changes: 67 additions & 8 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __new__(cls):
return cls.instance

def __init__(self):
pass
dirname = os.path.dirname(os.path.realpath(__file__))
mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils")
self.load_binary = mod.load_binary
Expand Down Expand Up @@ -182,14 +181,39 @@ def format_of(ty):

# generate glue code
src = f"""
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <string>
#include <iostream>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <omp.h>
#include <optional>
#include <stdio.h>
#include <string>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include <stdio.h>
inline bool getBoolEnv(const std::string &env) {{
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) {{ return std::tolower(c); }});
return str == "on" || str == "true" || str == "1";
}}
inline std::optional<int64_t> getIntEnv(const std::string &env) {{
const char *cstr = std::getenv(env.c_str());
if (!cstr)
return std::nullopt;
char *endptr;
long int result = std::strtol(cstr, &endptr, 10);
if (endptr == cstr)
assert(false && "invalid integer");
return result;
}}
using kernel_ptr_t = void(*)({kernel_fn_arg_types});
Expand Down Expand Up @@ -233,20 +257,55 @@ def format_of(ty):
return ptr_info;
}}
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// TODO: add OMP pragmas to run in parallel
static std::unique_ptr<uint32_t[][3]> get_all_grids(uint32_t gridX, uint32_t gridY, uint32_t gridZ) {{
std::unique_ptr<uint32_t[][3]> grids(new uint32_t[gridX * gridY * gridZ][3]);
// TODO: which order would be more effective for cache locality?
for (uint32_t z = 0; z < gridZ; ++z) {{
for (uint32_t y = 0; y < gridY; ++y) {{
for (uint32_t x = 0; x < gridX; ++x) {{
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
grids[z * gridY * gridX + y * gridX + x][0] = x;
grids[z * gridY * gridX + y * gridX + x][1] = y;
grids[z * gridY * gridX + y * gridX + x][2] = z;
}}
}}
}}
return grids;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// TODO: Consider using omp collapse(3) clause for simplicity?
auto all_grids = get_all_grids(gridX, gridY, gridZ);
size_t N = gridX * gridY * gridZ;
if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{
if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("Single core launcher\\n");
for (size_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
}}
return;
}}
std::optional<int> max_threads = getIntEnv("TRITON_CPU_MAX_THREADS");
if (max_threads.has_value())
max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads()));
else
max_threads = omp_get_max_threads();
if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("N: %zu, max_threads: %d\\n", N, max_threads.value());
// For now, use the default chunk size, total iterations / max_threads.
#pragma omp parallel for schedule(static) num_threads(max_threads.value())
for (size_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z);
}}
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
Expand Down

0 comments on commit 74f111f

Please sign in to comment.