diff --git a/base/array.jl b/base/array.jl index 91e9f800f8cb4..76d1883d17af5 100644 --- a/base/array.jl +++ b/base/array.jl @@ -664,6 +664,14 @@ end _array_for(::Type{T}, itr, ::HasLength) where {T} = Array{T,1}(Int(length(itr)::Integer)) _array_for(::Type{T}, itr, ::HasShape) where {T} = similar(Array{T}, indices(itr)) +function _isconcreteunion(::Type{T}) where T + if T isa Union + _isconcreteunion(T.a) && _isconcreteunion(T.b) + else + isconcrete(T) + end +end + function collect(itr::Generator) isz = iteratorsize(itr.iter) et = _default_eltype(typeof(itr)) @@ -675,7 +683,9 @@ function collect(itr::Generator) return _array_for(et, itr.iter, isz) end v1, st = next(itr, st) - collect_to_with_first!(_array_for(typeof(v1), itr.iter, isz), v1, itr, st) + S = _default_eltype(typeof(itr)) + T = _isconcreteunion(S) ? S : typeof(v1) + collect_to_with_first!(_array_for(T, itr.iter, isz), v1, itr, st) end end @@ -688,7 +698,9 @@ function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape}) return _similar_for(c, _default_eltype(typeof(itr)), itr, isz) end v1, st = next(itr, st) - collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st) + S = _default_eltype(typeof(itr)) + T = _isconcreteunion(S) ? S : typeof(v1) + collect_to_with_first!(_similar_for(c, T, itr, isz), v1, itr, st) end function collect_to_with_first!(dest::AbstractArray, v1, itr, st) diff --git a/test/arrayops.jl b/test/arrayops.jl index 70997260272ea..627afde46758f 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -1207,6 +1207,10 @@ end @test isequal([1,2,3], [a for (a,b) in enumerate(2:4)]) @test isequal([2,3,4], [b for (a,b) in enumerate(2:4)]) + @test [s for s in Union{String, Void}["a", nothing]] isa Vector{Union{String, Void}} + @test [s for s in Union{String, Void}["a"]] isa Vector{Union{String, Void}} + @test [s for s in Vector{Union{String, Void}}()] isa Vector{Union{String, Void}} + @testset "comprehension in let-bound function" begin let x⊙y = sum([x[i]*y[i] for i=1:length(x)]) @test [1,2] ⊙ [3,4] == 11 diff --git a/test/functional.jl b/test/functional.jl index 5256a54167473..e86dad67186f4 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -142,3 +142,45 @@ end for n = 0:5:100-q-d for p = 100-q-d-n if p < n < d < q] == [(50,30,15,5), (50,30,20,0), (50,40,10,0), (75,20,5,0)] + +@testset "return type of map() and collect() on generators" begin + x = ["a", "b"] + res = @inferred collect(s for s in x) + @test res isa Vector{String} + res = @inferred map(identity, x) + @test res isa Vector{String} + res = @inferred collect(s === nothing for s in x) + @test res isa Vector{Bool} + res = @inferred map(s -> s === nothing, x) + @test res isa Vector{Bool} + + y = Union{String, Void}["a", nothing] + f(::Void) = nothing + f(s::String) = s == "a" + res = @inferred collect(s for s in y) + @test res isa Vector{Union{String, Void}} + res = @inferred map(identity, y) + @test res isa Vector{Union{String, Void}} + res = @inferred collect(s === nothing for s in y) + @test res isa Vector{Bool} + res = @inferred map(s -> s === nothing, y) + @test res isa Vector{Bool} + res = @inferred collect(f(s) for s in y) + @test res isa Vector{Union{Bool, Void}} + res = @inferred map(f, y) + @test res isa Vector{Union{Bool, Void}} + + y[2] = "c" + res = @inferred collect(s for s in y) + @test res isa Vector{Union{String, Void}} + res = @inferred map(identity, y) + @test res isa Vector{Union{String, Void}} + res = @inferred collect(s === nothing for s in y) + @test res isa Vector{Bool} + res = @inferred map(s -> s === nothing, y) + @test res isa Vector{Bool} + res = @inferred collect(f(s) for s in y) + @test res isa Vector{Union{Bool, Void}} + res = @inferred map(f, y) + @test res isa Vector{Union{Bool, Void}} +end