Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented rfft!, irfft! and brfft! #222

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ end

include("fft.jl")
include("dct.jl")
include("rfft!.jl")

include("precompile.jl")
_precompile_()
Expand Down
266 changes: 266 additions & 0 deletions src/rfft!.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import Base: IndexStyle, getindex, setindex!, eltype, \, similar, copy, real, read!

export PaddedRFFTArray, plan_rfft!, rfft!, plan_irfft!, plan_brfft!, brfft!, irfft!

_check_valid_strides(a::AbstractArray) = strides(a) == colmajorstrides(size(a))

# This struct reinterprets the `data` array to a Complex or Float array, depending on `eltype(data)`
# It is used internally with the PaddedRFFTArray in place of `Base.ReinterpretArray`
# ReinterpretArray has some performance issues when reinterprreting a Complex array to Real
struct ComplexOrRealReinterpretArray{T<:fftwNumber,N,A<:AbstractArray{<:fftwNumber,N},B<:Ptr} <: DenseArray{T,N}
data::A # Either a real or complex array
_unsafe_pointer::B # Pointer to the `data` array, but converted to a different type representation.

function ComplexOrRealReinterpretArray(rarray::AbstractArray{T,N}) where {T<:fftwReal,N}
_check_valid_strides(rarray) || throw(
ArgumentError("Only contiguous strided arrays are supported"))
ptr = unsafe_convert(Ptr{Complex{T}}, pointer(rarray))
return new{Complex{T},N,typeof(rarray),typeof(ptr)}(rarray,ptr)
end

function ComplexOrRealReinterpretArray(carray::AbstractArray{T,N}) where {T<:fftwComplex,N}
_check_valid_strides(carray) || throw(
ArgumentError("Only contiguous strided arrays are supported"))
FT = T === ComplexF64 ? Float64 : Float32
ptr = unsafe_convert(Ptr{FT}, pointer(carray))
return new{FT,N,typeof(carray),typeof(ptr)}(carray,ptr)
end
end

const RealReinterpretArray{N} = ComplexOrRealReinterpretArray{<:fftwReal,N,<:AbstractArray{<:fftwComplex,N}}
const ComplexReinterpretArray{N} = ComplexOrRealReinterpretArray{<:fftwComplex,N,<:AbstractArray{<:fftwReal,N}}

@inline size_convertion(::RealReinterpretArray,i::Integer) = 2i
@inline size_convertion(::ComplexReinterpretArray,i::Integer) = i÷2

IndexStyle(::Type{T}) where {T<:ComplexOrRealReinterpretArray} = IndexLinear()

Base.size(a::ComplexOrRealReinterpretArray) =
ntuple(i->(i == 1 ? size_convertion(a,size(a.data)[i]) : size(a.data)[i]),Val(ndims(a.data)))

Base.@propagate_inbounds function getindex(a::ComplexOrRealReinterpretArray,i::Integer)
data = a.data
@boundscheck checkbounds(a,i)
GC.@preserve data r = unsafe_load(a._unsafe_pointer, i)
favba marked this conversation as resolved.
Show resolved Hide resolved
return r
end

Base.@propagate_inbounds function setindex!(a::ComplexOrRealReinterpretArray,v,i::Integer)
data = a.data
@boundscheck checkbounds(a,i)
GC.@preserve data unsafe_store!(a._unsafe_pointer,v, i)
return a
end

Base.unsafe_convert(p::Type{Ptr{T}}, a::ComplexOrRealReinterpretArray{T,N}) where {T,N} = Base.unsafe_convert(p,a.data)

Base.elsize(::Type{<:ComplexOrRealReinterpretArray{T,N}}) where {T,N} = sizeof(T)

complex_or_real_reinterpret(a::AbstractArray) = ComplexOrRealReinterpretArray(a)
complex_or_real_reinterpret(a::ComplexOrRealReinterpretArray) = a.data # Avoid nesting of ComplexOrRealReinterpretArrays

# At the time this code was written the new `ReinterpretArray` in Base had some performace issues.
# Those issues were bypassed with the usage of our simplified version of ReinterpretArray above.
# Hopefully, once the performance issues with ReinterpretArray
# are solved we can just use Base.ReinterpretArray directly.

