From 2bf4fa7566fd0f8d00a139966a16c80ee22dc334 Mon Sep 17 00:00:00 2001 From: Ravindra Marella Date: Mon, 7 Aug 2023 23:56:09 +0530 Subject: [PATCH] Use precompiled CUDA runtime libraries from NVIDIA --- ctransformers/lib.py | 38 ++++++++++++++++++++++++++++++++++++-- ctransformers/llm.py | 4 +++- setup.py | 4 ++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/ctransformers/lib.py b/ctransformers/lib.py index 3eb7b13..ae94e31 100644 --- a/ctransformers/lib.py +++ b/ctransformers/lib.py @@ -1,5 +1,6 @@ -from typing import Optional import platform +from ctypes import CDLL +from typing import List, Optional from pathlib import Path from .logger import logger @@ -29,7 +30,8 @@ def find_library(path: Optional[str] = None, cuda: bool = False) -> str: flags = get_cpu_info()["flags"] except: logger.warning( - "Unable to detect CPU features. Please report at https://github.com/marella/ctransformers/issues" + "Unable to detect CPU features. " + "Please report at https://github.com/marella/ctransformers/issues" ) flags = [] @@ -59,3 +61,35 @@ def find_library(path: Optional[str] = None, cuda: bool = False) -> str: f" {'CT_CUBLAS=1 ' if cuda else ''}pip install ctransformers --no-binary ctransformers\n\n" ) return str(path) + + +def load_cuda() -> bool: + try: + import nvidia + except ImportError: + return False + if not nvidia.__file__: + logger.warning( + "CUDA libraries might not be installed properly. " + "Please report at https://github.com/marella/ctransformers/issues" + ) + return False + path = Path(nvidia.__file__).parent + system = platform.system() + if system == "Windows": + libs = [ + path / "cuda_runtime" / "bin" / "cudart64_12.dll", + path / "cublas" / "bin" / "cublas64_12.dll", + ] + else: + libs = [ + path / "cuda_runtime" / "lib" / "libcudart.so.12", + path / "cublas" / "lib" / "libcublas.so.12", + ] + for lib in libs: + if not lib.is_file(): + return False + libs = [str(lib.resolve()) for lib in libs] + for lib in libs: + CDLL(lib) + return True diff --git a/ctransformers/llm.py b/ctransformers/llm.py index 09c2221..9cf45fe 100644 --- a/ctransformers/llm.py +++ b/ctransformers/llm.py @@ -24,7 +24,7 @@ Union, ) -from .lib import find_library +from .lib import find_library, load_cuda from .utils import Vector, utf8_split_incomplete c_int_p = POINTER(c_int) @@ -99,6 +99,8 @@ def load_library(path: Optional[str] = None, cuda: bool = False) -> Any: os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin")) path = find_library(path, cuda=cuda) + if "cuda" in path: + load_cuda() lib = CDLL(path) lib.ctransformers_llm_create.argtypes = [ diff --git a/setup.py b/setup.py index ba64286..a0d630d 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,10 @@ "py-cpuinfo>=9.0.0,<10.0.0", ], extras_require={ + "cuda": [ + "nvidia-cuda-runtime-cu12", + "nvidia-cublas-cu12", + ], "gptq": [ "exllama==0.1.0", ],