Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: rename round 2 to MLDataDevices #62

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name = "DeviceUtils"
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.0"
Expand Down Expand Up @@ -26,18 +26,18 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
DeviceUtilsAMDGPUExt = "AMDGPU"
DeviceUtilsCUDAExt = "CUDA"
DeviceUtilsFillArraysExt = "FillArrays"
DeviceUtilsGPUArraysExt = "GPUArrays"
DeviceUtilsMetalExt = ["GPUArrays", "Metal"]
DeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
DeviceUtilsReverseDiffExt = "ReverseDiff"
DeviceUtilsSparseArraysExt = "SparseArrays"
DeviceUtilsTrackerExt = "Tracker"
DeviceUtilsZygoteExt = "Zygote"
DeviceUtilscuDNNExt = ["CUDA", "cuDNN"]
DeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"]
MLDataDevicesAMDGPUExt = "AMDGPU"
MLDataDevicesCUDAExt = "CUDA"
MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
MLDataDevicesReverseDiffExt = "ReverseDiff"
MLDataDevicesSparseArraysExt = "SparseArrays"
MLDataDevicesTrackerExt = "Tracker"
MLDataDevicesZygoteExt = "Zygote"
MLDataDevicescuDNNExt = ["CUDA", "cuDNN"]
MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]

[compat]
AMDGPU = "0.9.6"
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# DeviceUtils
# MLDataDevices

[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils)

[![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml)
[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl)
[![codecov](https://codecov.io/gh/LuxDL/DeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/DeviceUtils.jl)
[![CI](https://github.com/LuxDL/MLDataDevices.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/MLDataDevices.jl/actions/workflows/CI.yml)
[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/MLDataDevices-dot-jl)
[![codecov](https://codecov.io/gh/LuxDL/MLDataDevices.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/MLDataDevices.jl)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

`DeviceUtils.jl` is a lightweight package defining rules for transferring data across
`MLDataDevices.jl` is a lightweight package defining rules for transferring data across
devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/).

Currently we provide support for the following backends:
Expand All @@ -24,6 +24,6 @@ Currently we provide support for the following backends:

## Updating to v1.0

* Package was renamed from `LuxDeviceUtils.jl` to `DeviceUtils.jl`.
* Package was renamed from `LuxDeviceUtils.jl` to `MLDataDevices.jl`.
* `Lux(***)Device` has been renamed to `(***)Device`.
* `Lux(***)Adaptor` objects have been removed. Use `(***)Device` objects instead.
27 changes: 0 additions & 27 deletions ext/DeviceUtilsMetalExt.jl

This file was deleted.

17 changes: 0 additions & 17 deletions ext/DeviceUtilsReverseDiffExt.jl

This file was deleted.

34 changes: 17 additions & 17 deletions ext/DeviceUtilsAMDGPUExt.jl → ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module DeviceUtilsAMDGPUExt
module MLDataDevicesAMDGPUExt

using Adapt: Adapt
using AMDGPU: AMDGPU
using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device!
using MLDataDevices: MLDataDevices, AMDGPUDevice, CPUDevice, reset_gpu_device!
using Random: Random

__init__() = reset_gpu_device!()
Expand All @@ -21,16 +21,16 @@
return
end

DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true
function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool
MLDataDevices.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true
function MLDataDevices.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool

Check warning on line 25 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
_check_use_amdgpu!()
return USE_AMD_GPU[]
end

function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing)
function MLDataDevices._with_device(::Type{AMDGPUDevice}, ::Nothing)

Check warning on line 30 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L30

Added line #L30 was not covered by tests
return AMDGPUDevice(nothing)
end
function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer)
function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer)

Check warning on line 33 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L33

Added line #L33 was not covered by tests
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
Expand All @@ -40,38 +40,38 @@
return device
end

DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)
MLDataDevices._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)

Check warning on line 43 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L43

Added line #L43 was not covered by tests

# Default RNG
DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng()
MLDataDevices.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng()

Check warning on line 46 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L46

Added line #L46 was not covered by tests

# Query Device from Array
function DeviceUtils._get_device(x::AMDGPU.AnyROCArray)
function MLDataDevices._get_device(x::AMDGPU.AnyROCArray)

Check warning on line 49 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L49

Added line #L49 was not covered by tests
parent_x = parent(x)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return DeviceUtils._get_device(parent_x)
return MLDataDevices._get_device(parent_x)

Check warning on line 52 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L52

Added line #L52 was not covered by tests
end

DeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
MLDataDevices._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice

Check warning on line 55 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L55

Added line #L55 was not covered by tests

# Set Device
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)