struct PaddedRFFTArray{T<:fftwReal,N,R,C,L,Nm1} <: DenseArray{Complex{T},N}
data::R
r::SubArray{T,N,R,Tuple{Base.OneTo{Int},Vararg{Base.Slice{Base.OneTo{Int}},Nm1}},L} # Real view skipping padding
c::C

function PaddedRFFTArray{T}(rr::AbstractArray{T,N},d::Int) where {T<:fftwReal,N}
fsize = size(rr)[1]
iseven(fsize) || throw(
ArgumentError("First dimension of allocated array must have even number of elements"))
(d == fsize-2 || d == fsize-1) || throw(
ArgumentError("Number of elements on the first dimension of array must be either 1 or 2 less than the number of elements on the first dimension of the allocated array"))
c = complex_or_real_reinterpret(rr)
r = view(rr, Base.OneTo(d), ntuple(i->Colon(),Val(N-1))...)
return new{T, N, typeof(rr), typeof(c), N===1, N-1}(rr,r,c)
end # function

function PaddedRFFTArray{T}(c::AbstractArray{Complex{T},N},d::Int) where {T<:fftwReal,N}
rr = complex_or_real_reinterpret(c)
fsize = size(rr)[1]
(d == fsize-2 || d == fsize-1) || throw(
ArgumentError("Given first dimension of real array d=$d incompatible with size of complex array $(size(c)). Valid values are d=$(fsize-2) or d=$(fsize-1)"))
r = view(rr, Base.OneTo(d), ntuple(i->Colon(),Val(N-1))...)
return new{T, N, typeof(rr), typeof(c), N===1, N-1}(rr,r,c)
end # function

end # struct

PaddedRFFTArray(a::AbstractArray{<:Union{T,Complex{T}},N},d::Int) where {T<:fftwReal,N} =
PaddedRFFTArray{T}(a,d)

function PaddedRFFTArray{T}(ndims::Vararg{Integer}) where {T}
fsize = (ndims[1]÷2 + 1)*2
a = zeros(T,(fsize, ndims[2:end]...))
PaddedRFFTArray{T}(a, ndims[1])
end

PaddedRFFTArray{T}(ndims::NTuple{N,Integer}) where {T,N} =
PaddedRFFTArray{T}(ndims...)

PaddedRFFTArray(ndims::Vararg{Integer,N}) where N =

Check warning on line 106 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L106

Added line #L106 was not covered by tests
PaddedRFFTArray{Float64}(ndims...)

PaddedRFFTArray(ndims::NTuple{N,Integer}) where N =
PaddedRFFTArray{Float64}(ndims...)

function PaddedRFFTArray(a::AbstractArray{T,N}) where {T<:fftwReal,N}
t = PaddedRFFTArray{T}(size(a))
@inbounds copyto!(t.r, a)
return t
end

copy(S::PaddedRFFTArray) = PaddedRFFTArray(copy(S.data),size(S.r,1))

similar(f::PaddedRFFTArray,::Type{T},dims::Tuple{Vararg{Int,N}}) where {T, N} =
PaddedRFFTArray{T}(dims)

similar(f::PaddedRFFTArray{T,N,L},dims::NTuple{N2,Int}) where {T,N,L,N2} =
PaddedRFFTArray{T}(dims)

similar(f::PaddedRFFTArray,::Type{T}) where {T} =
PaddedRFFTArray{T}(size(f.r))

similar(f::PaddedRFFTArray{T}) where {T} =
PaddedRFFTArray{T}(similar(f.data), size(f.r,1))

size(S::PaddedRFFTArray) =
size(S.c)

IndexStyle(::Type{T}) where {T<:PaddedRFFTArray} =
IndexLinear()

Base.@propagate_inbounds getindex(A::PaddedRFFTArray,i::Integer) =
getindex(A.c,i)

Base.@propagate_inbounds setindex!(A::PaddedRFFTArray,x, i::Integer) =
setindex!(A.c,x,i)

Base.unsafe_convert(p::Type{Ptr{Complex{T}}}, a::PaddedRFFTArray{T,N}) where {T,N} = Base.unsafe_convert(p,a.c)

