Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pyproject.toml to specify build requirements #1061

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
with:
submodules: recursive
- name: 'Build'
run: pip install . -v --no-deps
run: pip install . -v
ksivaman marked this conversation as resolved.
Show resolved Hide resolved
env:
NVTE_FRAMEWORK: pytorch
MAX_JOBS: 1
Expand Down
13 changes: 6 additions & 7 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import shutil
import subprocess
import sys
import importlib
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -262,15 +261,15 @@ def copy_common_headers(te_src, dst):
shutil.copy(file_path, new_path)


def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)
def install_packages(packages):
"""Install a package via pip (if not already installed)."""
for package in packages:
main_package = package.split("[")[0]
subprocess.run([sys.executable, "-m", "pip", "install", package])
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved


def uninstall_te_fw_packages():
subprocess.check_call(
subprocess.run(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_call is nice since it will print out the error message if the process fails:

Suggested change
subprocess.run(
subprocess.check_call(

We could alternatively call with check=True or manually check the return code.

[
sys.executable,
"-m",
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Copy link
Collaborator

@yaox12 yaox12 Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a section for the black formatter in pyproject.toml to declare the following arguments? It's useful when we want to format the code locally.

args: [--line-length=100, --preview, --enable-unstable-feature=string_processing]

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[build-system]
requires = ["setuptools >= 61.0", "cmake>=3.21", "pybind11", "ninja"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
25 changes: 10 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
found_pybind11,
remove_dups,
get_frameworks,
install_and_import,
uninstall_te_fw_packages,
install_packages,
)
from build_tools.te_version import te_version

Expand All @@ -29,14 +29,19 @@

from setuptools.command.build_ext import build_ext as BuildExtension

# Same list as build_system.requires in pyproject.toml
install_packages(["setuptools >= 61.0", "cmake>=3.21", "pybind11", "ninja"])

os.environ["NVTE_PROJECT_BUILDING"] = "1"

if "pytorch" in frameworks:
install_packages(["torch>=1.13"])
from torch.utils.cpp_extension import BuildExtension
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved
elif "paddle" in frameworks:
install_packages(["paddlepaddle-gpu>=2.6.1"])
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("pybind11[global]")
install_packages(["jax"])
from pybind11.setup_helpers import build_ext as BuildExtension


Expand All @@ -54,35 +59,26 @@ def setup_common_extension() -> CMakeExtension:
)


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
def setup_requirements() -> Tuple[List[str], List[str]]:
"""Setup Python dependencies

Returns dependencies for build, runtime, and testing.
"""

# Common requirements
setup_reqs: List[str] = []
install_reqs: List[str] = [
"pydantic",
"importlib-metadata>=1.0",
"packaging",
]
test_reqs: List[str] = ["pytest>=8.2.1"]

# Requirements that may be installed outside of Python
if not found_cmake():
setup_reqs.append("cmake>=3.21")
if not found_ninja():
setup_reqs.append("ninja")
if not found_pybind11():
setup_reqs.append("pybind11")

return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]


if __name__ == "__main__":
# Dependencies
setup_requires, install_requires, test_requires = setup_requirements()
install_requires, test_requires = setup_requirements()

__version__ = te_version()

Expand Down Expand Up @@ -150,7 +146,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
setup_requires=setup_requires,
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=True,
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@


from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers, install_and_import
from build_tools.utils import copy_common_headers, install_packages
from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension

install_and_import("pybind11")
install_packages("pybind11")
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved
from pybind11.setup_helpers import build_ext as BuildExtension

os.environ["NVTE_PROJECT_BUILDING"] = "1"
Expand Down
Loading