Skip to content

Commit

Permalink
Use the same driver for the upstream profiler as for the legacy profi…
Browse files Browse the repository at this point in the history
…ler with IPEX

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Oct 18, 2024
1 parent 26baece commit 4cd0190
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION # type: ignore # noqa: F401
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401

if USE_IPEX_OPTION:
if USE_IPEX_OPTION or BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
from triton.runtime import driver
from . import benchmark_driver
# replace the launcher with the profilier hook.
Expand Down
35 changes: 23 additions & 12 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from triton.runtime.build import _build, quiet

import torch
import intel_extension_for_pytorch

from .benchmark_testing import USE_IPEX_OPTION

_dirname = os.getenv("ZE_PATH", default="/usr/local")

include_dir = [
os.path.join(_dirname, "include"),
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"),
os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include")
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include")
]

oneapi_root = os.getenv("ONEAPI_ROOT")
Expand All @@ -28,12 +28,15 @@
os.path.join(oneapi_root, "compiler/latest/include/sycl")
]

library_dir = [
os.path.join(_dirname, "lib"),
os.path.join(torch.utils.cmake_prefix_path, "../../lib"),
os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib")
]
libraries = ["ze_loader", "sycl", "torch", "intel-ext-pt-gpu"]
library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")]
libraries = ["ze_loader", "sycl", "torch"]

if USE_IPEX_OPTION:
import intel_extension_for_pytorch

include_dir.append(os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include"))
library_dir.append(os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib"))
libraries.append("intel-ext-pt-gpu")


def compile_module_from_src(src, name):
Expand Down Expand Up @@ -141,6 +144,14 @@ def format_of(ty):
fmt = "iiiOOOOOO" + args_format
args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""

record_function_header = "#include <ATen/record_function.h>"
ipex_header = ""
xpu_profiler_record = ""
if USE_IPEX_OPTION:
record_function_header = "#include <torch/extension.h>"
ipex_header = "#include <ipex.h>"
xpu_profiler_record = "xpu::profiler_record(kernel_name, event);"

# generate glue code
src = f"""
#include <cstddef>
Expand All @@ -149,8 +160,8 @@ def format_of(ty):
#include <iomanip>
#include <level_zero/ze_api.h>
#include <sycl/sycl.hpp>
#include <torch/extension.h>
#include <ipex.h>
{record_function_header}
{ipex_header}
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
Expand Down Expand Up @@ -291,7 +302,7 @@ def format_of(ty):
}}
}};
auto event = stream.submit(cgf);
xpu::profiler_record(kernel_name, event);
{xpu_profiler_record}
}}
// end sycl
static PyObject* launch(PyObject* self, PyObject* args) {{
Expand Down

0 comments on commit 4cd0190

Please sign in to comment.