Check warning on line 144 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L144

Added line #L144 was not covered by tests

Base.elsize(::Type{<:PaddedRFFTArray{T,N}}) where {T,N} = sizeof(Complex{T})

Check warning on line 146 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L146

Added line #L146 was not covered by tests


function PaddedRFFTArray(stream::IO, dims)
field = PaddedRFFTArray(dims)
return read!(stream,field)
end

function PaddedRFFTArray{T}(stream::IO, dims) where T
field = PaddedRFFTArray{T}(dims)
return read!(stream,field)
end

function read!(file::AbstractString, field::PaddedRFFTArray)
open(file) do io
return read!(io,field)

Check warning on line 161 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L159-L161

Added lines #L159 - L161 were not covered by tests
end
end

# Read a binary file of an unpaded array directly to a PaddedRFFT array, without the need
# of the creation of a intermediary Array. If the data is already padded then the user
# should just use PaddedRFFTArray{T}(read("file",unpaddeddim),d)
function read!(stream::IO, field::PaddedRFFTArray{T,N,L}) where {T,N,L}
rr = field.data
dims = size(field.r)
d = dims[1]
nb = sizeof(T)*d
npencils = prod(dims)÷d
npad = iseven(d) ? 2 : 1
for i=0:(npencils-1)
unsafe_read(stream,Ref(rr,Int((d+npad)*i+1)),nb)
end
return field
end


###########################################################################################
# Foward plans

function plan_rfft!(X::PaddedRFFTArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwReal,N}

(1 in region) || throw(ArgumentError("The first dimension must always be transformed"))
return rFFTWPlan{T,FORWARD,true,N}(X.r, X.c, region, flags, timelimit)
end

plan_rfft!(f::PaddedRFFTArray;kws...) = plan_rfft!(f, 1:ndims(f); kws...)

