Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc][cuda] use nvml query to avoid accidentally cuda initialization #6007

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import torch

from vllm import _custom_ops as ops
from vllm.utils import get_device_capability_stateless

CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]


Expand Down
3 changes: 2 additions & 1 deletion tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import get_device_capability_stateless


def is_quant_method_supported(quant_method: str) -> bool:
# Currently, all quantization methods require Nvidia or AMD GPUs
if not torch.cuda.is_available():
return False

capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]
return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability())
6 changes: 3 additions & 3 deletions vllm/attention/ops/blocksparse_attention/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import torch

from vllm.utils import is_cpu, is_hip
from vllm.utils import get_device_capability_stateless, is_cpu, is_hip

from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)

IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 8)
and get_device_capability_stateless()[0] >= 8)

if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
Expand Down Expand Up @@ -235,4 +235,4 @@ def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
v,
cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale)
sm_scale=sm_scale)
4 changes: 3 additions & 1 deletion vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import triton
import triton.language as tl

from vllm.utils import get_device_capability_stateless

if triton.__version__ >= "2.1.0":

@triton.jit
Expand Down Expand Up @@ -683,7 +685,7 @@ def context_attention_fwd(q,
alibi_slopes=None,
sliding_window=None):

cap = torch.cuda.get_device_capability()
cap = get_device_capability_stateless()
BLOCK = 128 if cap[0] >= 8 else 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
Expand Down
58 changes: 5 additions & 53 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,66 +11,18 @@
gpu_p2p_access_check)
from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
from vllm.utils import cuda_device_count_stateless, is_full_nvlink

try:
import pynvml

# Simulate ImportError if custom_ar ops are not supported.
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
raise ImportError("custom_ar", __file__)

assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
custom_ar = True

@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()

except ImportError:
# For AMD GPUs
except Exception:
# For AMD GPUs and CPUs
custom_ar = False
pynvml = None

@contextmanager
def _nvml():
try:
yield
finally:
pass


logger = init_logger(__name__)


@_nvml()
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True


def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
Expand Down Expand Up @@ -161,7 +113,7 @@ def __init__(self,
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(physical_device_ids)
full_nvlink = is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import torch

from vllm import _custom_ops as ops
from vllm.utils import get_device_capability_stateless


def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return

if torch.cuda.get_device_capability() < (8, 0):
if get_device_capability_stateless() < (8, 0):
raise ImportError(
"punica LoRA kernels require compute capability >= 8.0")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)
from vllm.utils import get_device_capability_stateless


class CompressedTensorsConfig(QuantizationConfig):
Expand Down Expand Up @@ -84,7 +85,7 @@ def get_config_filenames(cls) -> List[str]:
return []

def _check_gptq_and_marlin_can_run(self):
capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
from vllm.utils import get_device_capability_stateless, print_warning_once

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)


def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]

return ops.cutlass_scaled_mm_supports_fp8(capability)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import get_device_capability_stateless

logger = init_logger(__name__)

Expand Down Expand Up @@ -165,7 +166,7 @@ def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
return False

# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
major, minor = get_device_capability_stateless()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights)
from vllm.utils import get_device_capability_stateless

__cuda_arch = torch.cuda.get_device_capability()
__cuda_arch = get_device_capability_stateless()

MARLIN_TILE = 16

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_tpu
from vllm.utils import get_device_capability_stateless, is_tpu

logger = init_logger(__name__)

Expand All @@ -46,7 +46,7 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability()
capability = get_device_capability_stateless()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
Expand Down
57 changes: 57 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,63 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA

try:
import pynvml
except ImportError:
# For non-NV devices
pynvml = None


def with_nvml_context(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
if pynvml is not None:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
if pynvml is not None:
pynvml.nvmlShutdown()

return wrapper


@with_nvml_context
def is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True


@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability_stateless(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):

Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_device_capability_stateless
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
Expand Down Expand Up @@ -322,7 +323,7 @@ def init_worker_distributed_environment(
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
compute_capability = get_device_capability_stateless()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
Expand Down
Loading