Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/PyTorch] Userbuffers and comm+GEMM overlap algorithms refactored and moved to TE/common #1067

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4a8b30a
moved userbuffers code to TE/common
denera Aug 16, 2024
b5e18df
moved comm+GEMM overlap code to TE/common
denera Aug 23, 2024
8dfe6d6
removed PyTorch depdency from comm+GEMM overlap in TE/common
denera Aug 26, 2024
f488620
added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE…
denera Aug 26, 2024
a36ebf6
updated TE/PyTorch Python API to match the refactored comm+GEMM overl…
denera Aug 26, 2024
2d495bc
updated unit tests to work with refactored comm+GEMM overlap code
denera Aug 27, 2024
0bd4822
added a pylint exception to comm+GEMM overlap test runner
denera Aug 27, 2024
cb9f235
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
aaca9d8
fixing linting errors
denera Aug 27, 2024
45acb5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
cbd22f2
added documentation for te.initialize_ub
denera Aug 27, 2024
557abbc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
59c1ced
fixed compile errors when building with NVTE_UB_WITH_MPI=1
denera Aug 27, 2024
e85062b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
502b217
fixed default bootstrap backend
denera Aug 28, 2024
b3cdf29
switched default bootstrap backend priority to MPI > Gloo > NCCL
denera Aug 28, 2024
e4679ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
f675d13
updated bootstrap backend documentation
denera Aug 28, 2024
3517beb
close UB bootstrap socket to avoid interfering with CUDA Multicast sh…
denera Aug 29, 2024
776ad27
added torch::Tensor wrappers for communication buffer and atomic coun…
denera Aug 29, 2024
ce9c34d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
935a403
automated handling of world, local and node ranks/sizes within C++ Co…
denera Sep 6, 2024
d80765e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
e5d31ce
fixed incorrect read of environment variables
denera Oct 9, 2024
3c53354
corrected priority for _SOCKET_IFNAME environment variables in UB boo…
denera Oct 9, 2024
1776282
moved multicast support check to cuda_runtime.h and replaced cudaDevi…
denera Oct 21, 2024
9dd300c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2024
d99734a
removed commented out old code and replaced external collective funct…
denera Oct 23, 2024
94dbe6a
compile-time CUDA version guard for CUDA Driver Multicast attribute
denera Oct 23, 2024
9c60c00
added compile-time CUDA version guards to Multicast code in Userbuffers
denera Oct 23, 2024
452b522
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
c1bded4
condensed UB docs, corrected const violations
denera Oct 24, 2024
a5504f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2024
7dd25c5
fixed autodoc rst for UB calls, added CUDA version guard on Multicast…
denera Oct 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ dist/
downloads/
.pytest_cache/
compile_commands.json
.nfs
19 changes: 4 additions & 15 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .utils import (
all_files_in_dir,
cuda_archs,
cuda_path,
cuda_version,
)

Expand All @@ -29,9 +28,6 @@ def setup_pytorch_extension(
sources = [
csrc_source_files / "common.cu",
csrc_source_files / "ts_fp8_op.cpp",
csrc_source_files / "userbuffers" / "ipcsocket.cc",
csrc_source_files / "userbuffers" / "userbuffers.cu",
csrc_source_files / "userbuffers" / "userbuffers-host.cpp",
] + all_files_in_dir(extensions_dir)

# Header files
Expand Down Expand Up @@ -85,19 +81,14 @@ def setup_pytorch_extension(
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])

# Libraries
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi")

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
Expand All @@ -112,6 +103,4 @@ def setup_pytorch_extension(
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
4 changes: 4 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_permute

.. autoapifunction:: transformer_engine.pytorch.moe_unpermute

.. autoapifunction:: transformer_engine.pytorch.initialize_ub

.. autoapifunction:: transformer_engine.pytorch.destroy_ub
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,20 @@ def run(self):

def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library"""
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

# Project directory root
root_path = Path(__file__).resolve().parent

return CMakeExtension(
name="transformer_engine",
cmake_path=root_path / Path("transformer_engine/common"),
cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())],
cmake_flags=cmake_flags,
)


Expand Down
Loading
Loading