Skip to content

Commit

Permalink
Remove SparseArrays + CUDA ext
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 6, 2024
1 parent dd82335 commit 1b949a3
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
[extensions]
LuxDeviceUtilsAMDGPUExt = "AMDGPU"
LuxDeviceUtilsCUDAExt = "CUDA"
LuxDeviceUtilsCUDASparseArraysExt = ["CUDA", "SparseArrays"]
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsGPUArraysExt = "GPUArrays"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
Expand Down
14 changes: 14 additions & 0 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module LuxDeviceUtilsCUDAExt

using Adapt: Adapt
using CUDA: CUDA
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector
using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice
using Random: Random

Expand Down Expand Up @@ -73,4 +74,17 @@ end

Adapt.adapt_storage(::LuxCPUDevice, rng::CUDA.RNG) = Random.default_rng()

# Defining as extensions seems to case precompilation errors
@static if isdefined(CUDA.CUSPARSE, :SparseArrays)
function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseMatrix)
return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x)
end
function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseVector)
return CUDA.CUSPARSE.SparseArrays.SparseVector(x)
end
else
@warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \
an issue in LuxDeviceUtils.jl repository."
end

end
11 changes: 0 additions & 11 deletions ext/LuxDeviceUtilsCUDASparseArraysExt.jl

This file was deleted.

0 comments on commit 1b949a3

Please sign in to comment.