Check warning on line 58 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L58

Added line #L58 was not covered by tests
return AMDGPU.device!(dev)
end
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer)
return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id])
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, id::Integer)
return MLDataDevices.set_device!(AMDGPUDevice, AMDGPU.devices()[id])

Check warning on line 62 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
end
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer)
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer)

Check warning on line 64 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L64

Added line #L64 was not covered by tests
id = mod1(rank + 1, length(AMDGPU.devices()))
return DeviceUtils.set_device!(AMDGPUDevice, id)
return MLDataDevices.set_device!(AMDGPUDevice, id)

Check warning on line 66 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L66

Added line #L66 was not covered by tests
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray)
old_dev = AMDGPU.device() # remember the current device
dev = DeviceUtils.get_device(x)
dev = MLDataDevices.get_device(x)

Check warning on line 74 in ext/MLDataDevicesAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesAMDGPUExt.jl#L74

Added line #L74 was not covered by tests
if !(dev isa AMDGPUDevice)
AMDGPU.device!(to.device)
x_new = AMDGPU.roc(x)
Expand Down
34 changes: 17 additions & 17 deletions ext/DeviceUtilsCUDAExt.jl → ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module DeviceUtilsCUDAExt
module MLDataDevicesCUDAExt

using Adapt: Adapt
using CUDA: CUDA
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector
using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice
using MLDataDevices: MLDataDevices, CUDADevice, CPUDevice
using Random: Random

function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer)
function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer)

Check warning on line 9 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L9

Added line #L9 was not covered by tests
id > length(CUDA.devices()) &&
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))"))
old_dev = CUDA.device()
Expand All @@ -16,47 +16,47 @@
return device
end

function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing)
function MLDataDevices._with_device(::Type{CUDADevice}, ::Nothing)

Check warning on line 19 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L19

Added line #L19 was not covered by tests
return CUDADevice(nothing)
end

DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1
MLDataDevices._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1

Check warning on line 23 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L23

Added line #L23 was not covered by tests

# Default RNG
DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng()
MLDataDevices.default_device_rng(::CUDADevice) = CUDA.default_rng()

Check warning on line 26 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L26

Added line #L26 was not covered by tests

# Query Device from Array
function DeviceUtils._get_device(x::CUDA.AnyCuArray)
function MLDataDevices._get_device(x::CUDA.AnyCuArray)

Check warning on line 29 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L29

Added line #L29 was not covered by tests
parent_x = parent(x)
parent_x === x && return CUDADevice(CUDA.device(x))
return DeviceUtils.get_device(parent_x)
return MLDataDevices.get_device(parent_x)

Check warning on line 32 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L32

Added line #L32 was not covered by tests
end
function DeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)
function MLDataDevices._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)

Check warning on line 34 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L34

Added line #L34 was not covered by tests
return CUDADevice(CUDA.device(x.nzVal))
end

function DeviceUtils._get_device_type(::Union{
function MLDataDevices._get_device_type(::Union{

Check warning on line 38 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L38

Added line #L38 was not covered by tests
<:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray})
return CUDADevice
end

# Set Device
function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice)
function MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice)

Check warning on line 44 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L44

Added line #L44 was not covered by tests
return CUDA.device!(dev)
end
function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer)
return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id])
function MLDataDevices.set_device!(::Type{CUDADevice}, id::Integer)
return MLDataDevices.set_device!(CUDADevice, collect(CUDA.devices())[id])

Check warning on line 48 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end
function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer)
function MLDataDevices.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer)

Check warning on line 50 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L50

