From 654a565bcee22c3ff3c184fc88f69b884fe189aa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 5 Jun 2024 19:49:55 -0700 Subject: [PATCH] Special checks for FP64 on Intel --- .JuliaFormatter.toml | 1 - ext/LuxDeviceUtilsoneAPIExt.jl | 24 +++++++++++++++++++++--- test/oneapi.jl | 1 - 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index f1f84c1..22c3407 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,6 +1,5 @@ style = "sciml" whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true diff --git a/ext/LuxDeviceUtilsoneAPIExt.jl b/ext/LuxDeviceUtilsoneAPIExt.jl index 8291435..1fc54b8 100644 --- a/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/ext/LuxDeviceUtilsoneAPIExt.jl @@ -3,9 +3,18 @@ module LuxDeviceUtilsoneAPIExt using Adapt: Adapt using GPUArrays: GPUArrays using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIDevice, reset_gpu_device! -using oneAPI: oneAPI, oneArray +using oneAPI: oneAPI, oneArray, oneL0 -__init__() = reset_gpu_device!() +const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() + +function __init__() + reset_gpu_device!() + for dev in oneAPI.devices() + SUPPORTS_FP64[dev] = oneL0.module_properties(dev).fp64flags & + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 + end +end LuxDeviceUtils.__is_loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true function LuxDeviceUtils.__is_functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) @@ -20,6 +29,15 @@ LuxDeviceUtils.get_device(::oneArray) = LuxoneAPIDevice() # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxoneAPIDevice, x) = oneArray(x) +for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) + function Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray{$(T1)}) + if !SUPPORTS_FP64[oneAPI.device()] + @warn "Double type is not supported on this device. Using `$($(T2))` instead." + return oneArray{$(T2)}(x) + end + return oneArray(x) + end +end +Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray) = oneArray(x) end diff --git a/test/oneapi.jl b/test/oneapi.jl index 7035ddf..418830a 100644 --- a/test/oneapi.jl +++ b/test/oneapi.jl @@ -25,7 +25,6 @@ using oneAPI @test LuxDeviceUtils.GPU_DEVICE[] !== nothing end - using FillArrays, Zygote # Extensions @testset "Data Transfer" begin