Skip to content

Commit

Permalink
Do not unconditionally load MKL_jll but match platforms
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Sep 28, 2023
1 parent 453aa11 commit 4bc920c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
2 changes: 0 additions & 2 deletions src/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import AbstractFFTs: Plan, ScaledPlan,
fftshift, ifftshift,
rfft_output_size, brfft_output_size,
plan_inv, normalization
import FFTW_jll
import MKL_jll

export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!

Expand Down
40 changes: 33 additions & 7 deletions src/providers.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
const valid_fftw_providers = if FFTW_jll.is_available() && MKL_jll.is_available()
("fftw", "mkl")
elseif FFTW_jll.is_available()
("fftw",)
elseif MKL_jll.is_available()
("mkl",)
else
# Hardcoded list of supported platforms
# In principle, we could check FFTW_jll.is_available() and MKL_jll.is_available()
# but then we would have to load MKL_jll which we want to avoid (lazy artifacts!)
const platforms_providers = Dict(
Base.BinaryPlatforms.Platform("aarch64", "macos") => ("fftw",),
Base.BinaryPlatforms.Platform("aarch64", "linux"; libc = "glibc") => ("fftw",),
Base.BinaryPlatforms.Platform("aarch64", "linux"; libc = "musl") => ("fftw",),
Base.BinaryPlatforms.Platform("armv7l", "linux"; libc = "glibc", call_abi = "eabihf") => ("fftw",),
Base.BinaryPlatforms.Platform("armv7l", "linux"; libc = "musl", call_abi = "eabihf") => ("fftw",),
Base.BinaryPlatforms.Platform("i686", "linux"; libc = "glibc") => ("fftw", "mkl"),
Base.BinaryPlatforms.Platform("i686", "linux"; libc = "musl") => ("fftw",),
Base.BinaryPlatforms.Platform("i686", "windows") => ("fftw", "mkl"),
Base.BinaryPlatforms.Platform("powerpc64le", "linux"; libc = "glibc") => ("fftw",),
Base.BinaryPlatforms.Platform("x86_64", "macos") => ("fftw", "mkl"),
Base.BinaryPlatforms.Platform("x86_64", "linux"; libc = "glibc") => ("fftw",),
Base.BinaryPlatforms.Platform("x86_64", "linux"; libc = "musl") => ("fftw",),
Base.BinaryPlatforms.Platform("x86_64", "freebsd") => ("fftw",),
Base.BinaryPlatforms.Platform("x86_64", "windows") => ("fftw", "mkl"),
)
const valid_fftw_providers = Base.BinaryPlatforms.select_platform(platforms_providers, Base.BinaryPlatforms.HostPlatform())
if valid_fftw_providers === nothing
error("no valid FFTW library available")
end

Expand Down Expand Up @@ -56,6 +70,12 @@ end

# If we're using fftw_jll, load it in
@static if fftw_provider == "fftw"
import FFTW_jll
if !FFTW_jll.is_available()
# more descriptive error message if FFTW is not available
# (should not be possible to reach this)
error("FFTW library cannot be loaded: please switch to the `mkl` provider for FFTW.jl")
end
libfftw3[] = FFTW_jll.libfftw3_path
libfftw3f[] = FFTW_jll.libfftw3f_path

Expand Down Expand Up @@ -90,6 +110,12 @@ end

# If we're using MKL, load it in and set library paths appropriately.
@static if fftw_provider == "mkl"
import MKL_jll
if !MKL_jll.is_available()
# more descriptive error message if MKL is not available
# (should not be possible to reach this)
error("MKL library cannot be loaded: please switch to the `fftw` provider for FFTW.jl")
end
libfftw3[] = MKL_jll.libmkl_rt_path
libfftw3f[] = MKL_jll.libmkl_rt_path
const _last_num_threads = Ref(Cint(1))
Expand Down

0 comments on commit 4bc920c

Please sign in to comment.