Skip to content

Commit

Permalink
Merge pull request #37 from LuxDL/ap/set_device
Browse files Browse the repository at this point in the history
Move things around a bit
  • Loading branch information
avik-pal authored Mar 27, 2024
2 parents aac7f4f + 0faa572 commit 0a4e5e9
Show file tree
Hide file tree
Showing 16 changed files with 321 additions and 155 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
join_lines_based_on_source = false
2 changes: 1 addition & 1 deletion .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
version: ['1.9']
version: ['1']
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
28 changes: 20 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.18"
version = "0.1.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
Expand All @@ -23,6 +26,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxDeviceUtilsAMDGPUExt = "AMDGPU"
LuxDeviceUtilsCUDAExt = "CUDA"
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsGPUArraysExt = "GPUArrays"
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
Expand All @@ -33,10 +38,14 @@ LuxDeviceUtilsSparseArraysExt = "SparseArrays"
LuxDeviceUtilsZygoteExt = "Zygote"

[compat]
AMDGPU = "0.8.4"
Adapt = "4"
Aqua = "0.8"
Aqua = "0.8.4"
CUDA = "5.2"
ChainRulesCore = "1.20"
ComponentArrays = "0.15.8"
ExplicitImports = "1.4.1"
FastClosures = "0.3.2"
FillArrays = "1"
Functors = "0.4.4"
GPUArrays = "10"
Expand All @@ -46,28 +55,31 @@ LuxCore = "0.1.4"
Metal = "1"
PrecompileTools = "1.2"
Preferences = "1.4"
Random = "1.9"
RecursiveArrayTools = "3"
Random = "1.10"
RecursiveArrayTools = "3.8"
SafeTestsets = "0.1"
SparseArrays = "1.9"
Test = "1.9"
SparseArrays = "1.10"
Test = "1.10"
TestSetExtensions = "3"
Zygote = "0.6.69"
julia = "1.9"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "SafeTestsets", "Test", "Zygote", "TestSetExtensions"]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"]
79 changes: 79 additions & 0 deletions ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
module LuxDeviceUtilsAMDGPUExt

using Adapt: Adapt
using AMDGPU: AMDGPU
using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUAdaptor, LuxAMDGPUDevice, LuxCPUAdaptor
using Random: Random, AbstractRNG

function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing)
return LuxAMDGPUDevice(nothing)
end
function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int)
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
AMDGPU.device!(AMDGPU.devices()[id])
device = LuxAMDGPUDevice(AMDGPU.device())
AMDGPU.device!(old_dev)
return device
end

LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device)

# Default RNG
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x))

# Set Device
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice)
if !AMDGPU.functional()
@warn "AMDGPU is not functional."
return
end
AMDGPU.device!(dev)
return
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int)
if !AMDGPU.functional()
@warn "AMDGPU is not functional."
return
end
AMDGPU.device!(id)
return
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int)
id = mod1(rank + 1, length(AMDGPU.devices()))
return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, id)
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::LuxAMDGPUAdaptor, x)
old_dev = AMDGPU.device() # remember the current device
if !(x isa AMDGPU.AnyROCArray)
AMDGPU.device!(to.device)
x_new = AMDGPU.roc(x)
AMDGPU.device!(old_dev)
return x_new
elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device)
return x
else
AMDGPU.device!(to.device)
x_new = copy(x)
AMDGPU.device!(old_dev)
return x_new
end
end
Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng
Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng
function Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG)
return AMDGPU.rocrand_rng()
end
Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng()

Adapt.adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

end
86 changes: 86 additions & 0 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
module LuxDeviceUtilsCUDAExt

using Adapt: Adapt
using CUDA: CUDA, CUSPARSE
using LuxDeviceUtils: LuxDeviceUtils, LuxCUDAAdaptor, LuxCUDADevice, LuxCPUAdaptor
using Random: Random, AbstractRNG

function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int)
id > length(CUDA.devices()) &&
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))"))
old_dev = CUDA.device()
CUDA.device!(id - 1)
device = LuxCUDADevice(CUDA.device())
CUDA.device!(old_dev)
return device
end

