diff --git a/Project.toml b/Project.toml index f78a118..00db75a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.16" +version = "0.1.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 07397b7..f7dd062 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -322,20 +322,26 @@ get_device(::AbstractArray) = LuxCPUDevice() # Adapt Interface abstract type AbstractLuxDeviceAdaptor end +abstract type AbstractLuxGPUDeviceAdaptor <: AbstractLuxDeviceAdaptor end struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor{D} <: AbstractLuxDeviceAdaptor +struct LuxCUDAAdaptor{D} <: AbstractLuxGPUDeviceAdaptor device::D end -struct LuxAMDGPUAdaptor{D} <: AbstractLuxDeviceAdaptor +struct LuxAMDGPUAdaptor{D} <: AbstractLuxGPUDeviceAdaptor device::D end -struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end +struct LuxMetalAdaptor <: AbstractLuxGPUDeviceAdaptor end adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng +# Prevent Ambiguity +for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor) + @eval adapt_storage(to::$(T), x::AbstractRange) = adapt(to, collect(x)) +end + _isbitsarray(::AbstractArray{<:Number}) = true _isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) _isbitsarray(x) = false