*(p::rFFTWPlan{T,FORWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N} =
(mul!(f.c, p, f.r); f)

rfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_rfft!(f, region) * f

function rfft!(r::SubArray{<:fftwReal}, region=1:ndims(r))
f = PaddedRFFTArray(parent(r),size(r,1))
plan_rfft!(f, region) * f
end

function \(p::rFFTWPlan{T,FORWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N}
isdefined(p,:pinv) || (p.pinv = plan_irfft!(f,p.region))
return p.pinv * f
end


##########################################################################################
# Inverse plans

function plan_brfft!(X::PaddedRFFTArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwReal,N}
(1 in region) || throw(ArgumentError("The first dimension must always be transformed"))
return rFFTWPlan{Complex{T},BACKWARD,true,N}(X.c, X.r, region, flags,timelimit)
end

plan_brfft!(f::PaddedRFFTArray;kws...) = plan_brfft!(f,1:ndims(f);kws...)

Check warning on line 221 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L221

Added line #L221 was not covered by tests

*(p::rFFTWPlan{Complex{T},BACKWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N} =
(mul!(f.r, p, f.c); f.r)

brfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_brfft!(f, region) * f

function brfft!(f::PaddedRFFTArray, i::Integer)
if i == size(f.r,1) # Assume `i` is the same as `d` in the brfft!(c::DenseArray{<:fftComplex}, d::Integer, region) defined below
return brfft!(f,1:ndims(f))
else # Assume `i` is specifying the region. `plan_brfft!` will throw an error if i != 1
return brfft!(f,(i,))

Check warning on line 232 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L232

Added line #L232 was not covered by tests
end
end

function brfft!(c::AbstractArray{<:fftwComplex}, d::Integer, region=1:ndims(c))
f = PaddedRFFTArray(c,d)
plan_brfft!(f, region) * f

Check warning on line 238 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L236-L238

Added lines #L236 - L238 were not covered by tests
end

function plan_irfft!(x::PaddedRFFTArray{T,N}, region; kws...) where {T,N}
ScaledPlan(plan_brfft!(x, region; kws...),normalization(T, size(x.r), region))
end

plan_irfft!(f::PaddedRFFTArray;kws...) = plan_irfft!(f,1:ndims(f);kws...)

*(p::ScaledPlan{Complex{T},rFFTWPlan{Complex{T},BACKWARD,true,N}},f::PaddedRFFTArray{T,N}) where {T,N} = begin
p.p * f
rmul!(f.data, p.scale)
f.r

Check warning on line 250 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L247-L250

Added lines #L247 - L250 were not covered by tests
end

irfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_irfft!(f,region) * f

function irfft!(f::PaddedRFFTArray, i::Integer)
if i == size(f.r,1) # Assume `i` is the same as `d` in the irfft!(c::AbstractArray{<:fftComplex}, d::Integer, region) defined below
return irfft!(f,1:ndims(f))

Check warning on line 257 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L255-L257

Added lines #L255 - L257 were not covered by tests
else # Assume `i` is specifying the region. `plan_irfft!` will throw an error if i != 1
return irfft!(f,(i,))

Check warning on line 259 in src/rfft!.jl

View check run for this annotation

Codecov / codecov/patch

src/rfft!.jl#L259

Added line #L259 was not covered by tests
end
end

function irfft!(c::AbstractArray{<:fftwComplex}, d::Integer, region=1:ndims(c))
f = PaddedRFFTArray(c,d)
plan_irfft!(f, region) * f
end
74 changes: 74 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,77 @@ end
@test occursin("dft-thr", string(p2))
end
end

let a = rand(Float64,(8,4,4)), b = PaddedRFFTArray(a), c = copy(b)

@testset "PaddedRFFTArray creation" begin
@test a == b.r
@test c == b
@test c.r == b.r
@test typeof(similar(b)) === typeof(b)
@test size(similar(b,Float32)) === size(b)
@test size(similar(b,Float32).r) === size(b.r)
@test size(similar(b,(4,4,4)).r) === (4,4,4)
@test size(similar(b,Float32,(4,4,4)).r) === (4,4,4)
end

@testset "rfft! and irfft!" begin
@test rfft(a) ≈ rfft!(b)
@test a ≈ irfft!(b)
@test rfft(a,1:2) ≈ rfft!(b,1:2)
@test a ≈ irfft!(b,1:2)
@test rfft(a,(1,3)) ≈ rfft!(b,(1,3))
@test a ≈ irfft!(b,(1,3))

p = plan_rfft!(c)
@test p*c ≈ rfft!(b)
@test p\c ≈ irfft!(b)

aa = rand(Float64,(9,4,4))
bb = PaddedRFFTArray(aa)
@test aa == bb.r
@test rfft(aa) ≈ rfft!(bb)
@test aa ≈ irfft!(bb)
@test rfft(aa,1:2) ≈ rfft!(bb,1:2)
@test aa ≈ irfft!(bb,1:2)
@test rfft(aa,(1,3)) ≈ rfft!(bb,(1,3))
@test aa ≈ irfft!(bb,(1,3))
end

@testset "Read binary file to PaddedRFFTArray" begin
for s in ((8,4,4),(9,4,4),(8,),(9,))
aa = rand(Float64,s)
f = IOBuffer()
write(f,aa)
@test aa == (PaddedRFFTArray(seekstart(f),s)).r
aa = rand(Float32,s)
f = IOBuffer()
write(f,aa)
@test aa == PaddedRFFTArray{Float32}(seekstart(f),s).r
end
end

@testset "brfft!" begin
a = rand(Float64,(4,4))
b = PaddedRFFTArray(a)
rfft!(b)
@test (brfft!(b) ./ 16) ≈ a
end

@testset "FFTW MEASURE flag" begin
c = similar(b)
p = plan_rfft!(c,flags=FFTW.MEASURE)
p.pinv = plan_irfft!(c,flags=FFTW.MEASURE)
c .= b
@test c == b
@test p*c ≈ rfft!(b)
@test p\c ≈ irfft!(b)
end

@testset "irfft! and brfft! of complex Array and rfft! of SubArray" begin
r = rand(8,6)
@test brfft!(rfft!(irfft!(rfft(r),8)),8)./48 ≈ r
r2 = rand(Float32,9,3,2)
@test brfft!(rfft!(irfft!(rfft(r2,(1,3)),9,(1,3))),9)./54 ≈ r2
end
end #let block