From 17bc81a0fcf9875d777ea4bee2fca70fc23c8a0c Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Sun, 13 Mar 2022 17:15:44 -0700 Subject: [PATCH] Add get_num_threads (#171) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add get_num_threads This commit adds `get_num_threads`, which returns the number of threads used by the planner, and is the complement to `set_num_threads`. This simply wraps the function `fftw_planner_nthreads`, which was [newly added to fftw in version 3.3.9](https://github.com/FFTW/fftw3/blob/34082eb5d6ed7dc9436915df69f376c06fc39762/NEWS#L3). * Set FFTW_jll compat to 3.3.9 `get_num_threads` requires FFTW_jll v3.3.9+7, but it doesn't seem possible to specify a particular build in the compat section of Project.toml files. However, this should work in most cases, as the most recent build of `FFTW_jll` should be downloaded upon updating. * bump to 1.3 for the new function * Make test for get_num_threads fftw specific No equivalent function for mkl * Typo... * another typo * Add vendor check to `get_num_threads` * Add a method of `set_num_threads` that restores the original nthreads Additionally, separate previous `set_num_threads` method into a base function, `_set_num_threads`, that wraps the `ccalls`, and `set_num_threads`, which will acquire the `fftwlock`. * Provide support for `get_num_threads` with MKL's FFTW While MKL's FFTW does not provide access to the number of threads available to the planner, this can be simulated by caching the value last passed to `set_num_threads` and returning it with `get_num_threads` if `fftw_vendor == :mkl`. * Implement suggestions of @stevengj * Fix typo in set_num_threads * Add test for set_num_threads method that restores original num_threads * Rename `nthreads` variable to `num_threads` to avoid shadowing Threads.nthreads Since FFTW uses `Base.Threads`, and `nthreads` is a function defined in `Base.Threads`, then the function argument `nthreads` shadows a function already in the namespace of every function. While there is no inherent issue with this, it can make debugging this code more confusing. * Make one-line method of `set_num_threads` one line. * First attempt at adding `num_threads` to `plan_...` functions As suggested by @stevengj, I have add a `num_threads` keyword to the `plan_...` functions. My approach here is fairly naive, and adds a bunch of redundant boiler plate code to every `plan_` function. Co-authored-by: Steven G. Johnson Co-authored-by: Mosè Giordano --- Project.toml | 4 +- src/dct.jl | 8 +-- src/fft.jl | 126 +++++++++++++++++++++++++++++++++++++++++------ src/providers.jl | 1 + test/runtests.jl | 11 +++++ 5 files changed, 128 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index cdb3efa..641b4d1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FFTW" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.4.6" +version = "1.5.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -12,7 +12,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [compat] AbstractFFTs = "1.0" -FFTW_jll = "3.3" +FFTW_jll = "3.3.9" MKL_jll = "2019.0.117, 2020, 2021, 2022" Preferences = "1.2" Reexport = "0.2, 1.0" diff --git a/src/dct.jl b/src/dct.jl index 8c9db4a..bc1d252 100644 --- a/src/dct.jl +++ b/src/dct.jl @@ -3,14 +3,14 @@ # (This is part of the FFTW module.) """ - plan_dct!(A [, dims [, flags [, timelimit]]]) + plan_dct!(A [, dims [, flags [, timelimit [, num_threads]]]]) Same as [`plan_dct`](@ref), but operates in-place on `A`. """ function plan_dct! end """ - plan_idct(A [, dims [, flags [, timelimit]]]) + plan_idct(A [, dims [, flags [, timelimit [, num_threads]]]]) Pre-plan an optimized inverse discrete cosine transform (DCT), similar to [`plan_fft`](@ref) except producing a function that computes @@ -20,7 +20,7 @@ Pre-plan an optimized inverse discrete cosine transform (DCT), similar to function plan_idct end """ - plan_dct(A [, dims [, flags [, timelimit]]]) + plan_dct(A [, dims [, flags [, timelimit [, num_threads]]]]) Pre-plan an optimized discrete cosine transform (DCT), similar to [`plan_fft`](@ref) except producing a function that computes @@ -30,7 +30,7 @@ Pre-plan an optimized discrete cosine transform (DCT), similar to function plan_dct end """ - plan_idct!(A [, dims [, flags [, timelimit]]]) + plan_idct!(A [, dims [, flags [, timelimit [, num_threads]]]]) Same as [`plan_idct`](@ref), but operates in-place on `A`. """ diff --git a/src/fft.jl b/src/fft.jl index 9d038e3..c8c0898 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -38,14 +38,14 @@ an array of real or complex floating-point numbers. function r2r! end """ - plan_r2r!(A, kind [, dims [, flags [, timelimit]]]) + plan_r2r!(A, kind [, dims [, flags [, timelimit [, num_threads]]]]) Similar to [`plan_fft`](@ref), but corresponds to [`r2r!`](@ref). """ function plan_r2r! end """ - plan_r2r(A, kind [, dims [, flags [, timelimit]]]) + plan_r2r(A, kind [, dims [, flags [, timelimit [, num_threads]]]]) Pre-plan an optimized r2r transform, similar to [`plan_fft`](@ref) except that the transforms (and the first three arguments) @@ -171,9 +171,33 @@ end # Threads -@exclusive function set_num_threads(nthreads::Integer) - ccall((:fftw_plan_with_nthreads,libfftw3[]), Cvoid, (Int32,), nthreads) - ccall((:fftwf_plan_with_nthreads,libfftw3f[]), Cvoid, (Int32,), nthreads) +# Must only be called after acquiring fftwlock +function _set_num_threads(num_threads::Integer) + @static if fftw_provider == "mkl" + _last_num_threads[] = num_threads + end + ccall((:fftw_plan_with_nthreads,libfftw3[]), Cvoid, (Int32,), num_threads) + ccall((:fftwf_plan_with_nthreads,libfftw3f[]), Cvoid, (Int32,), num_threads) +end + +@exclusive set_num_threads(num_threads::Integer) = _set_num_threads(num_threads) + +function get_num_threads() + @static if fftw_provider == "fftw" + ccall((:fftw_planner_nthreads,libfftw3[]), Cint, ()) + else + _last_num_threads[] + end +end + +@exclusive function set_num_threads(f::Function, num_threads::Integer) + orig_num_threads = get_num_threads() + _set_num_threads(num_threads) + try + f() + finally + _set_num_threads(orig_num_threads) + end end # pointer type for fftw_plan (opaque pointer) @@ -684,14 +708,28 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD)) @eval begin function $plan_f(X::StridedArray{T,N}, region; flags::Integer=ESTIMATE, - timelimit::Real=NO_TIMELIMIT) where {T<:fftwComplex,N} + timelimit::Real=NO_TIMELIMIT, + num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N} + if num_threads !== nothing + plan = set_num_threads(num_threads) do + $plan_f(X, region; flags = flags, timelimit = timelimit) + end + return plan + end cFFTWPlan{T,$direction,false,N}(X, fakesimilar(flags, X, T), region, flags, timelimit) end function $plan_f!(X::StridedArray{T,N}, region; - flags::Integer=ESTIMATE, - timelimit::Real=NO_TIMELIMIT) where {T<:fftwComplex,N} + flags::Integer=ESTIMATE, + timelimit::Real=NO_TIMELIMIT, + num_threads::Union{Nothing, Integer} = nothing ) where {T<:fftwComplex,N} + if num_threads !== nothing + plan = set_num_threads(num_threads) do + $plan_f!(X, region; flags = flags, timelimit = timelimit) + end + return plan + end cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit) end $plan_f(X::StridedArray{<:fftwComplex}; kws...) = @@ -699,7 +737,14 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD)) $plan_f!(X::StridedArray{<:fftwComplex}; kws...) = $plan_f!(X, 1:ndims(X); kws...) - function plan_inv(p::cFFTWPlan{T,$direction,inplace,N}) where {T<:fftwComplex,N,inplace} + function plan_inv(p::cFFTWPlan{T,$direction,inplace,N}; + num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace} + if num_threads !== nothing + plan = set_num_threads(num_threads) do + plan_inv(p) + end + return plan + end X = Array{T}(undef, p.sz) Y = inplace ? X : fakesimilar(p.flags, X, T) ScaledPlan(cFFTWPlan{T,$idirection,inplace,N}(X, Y, p.region, @@ -735,7 +780,14 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) @eval begin function plan_rfft(X::StridedArray{$Tr,N}, region; flags::Integer=ESTIMATE, - timelimit::Real=NO_TIMELIMIT) where N + timelimit::Real=NO_TIMELIMIT, + num_threads::Union{Nothing, Integer} = nothing) where N + if num_threads !== nothing + plan = set_num_threads(num_threads) do + plan_rfft(X, region; flags = flags, timelimit = timelimit) + end + return plan + end osize = rfft_output_size(X, region) Y = flags&ESTIMATE != 0 ? FakeArray{$Tc}(osize) : Array{$Tc}(undef, osize) rFFTWPlan{$Tr,$FORWARD,false,N}(X, Y, region, flags, timelimit) @@ -743,7 +795,14 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) function plan_brfft(X::StridedArray{$Tc,N}, d::Integer, region; flags::Integer=ESTIMATE, - timelimit::Real=NO_TIMELIMIT) where N + timelimit::Real=NO_TIMELIMIT, + num_threads::Union{Nothing, Integer} = nothing) where N + if num_threads !== nothing + plan = set_num_threads(num_threads) do + plan_brfft(X, d, region; flags = flags, timelimit = timelimit) + end + return plan + end osize = brfft_output_size(X, d, region) Y = flags&ESTIMATE != 0 ? FakeArray{$Tr}(osize) : Array{$Tr}(undef, osize) @@ -763,7 +822,14 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...) plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...) - function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N}) where N + function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N}, + num_threads::Union{Nothing, Integer} = nothing) where N + if num_threads !== nothing + plan = set_num_threads(num_threads) do + plan_inv(p) + end + return plan + end X = Array{$Tr}(undef, p.sz) Y = p.flags&ESTIMATE != 0 ? FakeArray{$Tc}(p.osz) : Array{$Tc}(undef, p.osz) ScaledPlan(rFFTWPlan{$Tc,$BACKWARD,false,N}(Y, X, p.region, @@ -773,7 +839,14 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) normalization(X, p.region)) end - function plan_inv(p::rFFTWPlan{$Tc,$BACKWARD,false,N}) where N + function plan_inv(p::rFFTWPlan{$Tc,$BACKWARD,false,N}; + num_threads::Union{Nothing, Integer} = nothing) where N + if num_threads !== nothing + plan = set_num_threads(num_threads) do + plan_inv(p) + end + return plan + end X = Array{$Tc}(undef, p.sz) Y = p.flags&ESTIMATE != 0 ? FakeArray{$Tr}(p.osz) : Array{$Tr}(undef, p.osz) ScaledPlan(rFFTWPlan{$Tr,$FORWARD,false,N}(Y, X, p.region, @@ -832,14 +905,28 @@ end function plan_r2r(X::StridedArray{T,N}, kinds, region; flags::Integer=ESTIMATE, - timelimit::Real=NO_TIMELIMIT) where {T<:fftwNumber,N} + timelimit::Real=NO_TIMELIMIT, + num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,N} + if num_threads !== nothing + plan = set_num_threads(num_threads) do + plan_r2r(X, kinds, region; flags = flags, timelimit = timelimit) + end + return plan + end r2rFFTWPlan{T,Any,false,N}(X, fakesimilar(flags, X, T), region, kinds, flags, timelimit) end function plan_r2r!(X::StridedArray{T,N}, kinds, region; flags::Integer=ESTIMATE, - timelimit::Real=NO_TIMELIMIT) where {T<:fftwNumber,N} + timelimit::Real=NO_TIMELIMIT, + num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,N} + if num_threads !== nothing + plan = set_num_threads(num_threads) do + plan_r2r(X, kinds, region; flags = flags, timelimit = timelimit) + end + return plan + end r2rFFTWPlan{T,Any,true,N}(X, X, region, kinds, flags, timelimit) end @@ -861,7 +948,14 @@ function logical_size(n::Integer, k::Integer) return 2n end -function plan_inv(p::r2rFFTWPlan{T,K,inplace,N}) where {T<:fftwNumber,K,inplace,N} +function plan_inv(p::r2rFFTWPlan{T,K,inplace,N}; + num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,K,inplace,N} + if num_threads !== nothing + set_num_threads(num_threads) do + plan = plan_inv(p) + end + return plan + end X = Array{T}(undef, p.sz) iK = fix_kinds(p.region, [inv_kind[k] for k in K]) Y = inplace ? X : fakesimilar(p.flags, X, T) diff --git a/src/providers.jl b/src/providers.jl index 9b08140..58817b1 100644 --- a/src/providers.jl +++ b/src/providers.jl @@ -85,4 +85,5 @@ end import MKL_jll libfftw3[] = MKL_jll.libmkl_rt_path libfftw3f[] = MKL_jll.libmkl_rt_path + const _last_num_threads = Ref(Cint(1)) end diff --git a/test/runtests.jl b/test/runtests.jl index 2291003..965520a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -528,3 +528,14 @@ end @test occursin("dft-thr", string(p2)) end end + +@testset "Setting and getting planner nthreads" begin + FFTW.set_num_threads(1) + @test FFTW.get_num_threads() == 1 + FFTW.set_num_threads(2) + @test FFTW.get_num_threads() == 2 + plan = FFTW.set_num_threads(1) do # Should leave get_num_threads unchanged + plan_rfft(m4, 1) + end + @test FFTW.get_num_threads() == 2 # Unchanged +end