Skip to content

Commit

Permalink
TE with threading build (NVIDIA#1092)
Browse files Browse the repository at this point in the history
* added threading build back

* integrating threading for pytorch and paddle extensions

* added messages

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
2 people authored and mgoldfarb-nvidia committed Aug 14, 2024
1 parent c32f69a commit b2b77fd
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
3 changes: 2 additions & 1 deletion build_tools/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import setuptools
import os

from .utils import cuda_version

Expand Down Expand Up @@ -62,7 +63,7 @@ def setup_paddle_extension(
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
nvcc_flags.extend(["--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1")])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
Expand Down
2 changes: 1 addition & 1 deletion build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setup_pytorch_extension(
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
nvcc_flags.extend(["--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1")])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
Expand Down
4 changes: 2 additions & 2 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def get_max_jobs_for_parallel_build() -> int:
num_jobs = 0

# Check environment variable
if os.getenv("NVTE_MAX_BUILD_JOBS"):
num_jobs = int(os.getenv("NVTE_MAX_BUILD_JOBS"))
if os.getenv("NVTE_BUILD_MAX_JOBS"):
num_jobs = int(os.getenv("NVTE_BUILD_MAX_JOBS"))
elif os.getenv("MAX_JOBS"):
num_jobs = int(os.getenv("MAX_JOBS"))

Expand Down
16 changes: 16 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON)

project(transformer_engine LANGUAGES CUDA CXX)

set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB})
if (NOT BUILD_THREADS_PER_JOB)
set(BUILD_THREADS_PER_JOB 1)
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}")

if(DEFINED ENV{MAX_JOBS})
set(JOBS $ENV{MAX_JOBS})
elseif(DEFINED ENV{NVTE_BUILD_MAX_JOBS})
set(JOBS $ENV{NVTE_BUILD_MAX_JOBS})
else()
set(JOBS "max number of")
endif()

message(STATUS "Parallel build with ${JOBS} jobs and ${BUILD_THREADS_PER_JOB} threads per job")

if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()
Expand Down

0 comments on commit b2b77fd

Please sign in to comment.