function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing)
return LuxCUDADevice(nothing)
end

LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + 1

# Default RNG
LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()

# Query Device from Array
LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x))

# Set Device
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice)
if !CUDA.functional()
@warn "CUDA is not functional."
return
end
CUDA.device!(dev)
return
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int)
if !CUDA.functional()
@warn "CUDA is not functional."
return
end
CUDA.device!(id - 1)
return
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int)
id = mod1(rank + 1, length(CUDA.devices()))
return LuxDeviceUtils.set_device!(LuxCUDADevice, id)
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = CUDA.cu(x)
function Adapt.adapt_storage(to::LuxCUDAAdaptor, x)
old_dev = CUDA.device() # remember the current device
if !(x isa CUDA.AnyCuArray)
CUDA.device!(to.device)
x_new = CUDA.cu(x)
CUDA.device!(old_dev)
return x_new
elseif CUDA.deviceid(x) == to.device
return x
else
CUDA.device!(to.device)
x_new = copy(x)
CUDA.device!(old_dev)
return x_new
end
end
Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng
Adapt.adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng
function Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG)
return CUDA.default_rng()
end
Adapt.adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng()

Adapt.adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng()

## To CPU
## FIXME: Use SparseArrays to preserve the sparsity
function Adapt.adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix)
return Adapt.adapt(Array, x)
end

end
10 changes: 6 additions & 4 deletions ext/LuxDeviceUtilsFillArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module LuxDeviceUtilsFillArraysExt

using Adapt, FillArrays, LuxDeviceUtils
using Adapt: Adapt
using FillArrays: FillArrays
using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor

Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x

function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor,
x::FillArrays.AbstractFill)
return adapt(to, collect(x))
function Adapt.adapt_structure(
to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, x::FillArrays.AbstractFill)
return Adapt.adapt(to, collect(x))
end

end
8 changes: 5 additions & 3 deletions ext/LuxDeviceUtilsGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
module LuxDeviceUtilsGPUArraysExt

using GPUArrays, LuxDeviceUtils, Random
import Adapt: adapt_storage, adapt
using Adapt: Adapt
using GPUArrays: GPUArrays
using LuxDeviceUtils: LuxCPUAdaptor
using Random: Random

adapt_storage(::LuxCPUAdaptor, rng::GPUArrays.RNG) = Random.default_rng()
Adapt.adapt_storage(::LuxCPUAdaptor, rng::GPUArrays.RNG) = Random.default_rng()

end
51 changes: 2 additions & 49 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LuxDeviceUtilsLuxAMDGPUExt

using LuxAMDGPU, LuxDeviceUtils, Random
import Adapt: adapt_storage, adapt
using LuxAMDGPU: LuxAMDGPU
using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, reset_gpu_device!

__init__() = reset_gpu_device!()

Expand All @@ -10,51 +10,4 @@ function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGP
return LuxAMDGPU.functional()
end

function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing)
return LuxAMDGPUDevice(nothing)
end
function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int)
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
AMDGPU.device!(AMDGPU.devices()[id])
device = LuxAMDGPUDevice(AMDGPU.device())
AMDGPU.device!(old_dev)
return device
end

LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device)

# Default RNG
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x))

# Device Transfer
## To GPU
adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = roc(x)
function adapt_storage(to::LuxAMDGPUAdaptor, x)
old_dev = AMDGPU.device() # remember the current device
if !(x isa AMDGPU.AnyROCArray)
AMDGPU.device!(to.device)
x_new = roc(x)
AMDGPU.device!(old_dev)
return x_new
elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device)
return x
else
AMDGPU.device!(to.device)
x_new = copy(x)
AMDGPU.device!(old_dev)
return x_new
end
end
adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng
adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng()
adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng()

adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

end
Loading

2 comments on commit 0a4e5e9

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/103740

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.19 -m "<description of version>" 0a4e5e9132f1fa8c2f2118cc91bf29c24bedbae2
git push origin v0.1.19

Please sign in to comment.