Skip to content

Commit

Permalink
auto finding libcudnn8 and libcudnn9 packages
Browse files Browse the repository at this point in the history
  • Loading branch information
prathameshzarkar9 committed Sep 25, 2024
1 parent 7ee889c commit 6f2aeac
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/nvidia-cuda/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ fi

echo "Installing CUDA libraries..."
apt-get install -yq "$cuda_pkg"
apt-get update -yq

# auto find recent cudnn version
major_cuda_version=$(echo "${CUDA_VERSION}" | cut -d '.' -f 1)
if [[ "$CUDA_VERSION" < "12.3" ]]; then
CUDNN_VERSION=$(apt-cache policy libcudnn8 | grep "$CUDA_VERSION" | grep -Eo '^[^-1+]*' | sort -V | tail -n1 | xargs)
else
CUDNN_VERSION=$(apt-cache policy libcudnn9-cuda-$major_cuda_version | grep "Candidate" | awk '{print $2}' | grep -Eo '^[^-1+]*')
fi
major_cudnn_version=$(echo "${CUDNN_VERSION}" | cut -d '.' -f 1)

Expand Down
15 changes: 15 additions & 0 deletions test/nvidia-cuda/install_cuda_12_3_version.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

set -e

# Optional: Import test library
source dev-container-features-test-lib

# # Check installation of libcudnn8 (9.4.0)
check "libcudnn.so.9.4.0" test 1 -eq "$(find /usr -name 'libcudnn.so.9.4.0' | wc -l)"

# Check installation of cuda-nvtx-12-3 (12.3)
check "cuda-12-3+nvtx" test -e '/usr/local/cuda-12.3/targets/x86_64-linux/include/nvtx3/'

# Report result
reportResults
15 changes: 15 additions & 0 deletions test/nvidia-cuda/install_cuda_12_4_version.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

set -e

# Optional: Import test library
source dev-container-features-test-lib

# # Check installation of libcudnn8 (9.4.0)
check "libcudnn.so.9.4.0" test 1 -eq "$(find /usr -name 'libcudnn.so.9.4.0' | wc -l)"

# Check installation of cuda-nvtx-12-3 (12.3)
check "cuda-12-4+nvtx" test -e '/usr/local/cuda-12.4/targets/x86_64-linux/include/nvtx3/'

# Report result
reportResults
15 changes: 15 additions & 0 deletions test/nvidia-cuda/install_cuda_12_5_version.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

set -e

# Optional: Import test library
source dev-container-features-test-lib

# # Check installation of libcudnn8 (9.4.0)
check "libcudnn.so.9.4.0" test 1 -eq "$(find /usr -name 'libcudnn.so.9.4.0' | wc -l)"

# Check installation of cuda-nvtx-12-3 (12.3)
check "cuda-12-5+nvtx" test -e '/usr/local/cuda-12.5/targets/x86_64-linux/include/nvtx3/'

# Report result
reportResults

0 comments on commit 6f2aeac

Please sign in to comment.