From f5bb85b435e6fe3db57fae1e25e09914015ef957 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 14 Jun 2024 14:47:45 -0700 Subject: [PATCH] [Core][Distributed] improve p2p cache generation (#5528) --- .../device_communicators/cuda_wrapper.py | 146 ++++++++++++ .../custom_all_reduce_utils.py | 215 ++++++++++-------- 2 files changed, 265 insertions(+), 96 deletions(-) create mode 100644 vllm/distributed/device_communicators/cuda_wrapper.py diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py new file mode 100644 index 0000000000000..24308235c4a48 --- /dev/null +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -0,0 +1,146 @@ +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +# this line makes it possible to directly load `libcudart.so` using `ctypes` +import torch # noqa + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# === export types and functions from cudart to Python === +# for the original cudart definition, please check +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("cudaDeviceReset", cudaError_t, []), + + # const char* cudaGetErrorString ( cudaError_t error ) + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), + + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function("cudaMalloc", cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function("cudaMemset", cudaError_t, + [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function("cudaMemcpy", cudaError_t, [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind + ]), + + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function("cudaIpcGetMemHandle", cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function("cudaIpcOpenMemHandle", cudaError_t, [ + ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint + ]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + assert torch.version.cuda is not None + major_version = torch.version.cuda.split(".")[0] + so_file = f"libcudart.so.{major_version}" + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["cudaGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, + count: int) -> None: + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) + + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, + count: int) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, + devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( + ctypes.byref(handle), devPtr)) + return handle + + def cudaIpcOpenMemHandle(self, + handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + return devPtr diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index c9573edb08f33..e6957b1196969 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -1,87 +1,98 @@ +import ctypes import json import os -import sys -import tempfile -import time -from contextlib import contextmanager -from typing import Callable, Dict, List, Optional +from itertools import product +from typing import Dict, Optional, Sequence -import torch import torch.distributed as dist import torch.multiprocessing as mp import vllm.envs as envs +from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless logger = init_logger(__name__) -@contextmanager -def mute_output(): - with open(os.devnull, "w") as f: - sys.stderr = f - sys.stdout = f - yield - - -def producer(i: int, - init_method: str, +def producer(batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, cuda_visible_devices: Optional[str] = None): if cuda_visible_devices is not None: os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices - with mute_output(): - dist.init_process_group( - backend="gloo", - init_method=init_method, - world_size=2, - rank=0, - ) - # produce a tensor in GPU i - data = torch.zeros((128, ), device=f"cuda:{i}") - # get the information to reconstruct the shared tensor - func, args = torch.multiprocessing.reductions.reduce_tensor(data) - args = list(args) - dist.broadcast_object_list([(func, args)], src=0) - dist.barrier() - torch.cuda.synchronize() - assert torch.all(data == 1).item() - - -def consumer(j: int, - init_method: str, + + lib = CudaRTLibrary() + for i in batch_src: + lib.cudaSetDevice(i) + pointer = lib.cudaMalloc(1024) + lib.cudaMemset(pointer, 1, 1024) + lib.cudaDeviceSynchronize() + handle = lib.cudaIpcGetMemHandle(pointer) + producer_queue.put(handle) + open_success = consumer_queue.get() + if open_success: + # use two queues to simulate barrier + producer_queue.put(0) + consumer_queue.get() + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def consumer(batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, cuda_visible_devices: Optional[str] = None): if cuda_visible_devices is not None: os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices - with mute_output(): - dist.init_process_group( - backend="gloo", - init_method=init_method, - world_size=2, - rank=1, - ) - torch.cuda.set_device(j) - recv = [None] - dist.broadcast_object_list(recv, src=0) - func: Callable - args: List - func, args = recv[0] # type: ignore - # `args[6]` is the device id - # by default pytorch will use `i` from the producer - # here we need to set it to `j` to test P2P access - args[6] = j - data = func(*args) - data += 1 - dist.barrier() - torch.cuda.synchronize() - assert torch.all(data == 1).item() - - -def can_actually_p2p(i, j): + + lib = CudaRTLibrary() + for j in batch_tgt: + lib.cudaSetDevice(j) + handle = producer_queue.get() + open_success = False + try: + pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore + open_success = True + except RuntimeError: + # cannot error out here, because the producer process + # is still waiting for the response. + pass + consumer_queue.put(open_success) + if open_success: + # modify the memory + lib.cudaMemset(pointer, 2, 1024) + # use two queues to simulate barrier + producer_queue.get() + consumer_queue.put(0) + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def can_actually_p2p( + batch_src: Sequence[int], + batch_tgt: Sequence[int], +): """ Usually, checking if P2P access is enabled can be done by - `torch.cuda.can_device_access_peer(i, j)`. However, sometimes - the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)` + `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` returns `True` even if P2P access is not actually possible. See https://github.com/vllm-project/vllm/issues/2728 and https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 @@ -90,41 +101,50 @@ def can_actually_p2p(i, j): Note on p2p and cuda IPC: Usually, one process uses one GPU: - GPU i --> cuda context i --> tensor i --> process i + GPU src --> cuda context src --> tensor src --> process src We need to combine p2p and cuda IPC, so that: - GPU i --> cuda context i --> tensor i --> process i - |shared| - GPU j --> cuda context j --> tensor j --> process j - That is to say, process i creates a tensor in GPU i, passes IPC handle to - process j, and process j accesses the tensor in GPU j. Any operation on the - tensor in process j will be reflected in the tensor in process i, because + GPU src --> cuda context src --> tensor src --> process src + |shared| + GPU tgt --> cuda context tgt --> tensor tgt --> process tgt + That is to say, process src creates a tensor in GPU src, passes IPC handle to + process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the + tensor in process tgt will be reflected in the tensor in process src, because they are the same memory segment. - It is important to note that process j accesses the tensor in GPU j, not - GPU i. That's why we need p2p access. # noqa - """ + It is important to note that process tgt accesses the tensor in GPU tgt, not + GPU src. That's why we need p2p access. + + The most time-consuming part is the process creation. To avoid creating + processes for every pair of GPUs, we use batched testing. We create two + processes for testing all pairs of GPUs in batch. The trick is to reset + the device after each test (which is not available in PyTorch). + """ # noqa cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) # pass the CUDA_VISIBLE_DEVICES to the child process # to make sure they see the same set of GPUs - # make sure the temp file is not the same across different calls - temp_path = tempfile.mktemp() + str(time.time()) - # create an empty file - with open(temp_path, "w"): - pass - init_method = f"file://{temp_path}" - # make sure the processes are spawned smp = mp.get_context("spawn") - pi = smp.Process(target=producer, - args=(i, init_method, cuda_visible_devices)) - pj = smp.Process(target=consumer, - args=(j, init_method, cuda_visible_devices)) - pi.start() - pj.start() - pi.join() - pj.join() - return pi.exitcode == 0 and pj.exitcode == 0 + producer_queue = smp.Queue() + consumer_queue = smp.Queue() + result_queue = smp.Queue() + p_src = smp.Process(target=producer, + args=(batch_src, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_tgt = smp.Process(target=consumer, + args=(batch_tgt, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_src.start() + p_tgt.start() + p_src.join() + p_tgt.join() + result = [] + for src, tgt in zip(batch_src, batch_tgt): + a = result_queue.get() + b = result_queue.get() + assert a == b + result.append(a) + return result # why do we need this cache? @@ -142,14 +162,14 @@ def can_actually_p2p(i, j): _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None -def gpu_p2p_access_check(i: int, j: int) -> bool: - """Check if GPU i can access GPU j.""" +def gpu_p2p_access_check(src: int, tgt: int) -> bool: + """Check if GPU src can access GPU tgt.""" # if the cache variable is already calculated, # read from the cache instead of checking it again global _gpu_p2p_access_cache if _gpu_p2p_access_cache is not None: - return _gpu_p2p_access_cache[f"{i}->{j}"] + return _gpu_p2p_access_cache[f"{src}->{tgt}"] is_distributed = dist.is_initialized() @@ -169,9 +189,12 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: # enter this block to calculate the cache logger.info("generating GPU P2P access cache in %s", path) cache = {} - for _i in range(num_dev): - for _j in range(num_dev): - cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j) + ids = list(range(num_dev)) + # batch of all pairs of GPUs + batch_src, batch_tgt = zip(*list(product(ids, ids))) + result = can_actually_p2p(batch_src, batch_tgt) + for _i, _j, r in zip(batch_src, batch_tgt, result): + cache[f"{_i}->{_j}"] = r with open(path, "w") as f: json.dump(cache, f, indent=4) if is_distributed: @@ -180,7 +203,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: with open(path, "r") as f: cache = json.load(f) _gpu_p2p_access_cache = cache - return _gpu_p2p_access_cache[f"{i}->{j}"] + return _gpu_p2p_access_cache[f"{src}->{tgt}"] __all__ = ["gpu_p2p_access_check"]