From be94ece266bf02a0a4fecaa4dc294b1a1c67792e Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 30 Sep 2017 19:26:12 +0200 Subject: [PATCH] Preserve concrete Union types in collect() with generators When return_type() gives a Union of concrete types, use it as the element type of the array instead of the type of the first element and progressively making it broader as needed (often ending with element type Any). This means that no reallocation will be needed anymore if/when encoutering an element with a new type. Note that the inferred element type may still be broader than the actual contents of the array, for example with (x for x in Union{Int, Void}[1]). Using a Union element type also has the advantage of using the more efficient layout for isbitsunion types and of allowing for more efficient generated code when using the resulting array. This is particularly noticeable with array comprehensions, which inherit the behavior of collect(). --- base/array.jl | 16 ++++++++++++++-- test/arrayops.jl | 4 ++++ test/functional.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/base/array.jl b/base/array.jl index 91e9f800f8cb4..76d1883d17af5 100644 --- a/base/array.jl +++ b/base/array.jl @@ -664,6 +664,14 @@ end _array_for(::Type{T}, itr, ::HasLength) where {T} = Array{T,1}(Int(length(itr)::Integer)) _array_for(::Type{T}, itr, ::HasShape) where {T} = similar(Array{T}, indices(itr)) +function _isconcreteunion(::Type{T}) where T + if T isa Union + _isconcreteunion(T.a) && _isconcreteunion(T.b) + else + isconcrete(T) + end +end + function collect(itr::Generator) isz = iteratorsize(itr.iter) et = _default_eltype(typeof(itr)) @@ -675,7 +683,9 @@ function collect(itr::Generator) return _array_for(et, itr.iter, isz) end v1, st = next(itr, st) - collect_to_with_first!(_array_for(typeof(v1), itr.iter, isz), v1, itr, st) + S = _default_eltype(typeof(itr)) + T = _isconcreteunion(S) ? S : typeof(v1) + collect_to_with_first!(_array_for(T, itr.iter, isz), v1, itr, st) end end @@ -688,7 +698,9 @@ function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape}) return _similar_for(c, _default_eltype(typeof(itr)), itr, isz) end v1, st = next(itr, st) - collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st) + S = _default_eltype(typeof(itr)) + T = _isconcreteunion(S) ? S : typeof(v1) + collect_to_with_first!(_similar_for(c, T, itr, isz), v1, itr, st) end function collect_to_with_first!(dest::AbstractArray, v1, itr, st) diff --git a/test/arrayops.jl b/test/arrayops.jl index 70997260272ea..627afde46758f 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -1207,6 +1207,10 @@ end @test isequal([1,2,3], [a for (a,b) in enumerate(2:4)]) @test isequal([2,3,4], [b for (a,b) in enumerate(2:4)]) + @test [s for s in Union{String, Void}["a", nothing]] isa Vector{Union{String, Void}} + @test [s for s in Union{String, Void}["a"]] isa Vector{Union{String, Void}} + @test [s for s in Vector{Union{String, Void}}()] isa Vector{Union{String, Void}} + @testset "comprehension in let-bound function" begin let x⊙y = sum([x[i]*y[i] for i=1:length(x)]) @test [1,2] ⊙ [3,4] == 11 diff --git a/test/functional.jl b/test/functional.jl index 5256a54167473..e86dad67186f4 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -142,3 +142,45 @@ end for n = 0:5:100-q-d for p = 100-q-d-n if p < n < d < q] == [(50,30,15,5), (50,30,20,0), (50,40,10,0), (75,20,5,0)] + +@testset "return type of map() and collect() on generators" begin + x = ["a", "b"] + res = @inferred collect(s for s in x) + @test res isa Vector{String} + res = @inferred map(identity, x) + @test res isa Vector{String} + res = @inferred collect(s === nothing for s in x) + @test res isa Vector{Bool} + res = @inferred map(s -> s === nothing, x) + @test res isa Vector{Bool} + + y = Union{String, Void}["a", nothing] + f(::Void) = nothing + f(s::String) = s == "a" + res = @inferred collect(s for s in y) + @test res isa Vector{Union{String, Void}} + res = @inferred map(identity, y) + @test res isa Vector{Union{String, Void}} + res = @inferred collect(s === nothing for s in y) + @test res isa Vector{Bool} + res = @inferred map(s -> s === nothing, y) + @test res isa Vector{Bool} + res = @inferred collect(f(s) for s in y) + @test res isa Vector{Union{Bool, Void}} + res = @inferred map(f, y) + @test res isa Vector{Union{Bool, Void}} + + y[2] = "c" + res = @inferred collect(s for s in y) + @test res isa Vector{Union{String, Void}} + res = @inferred map(identity, y) + @test res isa Vector{Union{String, Void}} + res = @inferred collect(s === nothing for s in y) + @test res isa Vector{Bool} + res = @inferred map(s -> s === nothing, y) + @test res isa Vector{Bool} + res = @inferred collect(f(s) for s in y) + @test res isa Vector{Union{Bool, Void}} + res = @inferred map(f, y) + @test res isa Vector{Union{Bool, Void}} +end