Skip to content

Commit

Permalink
Use precompiled CUDA runtime libraries from NVIDIA
Browse files Browse the repository at this point in the history
  • Loading branch information
marella committed Aug 7, 2023
1 parent 5be9ecc commit 2bf4fa7
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
38 changes: 36 additions & 2 deletions ctransformers/lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
import platform
from ctypes import CDLL
from typing import List, Optional
from pathlib import Path

from .logger import logger
Expand Down Expand Up @@ -29,7 +30,8 @@ def find_library(path: Optional[str] = None, cuda: bool = False) -> str:
flags = get_cpu_info()["flags"]
except:
logger.warning(
"Unable to detect CPU features. Please report at https://github.com/marella/ctransformers/issues"
"Unable to detect CPU features. "
"Please report at https://github.com/marella/ctransformers/issues"
)
flags = []

Expand Down Expand Up @@ -59,3 +61,35 @@ def find_library(path: Optional[str] = None, cuda: bool = False) -> str:
f" {'CT_CUBLAS=1 ' if cuda else ''}pip install ctransformers --no-binary ctransformers\n\n"
)
return str(path)


def load_cuda() -> bool:
try:
import nvidia
except ImportError:
return False
if not nvidia.__file__:
logger.warning(
"CUDA libraries might not be installed properly. "
"Please report at https://github.com/marella/ctransformers/issues"
)
return False
path = Path(nvidia.__file__).parent
system = platform.system()
if system == "Windows":
libs = [
path / "cuda_runtime" / "bin" / "cudart64_12.dll",
path / "cublas" / "bin" / "cublas64_12.dll",
]
else:
libs = [
path / "cuda_runtime" / "lib" / "libcudart.so.12",
path / "cublas" / "lib" / "libcublas.so.12",
]
for lib in libs:
if not lib.is_file():
return False
libs = [str(lib.resolve()) for lib in libs]
for lib in libs:
CDLL(lib)
return True
4 changes: 3 additions & 1 deletion ctransformers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Union,
)

from .lib import find_library
from .lib import find_library, load_cuda
from .utils import Vector, utf8_split_incomplete

c_int_p = POINTER(c_int)
Expand Down Expand Up @@ -99,6 +99,8 @@ def load_library(path: Optional[str] = None, cuda: bool = False) -> Any:
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))

path = find_library(path, cuda=cuda)
if "cuda" in path:
load_cuda()
lib = CDLL(path)

lib.ctransformers_llm_create.argtypes = [
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
"py-cpuinfo>=9.0.0,<10.0.0",
],
extras_require={
"cuda": [
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
],
"gptq": [
"exllama==0.1.0",
],
Expand Down

0 comments on commit 2bf4fa7

Please sign in to comment.