Skip to content

Commit

Permalink
Improve inferrability of [region...]
Browse files Browse the repository at this point in the history
This fixes the following problem:

```julia
julia> docat(v) = [v...]
docat (generic function with 1 method)

julia> code_typed(docat, (UnitRange{Int},))
1-element Vector{Any}:
 CodeInfo(
1 ─ %1 = Core._apply_iterate(Base.iterate, Base.vect, v)::Vector{_A} where _A
└──      return %1
) => Vector{_A} where _A

julia> code_typed(docat, (Vector{Int},))
1-element Vector{Any}:
 CodeInfo(
1 ─ %1 = Core._apply_iterate(Base.iterate, Base.vect, v)::Union{Vector{Any}, Vector{Int64}}
└──      return %1
) => Union{Vector{Any}, Vector{Int64}}
```

Essentially, because `region...` is implemented via the parser,
it relies on the call `vect(args...)` and the risk is that `args`
might be empty, which causes it to return `Vector{Any}`.
  • Loading branch information
timholy committed Jan 22, 2021
1 parent 59f38c6 commit 18829ab
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ unsafe_execute!(plan::r2rFFTWPlan{T},
# Compute dims and howmany for FFTW guru planner
function dims_howmany(X::StridedArray, Y::StridedArray,
sz::Array{Int,1}, region)
reg = [region...]
reg = Int[region...]
if length(unique(reg)) < length(reg)
throw(ArgumentError("each dimension can be transformed at most once"))
end
Expand Down Expand Up @@ -578,7 +578,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",libfftw3),
Y::StridedArray{$Tc,N},
region, flags::Integer, timelimit::Real) where {inplace,N}
R = isa(region, Tuple) ? region : copy(region)
region = circshift([region...],-1) # FFTW halves last dim
region = circshift(Int[region...],-1) # FFTW halves last dim
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
plan = ccall(($(string(fftw,"_plan_guru64_dft_r2c")),$lib),
Expand All @@ -598,7 +598,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",libfftw3),
Y::StridedArray{$Tr,N},
region, flags::Integer, timelimit::Real) where {inplace,N}
R = isa(region, Tuple) ? region : copy(region)
region = circshift([region...],-1) # FFTW halves last dim
region = circshift(Int[region...],-1) # FFTW halves last dim
unsafe_set_timelimit($Tr, timelimit)
dims, howmany = dims_howmany(X, Y, [size(Y)...], region)
plan = ccall(($(string(fftw,"_plan_guru64_dft_c2r")),$lib),
Expand Down

0 comments on commit 18829ab

Please sign in to comment.