diff --git a/benchmarks/xetla_benchmark/__init__.py b/benchmarks/xetla_benchmark/__init__.py index 152663fd00..9c14d71486 100644 --- a/benchmarks/xetla_benchmark/__init__.py +++ b/benchmarks/xetla_benchmark/__init__.py @@ -1,3 +1,12 @@ import torch import intel_extension_for_pytorch from . import xetla_kernel +from . import benchmark_testing +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark + +import triton +import triton.runtime.driver as driver +from . import benchmark_driver + +# replace the launcher with the profilier hook. +driver.active.launcher_cls = benchmark_driver.XPULauncher diff --git a/benchmarks/xetla_benchmark/benchmark_driver.py b/benchmarks/xetla_benchmark/benchmark_driver.py new file mode 100644 index 0000000000..82067021af --- /dev/null +++ b/benchmarks/xetla_benchmark/benchmark_driver.py @@ -0,0 +1,444 @@ +import os +import hashlib +import tempfile +import sysconfig +import setuptools +from pathlib import Path +from triton.runtime.cache import get_cache_manager +from triton.backends.compiler import GPUTarget +from triton.backends.driver import DriverBase +from triton.runtime.build import _build, quiet + +import torch +import intel_extension_for_pytorch + +dirname = os.getenv("ZE_PATH", default="/usr/local") + +include_dir = [ + os.path.join(dirname, "include"), + os.path.join(torch.utils.cmake_prefix_path, "../../include"), + os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"), + os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include") +] + +oneapi_root = os.getenv("ONEAPI_ROOT") +if oneapi_root: + include_dir += [ + os.path.join(oneapi_root, "compiler/latest/include"), + os.path.join(oneapi_root, "compiler/latest/include/sycl") + ] + +library_dir = [ + os.path.join(dirname, "lib"), + os.path.join(torch.utils.cmake_prefix_path, "../../lib"), + os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib") +] +libraries = ['ze_loader', 'sycl', 'torch', 'intel-ext-pt-gpu'] + + +def compile_module_from_src(src, name): + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.cpp") + with open(src_path, "w") as f: + f.write(src) + with quiet(): + so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Utils +# ------------------------ + + +class XPUUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(XPUUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + dirname = os.path.dirname(os.path.realpath(__file__)) + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "spirv_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + self.context = mod.init_context(self.get_sycl_queue()) + self.device_count = mod.init_devices(self.get_sycl_queue()) + self.current_device = 0 if self.device_count[0] > 0 else -1 + + def get_current_device(self): + return self.current_device + + def get_event_pool(self): + return self.event_pool + + def get_sycl_queue(self): + import torch + return torch.xpu.current_stream().sycl_queue + + def get_sycl_device(self, device_id): + import torch + return torch.xpu.device(device_id).sycl_device + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def make_launcher(constants, signature, ids): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiOKOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + # generate glue code + src = f""" + #include + #include + #include + #include + #include + #include + #include + #include + + #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + #include + #include + #include + + static inline void gpuAssert(ze_result_t code, const char *file, int line) + {{ + if (code != ZE_RESULT_SUCCESS) + {{ + const char* prefix = "Triton Error [ZE]: "; + std::string str = std::to_string(code); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str.c_str()); + PyErr_SetString(PyExc_RuntimeError, err); + }} + }} + + #define ZE_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + + typedef struct _DevicePtrInfo {{ + void* dev_ptr; + bool valid; + }} DevicePtrInfo; + + static inline void checkDevicePointer(DevicePtrInfo *ptr_info, int idx, const sycl::queue &queue) {{ + if (!ptr_info->dev_ptr || !ptr_info->valid) {{ + return; + }} + auto context = queue.get_context(); + auto handle = sycl::get_native(context); + ze_memory_allocation_properties_t prop; + prop.stype = ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES; + prop.pNext = nullptr; + ze_device_handle_t device; + auto res = zeMemGetAllocProperties((ze_context_handle_t)handle, ptr_info->dev_ptr, &prop, &device); + if (res != ZE_RESULT_SUCCESS) {{ + PyErr_Format(PyExc_ValueError, + "Cannot get memory properties for pointer argument (at %d, err=%d)", idx, res); + ptr_info->valid = false; + }} else if (prop.type != ZE_MEMORY_TYPE_DEVICE) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) doesn't reference XPU device memory (cpu tensor?)", idx); + ptr_info->valid = false; + }} + }} + + static inline DevicePtrInfo getPointer(PyObject *obj, int idx, const sycl::queue &queue) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); + checkDevicePointer(&ptr_info, idx, queue); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); + if(!ptr_info.dev_ptr) {{ + return ptr_info; + }} + checkDevicePointer(&ptr_info, idx, queue); + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; + }} +// start sycl + static void set_scalar_arg( + sycl::handler& cgh, + int index, + size_t size, + const void* value) {{ + switch (size) {{ + case sizeof(uint8_t): + cgh.set_arg(index, *static_cast(value)); + break; + case sizeof(uint16_t): + cgh.set_arg(index, *static_cast(value)); + break; + case sizeof(uint32_t): + cgh.set_arg(index, *static_cast(value)); + break; + case sizeof(uint64_t): + cgh.set_arg(index, *static_cast(value)); + break; + default: + assert(false && "wrong scalar size in sycl gen."); + }} + }} + static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + + std::string kernel_name = kernel_ptr.get_info(); + RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}}); + void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + uint32_t num_params = sizeof(params)/sizeof(params[0]); + uint32_t expected_num_params = kernel_ptr.get_info(); + size_t global_range_x = gridX*threads_per_warp*num_warps; + size_t global_range_y = gridY; + size_t global_range_z = gridZ; + size_t local_range_x = num_warps*threads_per_warp; + size_t local_range_y = 1; + size_t local_range_z = 1; + sycl::range<3> global_range(global_range_z, global_range_y, global_range_x); + sycl::range<3> local_range(local_range_z, local_range_y, local_range_x); + sycl::nd_range<3> parallel_work_size(global_range, local_range); + if (shared_memory) {{ + expected_num_params -= 1; + }} + assert(num_params == expected_num_params && "number of kernel param not matched"); + // Submit the imported kernel. + auto cgf = [&](sycl::handler &cgh) {{ + {" ".join(f'set_scalar_arg(cgh, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants]))} + if (shared_memory) {{ + using share_mem_t = sycl::local_accessor; + share_mem_t local_buffer = share_mem_t(shared_memory, cgh); + cgh.set_arg(num_params, local_buffer); + cgh.parallel_for(parallel_work_size, kernel_ptr); + }} else {{ + cgh.parallel_for(parallel_work_size, kernel_ptr); + }} + }}; + auto event = stream.submit(cgf); + xpu::profiler_record(kernel_name, event); + }} +// end sycl + static PyObject* launch(PyObject* self, PyObject* args) {{ + + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + PyObject *py_obj_stream; + void* pKrnl; + + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + // extract kernel metadata + int num_warps = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_warps")); + int num_ctas = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_ctas")); + int shared_memory = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "shared")); + int threads_per_warp = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "threads_per_warp")); + + // extract cluster dims + PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims"); + if (!PyTuple_Check(kernel_metadata)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple"); + return NULL; + }} + int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0)); + int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1)); + int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2)); + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + void * pStream = PyLong_AsVoidPtr(py_obj_stream); + //error check + if(pStream == nullptr || pKrnl == nullptr) return NULL; + + sycl::queue stream = *(static_cast(pStream)); + sycl::kernel kernel = *(static_cast(pKrnl)); + + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + if (PyErr_Occurred()) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; + }} + + static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel + }}; + + static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods + }}; + + PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; + }} + """ + return src + + +class XPULauncher(object): + + def __init__(self, src, metadata): + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(constants, signature, ids) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class XPUDriver(DriverBase): + + def __init__(self): + self.launcher_cls = XPULauncher + + def __getattr__(self, name): + # Lazily initialize utils to avoid unnecessary XPU runtime invocations. + # See https://github.com/intel/intel-xpu-backend-for-triton/issues/624 + if name == "utils": + self.utils = XPUUtils() + return self.utils + else: + raise AttributeError + + def get_current_device(self): + return self.utils.get_current_device() + + def get_current_stream(self, device): + import torch + return torch.xpu.current_stream().sycl_queue + + def get_current_target(self): + import torch + device = self.get_current_device() + dev_property = torch.xpu.get_device_capability(device) + warp_size = 32 + return GPUTarget("xpu", dev_property, warp_size) + + @staticmethod + def is_active(): + import torch + return torch.xpu.is_available() diff --git a/benchmarks/xetla_benchmark/benchmark_testing.py b/benchmarks/xetla_benchmark/benchmark_testing.py new file mode 100644 index 0000000000..5ebea17aaf --- /dev/null +++ b/benchmarks/xetla_benchmark/benchmark_testing.py @@ -0,0 +1,389 @@ +import functools +import os +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +import itertools + + +def synchronize(): + import torch + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif torch.xpu.is_available(): + torch.xpu.synchronize() + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", + device='xpu'): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + assert return_mode in ["min", "max", "mean", "median"] + import torch + from torch.autograd.profiler import record_function + + fn() + synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device) + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device=device) + + # Estimate the runtime of the function + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + + with torch.autograd.profiler_legacy.profile(True, use_xpu=True) as prof: + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + with record_function("__profile_kernel_of_func"): + fn() + # Record clocks + synchronize() + + profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.function_events) + functions = [func for func in profiling_func_filter] + + def extract_kernels(funcs): + kernels = [] + kernels += list(itertools.chain.from_iterable(map(lambda func: extract_kernels(func.cpu_children), funcs))) + kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs])) + return kernels + + kernels = [extract_kernels(func.cpu_children) for func in functions] + assert len(kernels) == n_repeat, "the profiling number not match" + # Make the time to the milliseconds. + times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + color=None, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + # df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) diff --git a/benchmarks/xetla_benchmark/fused_softmax.py b/benchmarks/xetla_benchmark/fused_softmax.py index 023b83fa0b..1ec3debbca 100644 --- a/benchmarks/xetla_benchmark/fused_softmax.py +++ b/benchmarks/xetla_benchmark/fused_softmax.py @@ -16,6 +16,8 @@ import xetla_benchmark import xetla_benchmark.xetla_kernel as xetla_kernel +benchmark_suit = xetla_benchmark # triton.testing + @torch.jit.script def naive_softmax(x): @@ -40,28 +42,14 @@ def naive_softmax(x): @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE': 128}, num_warps=32), - triton.Config({'BLOCK_SIZE': 256}, num_warps=32), - triton.Config({'BLOCK_SIZE': 512}, num_warps=32), - triton.Config({'BLOCK_SIZE': 1024}, num_warps=32), - triton.Config({'BLOCK_SIZE': 2048}, num_warps=32), - triton.Config({'BLOCK_SIZE': 128}, num_warps=16), - triton.Config({'BLOCK_SIZE': 256}, num_warps=16), - triton.Config({'BLOCK_SIZE': 512}, num_warps=16), - triton.Config({'BLOCK_SIZE': 1024}, num_warps=16), - triton.Config({'BLOCK_SIZE': 2048}, num_warps=16), - triton.Config({'BLOCK_SIZE': 128}, num_warps=8), - triton.Config({'BLOCK_SIZE': 256}, num_warps=8), - triton.Config({'BLOCK_SIZE': 512}, num_warps=8), - triton.Config({'BLOCK_SIZE': 1024}, num_warps=8), - triton.Config({'BLOCK_SIZE': 2048}, num_warps=8), - triton.Config({'BLOCK_SIZE': 128}, num_warps=4), - triton.Config({'BLOCK_SIZE': 256}, num_warps=4), - triton.Config({'BLOCK_SIZE': 512}, num_warps=4), - triton.Config({'BLOCK_SIZE': 1024}, num_warps=4), - triton.Config({'BLOCK_SIZE': 2048}, num_warps=4), + triton.Config({}, num_warps=32), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=32), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=4), ], - key=['n_cols'], + key=['n_cols', 'BLOCK_SIZE'], ) @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): @@ -91,21 +79,16 @@ def softmax(x): n_rows, n_cols = x.shape # The block size is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) - # Another trick we can use is to ask the compiler to use more threads per row by - # increasing the number of warps (`num_warps`) over which each row is distributed. - # You will see in the next tutorial how to auto-tune this value in a more natural - # way so you don't have to come up with manual heuristics yourself. - num_warps = 32 # Allocate output y = torch.empty_like(x) # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o # f the input matrix - softmax_kernel[(n_rows, )](y, x, x.stride(0), y.stride(0), n_cols) + softmax_kernel[(n_rows, )](y, x, x.stride(0), y.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE) return y -@triton.testing.perf_report( - triton.testing.Benchmark( +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[256, 1024, 2048, 4096], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot @@ -130,16 +113,23 @@ def benchmark(M, N, provider): x = torch.randn(M, N, device='xpu', dtype=torch.bfloat16) quantiles = [0.5, 0.2, 0.8] if provider == 'torch-native': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, warmup=10, + ms, min_ms, max_ms = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, warmup=10, rep=10) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, warmup=10, rep=10) + triton_fn = lambda: softmax(x) + torch_fn = lambda: torch.softmax(x, axis=-1) + benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") + ms, min_ms, max_ms = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10) + if provider == 'torch-jit': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, rep=10) + ms, min_ms, max_ms = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, rep=10) if provider == 'xetla': - name = "softmax_shape_{}_{}".format(N, N) + name = "softmax_shape_{}_{}".format(M, N) func = getattr(xetla_kernel, name) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: func(x, 0), quantiles=quantiles, warmup=10, rep=10) + xetla_fn = lambda: func(x, 0) + torch_fn = lambda: torch.softmax(x, axis=-1) + # benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch") + ms, min_ms, max_ms = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 36844b4a1c..a8d6f04504 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -23,11 +23,13 @@ sycl::queue get_current_sycl_queue() { template at::Tensor softmax(const at::Tensor &input, const int64_t dim) { CHECK_INPUT(input); + RECORD_FUNCTION("xetla softmax", {input}); auto output = at::empty_like(input); auto queue = get_current_sycl_queue(); - softmax_forward(input.data_ptr(), output.data_ptr(), queue); + auto evt = softmax_forward(input.data_ptr(), output.data_ptr(), queue); + xpu::profiler_record("xetla kernel", evt); return output; } @@ -47,11 +49,11 @@ at::Tensor bgemm(const at::Tensor &a, const at::Tensor &b, const at::Tensor &c, } PYBIND11_MODULE(xetla_kernel, m) { - m.def("softmax_shape_256_256", &softmax, + m.def("softmax_shape_4096_256", &softmax, "softmax forward (XeTLA)"); - m.def("softmax_shape_1024_1024", &softmax, + m.def("softmax_shape_4096_1024", &softmax, "softmax forward (XeTLA)"); - m.def("softmax_shape_2048_2048", &softmax, + m.def("softmax_shape_4096_2048", &softmax, "softmax forward (XeTLA)"); m.def("softmax_shape_4096_4096", &softmax, "softmax forward (XeTLA)"); diff --git a/benchmarks/xetla_kernel/softmax/softmax.h b/benchmarks/xetla_kernel/softmax/softmax.h index 0bf89d070a..e159b4b3bb 100644 --- a/benchmarks/xetla_kernel/softmax/softmax.h +++ b/benchmarks/xetla_kernel/softmax/softmax.h @@ -92,7 +92,8 @@ sycl::event softmax_forward(void *input, void *output, sycl::queue &queue) { // sycl::info::event_profiling::command_start>()) / // (1000.0f * 1000.0f * 1000.f); - // printf("M: %d, Data_type_in: %d, Bandwidth: GB/S: %f \n", mat_m, + // printf("M: %d, N: %d Data_type_in: %d, Bandwidth: GB/S: %f \n", mat_m, + // mat_n, // sizeof(data_type_in), // ((mat_m * mat_n * sizeof(data_type_in) * 2 / 1e9) / time)); return e_softmax_fwd; diff --git a/benchmarks/xetla_kernel/softmax/softmax_config.hpp b/benchmarks/xetla_kernel/softmax/softmax_config.hpp index 00727bf50b..16079a833e 100644 --- a/benchmarks/xetla_kernel/softmax/softmax_config.hpp +++ b/benchmarks/xetla_kernel/softmax/softmax_config.hpp @@ -18,10 +18,10 @@ #include -class mat1_256x256_bf16_cfg0 { +class mat1_4096x256_bf16_cfg0 { public: static constexpr size_t mat_n = 256; - static constexpr size_t mat_m = 256; + static constexpr size_t mat_m = 4096; static constexpr size_t wg_n = mat_n; static constexpr size_t wg_m = 4; // 1 4 8 16 static constexpr size_t sg_n = mat_n; @@ -31,10 +31,10 @@ class mat1_256x256_bf16_cfg0 { using data_type_acc = float; }; -class mat1_1024x1024_bf16_cfg0 { +class mat1_4096x1024_bf16_cfg0 { public: static constexpr size_t mat_n = 1024; - static constexpr size_t mat_m = 1024; + static constexpr size_t mat_m = 4096; static constexpr size_t wg_n = mat_n; static constexpr size_t wg_m = 4; // 1 4 8 16 static constexpr size_t sg_n = mat_n; @@ -44,10 +44,10 @@ class mat1_1024x1024_bf16_cfg0 { using data_type_acc = float; }; -class mat1_2048x2048_bf16_cfg0 { +class mat1_4096x2048_bf16_cfg0 { public: static constexpr size_t mat_n = 2048; - static constexpr size_t mat_m = 2048; + static constexpr size_t mat_m = 4096; static constexpr size_t wg_n = mat_n; static constexpr size_t wg_m = 4; // 1 4 8 16 static constexpr size_t sg_n = mat_n;