Skip to content

Commit

Permalink
Add ext_modules for install build
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Oct 10, 2024
1 parent 66ea7c3 commit d11b65d
Showing 1 changed file with 28 additions and 27 deletions.
55 changes: 28 additions & 27 deletions benchmarks/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,36 @@
import subprocess
import sys

from distutils import log
from distutils.dir_util import remove_tree
from distutils.command.clean import clean as _clean
from distutils.command.build import build as _build

from setuptools import setup
from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext as _build_ext
from setuptools._distutils import log
from setuptools._distutils.dir_util import remove_tree
from setuptools._distutils.command.clean import clean as _clean

import torch


class CMakeExtension(Extension):

def __init__(self, name):
# don't invoke the original build_ext for this special extension
super().__init__(name, sources=[])


class CMakeBuild():

def __init__(self, build_type="Debug"):
def __init__(self, debug=False, dry_run=False):
self.current_dir = os.path.abspath(os.path.dirname(__file__))
self.build_temp = self.current_dir + "/build/temp"
self.extdir = self.current_dir + "/triton_kernels_benchmark"
self.build_type = build_type
self.build_type = self.get_build_type(debug)
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
self.use_ipex = False
self.dry_run = dry_run

def get_build_type(self, debug):
DEBUG_OPTION = os.getenv("DEBUG", "0")
return "Debug" if debug or (DEBUG_OPTION == "1") else "Release"

def run(self):
self.check_ipex()
Expand All @@ -41,7 +52,8 @@ def check_ipex(self):

def check_call(self, *popenargs, **kwargs):
print(" ".join(popenargs[0]))
subprocess.check_call(*popenargs, **kwargs)
if not self.dry_run:
subprocess.check_call(*popenargs, **kwargs)

def build_extension(self):
ninja_dir = shutil.which("ninja")
Expand Down Expand Up @@ -94,38 +106,27 @@ def clean(self):
os.path.dirname(__file__)))


class build(_build):
class build_ext(_build_ext):

def run(self):
self.build_cmake()
super().run()

def build_cmake(self):
DEBUG_OPTION = os.getenv("DEBUG", "0")
debug = DEBUG_OPTION == "1"
if hasattr(self, "debug"):
debug = debug or self.debug
build_type = "Debug" if debug else "Release"
cmake = CMakeBuild(build_type)
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
cmake.run()
super().run()


class clean(_clean):

def run(self):
self.clean_cmake()
super().run()

def clean_cmake(self):
cmake = CMakeBuild()
cmake = CMakeBuild(dry_run=self.dry_run)
cmake.clean()
super().run()


setup(name="triton-kernels-benchmark", packages=[
"triton_kernels_benchmark",
], package_dir={
"triton_kernels_benchmark": "triton_kernels_benchmark",
}, package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={
"build": build,
"build_ext": build_ext,
"clean": clean,
})
}, ext_modules=[CMakeExtension("triton_kernels_benchmark")])

0 comments on commit d11b65d

Please sign in to comment.