Skip to content

Commit

Permalink
Special checks for FP64 on Intel
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 6, 2024
1 parent c15ce87 commit 5bc20fd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
style = "sciml"
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 4
format_docstrings = true
Expand Down
25 changes: 22 additions & 3 deletions ext/LuxDeviceUtilsoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@ module LuxDeviceUtilsoneAPIExt
using Adapt: Adapt
using GPUArrays: GPUArrays
using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIDevice, reset_gpu_device!
using oneAPI: oneAPI, oneArray
using oneAPI: oneAPI, oneArray, oneL0

__init__() = reset_gpu_device!()
const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}()

function __init__()
reset_gpu_device!()
for dev in oneAPI.devices()
SUPPORTS_FP64[dev] = oneL0.module_properties(dev).fp64flags &
oneL0.ZE_DEVICE_MODULE_FLAG_FP64 ==
oneL0.ZE_DEVICE_MODULE_FLAG_FP64
end
end

LuxDeviceUtils.__is_loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true
function LuxDeviceUtils.__is_functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}})
Expand All @@ -20,6 +29,16 @@ LuxDeviceUtils.get_device(::oneArray) = LuxoneAPIDevice()

# Device Transfer
## To GPU
Adapt.adapt_storage(::LuxoneAPIDevice, x) = oneArray(x)
for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
@eval function Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray{$(T1)})
if !SUPPORTS_FP64[oneAPI.device()]
@warn LazyString(
"Double type is not supported on this device. Using `", $(T2), "` instead.")
return oneArray{$(T2)}(x)
end
return oneArray(x)
end
end
Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray) = oneArray(x)

end
1 change: 0 additions & 1 deletion test/oneapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ using oneAPI
@test LuxDeviceUtils.GPU_DEVICE[] !== nothing
end


using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
Expand Down

0 comments on commit 5bc20fd

Please sign in to comment.