Added line #L50 was not covered by tests
id = mod1(rank + 1, length(CUDA.devices()))
return DeviceUtils.set_device!(CUDADevice, id)
return MLDataDevices.set_device!(CUDADevice, id)

Check warning on line 52 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L52

Added line #L52 was not covered by tests
end

# Device Transfer
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)
function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray)
old_dev = CUDA.device() # remember the current device
dev = DeviceUtils.get_device(x)
dev = MLDataDevices.get_device(x)

Check warning on line 59 in ext/MLDataDevicesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesCUDAExt.jl#L59

Added line #L59 was not covered by tests
if !(dev isa CUDADevice)
CUDA.device!(to.device)
x_new = CUDA.cu(x)
Expand Down Expand Up @@ -84,7 +84,7 @@
end
else
@warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \
an issue in DeviceUtils.jl repository."
an issue in MLDataDevices.jl repository."
end

end
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module DeviceUtilsFillArraysExt
module MLDataDevicesFillArraysExt

using Adapt: Adapt
using FillArrays: FillArrays, AbstractFill
using DeviceUtils: DeviceUtils, CPUDevice, AbstractDevice
using MLDataDevices: MLDataDevices, CPUDevice, AbstractDevice

Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module DeviceUtilsGPUArraysExt
module MLDataDevicesGPUArraysExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using DeviceUtils: CPUDevice
using MLDataDevices: CPUDevice
using Random: Random

Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()
Expand Down
27 changes: 27 additions & 0 deletions ext/MLDataDevicesMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module MLDataDevicesMetalExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: MLDataDevices, MetalDevice, reset_gpu_device!
using Metal: Metal, MtlArray

__init__() = reset_gpu_device!()

Check warning on line 8 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L8

Added line #L8 was not covered by tests

MLDataDevices.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true
function MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}})
return Metal.functional()

Check warning on line 12 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L10-L12

Added lines #L10 - L12 were not covered by tests
end

# Default RNG
MLDataDevices.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray)

Check warning on line 16 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L16

Added line #L16 was not covered by tests

# Query Device from Array
MLDataDevices._get_device(::MtlArray) = MetalDevice()

Check warning on line 19 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L19

Added line #L19 was not covered by tests

MLDataDevices._get_device_type(::MtlArray) = MetalDevice

Check warning on line 21 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L21

Added line #L21 was not covered by tests

# Device Transfer
## To GPU
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x)

Check warning on line 25 in ext/MLDataDevicesMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MLDataDevicesMetalExt.jl#L25

Added line #L25 was not covered by tests

end
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DeviceUtilsRecursiveArrayToolsExt
module MLDataDevicesRecursiveArrayToolsExt

using Adapt: Adapt, adapt
using DeviceUtils: DeviceUtils, AbstractDevice
using MLDataDevices: MLDataDevices, AbstractDevice
using RecursiveArrayTools: VectorOfArray, DiffEqArray

# We want to preserve the structure
Expand All @@ -15,9 +15,9 @@ function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray)
end

for op in (:_get_device, :_get_device_type)
@eval function DeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray})
@eval function MLDataDevices.$op(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing)
return mapreduce(DeviceUtils.$op, DeviceUtils.__combine_devices, x.u)
return mapreduce(MLDataDevices.$op, MLDataDevices.__combine_devices, x.u)
end
end

Expand Down
17 changes: 17 additions & 0 deletions ext/MLDataDevicesReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module MLDataDevicesReverseDiffExt

using MLDataDevices: MLDataDevices
using ReverseDiff: ReverseDiff

for op in (:_get_device, :_get_device_type)
@eval begin
function MLDataDevices.$op(x::ReverseDiff.TrackedArray)
return MLDataDevices.$op(ReverseDiff.value(x))
end
function MLDataDevices.$op(x::AbstractArray{<:ReverseDiff.TrackedReal})
return MLDataDevices.$op(ReverseDiff.value.(x))
end
end
end

end
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DeviceUtilsSparseArraysExt
module MLDataDevicesSparseArraysExt

using Adapt: Adapt
using DeviceUtils: CPUDevice
using MLDataDevices: CPUDevice
using SparseArrays: AbstractSparseArray

Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x
Expand Down
Loading
Loading