Skip to content

Commit

Permalink
Allow constrained type parameters in variant (#78)
Browse files Browse the repository at this point in the history
* Allow constrained type parameters in variant

* Update sum_type.jl

* Update sum_type.jl

* Update sum_type.jl

* Update runtests.jl

* more tests

* Update runtests.jl
  • Loading branch information
Tortar authored May 21, 2024
1 parent 6da8e18 commit dd4f806
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/sum_type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ function _sum_type(T, hidden, blk)
T_nameparam = isempty(T_params) ? T : :($T_name{$(T_params...)})
filter!(x -> !(x isa LineNumberNode), blk.args)

constructors = generate_constructor_data(T_name, T_params, T_params_constrained, T_nameparam, hide_variants, blk)
constructors = generate_constructor_data(T_name, T_params, T_nameparam, hide_variants, blk)

variants_params = [nt.params for nt in constructors]
for p in setdiff(union(variants_params...), intersect(variants_params...))
i = findfirst(x -> x == p, T_params)
T_params_constrained[i] isa Symbol && continue
T_p_args = T_params_constrained[i].args
T_p_args[2] = :(Union{Uninit, $(T_p_args[2])})
end

if !allunique(map(x -> x.name, constructors))
error("constructors must have unique names, got $(map(x -> x.name, constructors))")
Expand All @@ -41,7 +49,7 @@ end

#------------------------------------------------------

function generate_constructor_data(T_name, T_params, T_params_constrained, T_nameparam, hide_variants, blk::Expr)
function generate_constructor_data(T_name, T_params, T_nameparam, hide_variants, blk::Expr)
constructors = []
for con_ blk.args
con_ isa LineNumberNode && continue
Expand Down Expand Up @@ -71,7 +79,8 @@ function generate_constructor_data(T_name, T_params, T_params_constrained, T_nam
con::Expr = con_
con.head == :call || throw(ArgumentError("Malformed variant $con_"))
con_name = con.args[1] isa Expr && con.args[1].head == :curly ? con.args[1].args[1] : con.args[1]
con_params = (con.args[1] isa Expr && con.args[1].head == :curly) ? con.args[1].args[2:end] : []
con_params_constrained = (con.args[1] isa Expr && con.args[1].head == :curly) ? con.args[1].args[2:end] : []
con_params = Any[p isa Expr && p.head == :(<:) ? p.args[1] : p for p in con_params_constrained]
issubset(con_params, T_params) ||
error("constructor parameters ($con_params) for $con_name, not a subset of sum type parameters $T_params")
con_params_uninit = let v = copy(con_params)
Expand All @@ -82,7 +91,6 @@ function generate_constructor_data(T_name, T_params, T_params_constrained, T_nam
end
v
end
con_params_constrained = [T_params_constrained[i] for i eachindex(con_params_uninit) if con_params_uninit[i] != Uninit]
con_nameparam = isempty(con_params) ? con_name : :($con_name{$(con_params...)})
con_field_names = map(enumerate(con.args[2:end])) do (i, field)
@assert field isa Symbol || (field isa Expr && field.head == :(::)) "malformed constructor field $field"
Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,21 @@ end
@test WT{Int,Int}[T(1), W(1)] isa Vector{WT{Int, Int}}
@test [T(1), W(1)] isa Vector{WT}
end

#---------------

@testset "Constrained type parameters in variants" begin
@sum_type QPM{X<:Real, Y<:Real, Z} begin
Q{X<:AbstractFloat, Z}(a::X, b::Z)
P{Y<:Integer}(b::Y)
M{Z<:AbstractArray}(a::Z)
end
@test Q{Float64, Bool}(1.0, true) isa QPM{Float64, SumTypes.Uninit, Bool}
@test Q(1.0, true) isa QPM{Float64, SumTypes.Uninit, Bool}
@test P{Int}(1.0) isa QPM{SumTypes.Uninit, Int, SumTypes.Uninit}
@test P(1) isa QPM{SumTypes.Uninit, Int, SumTypes.Uninit}
@test M(Int[]) isa QPM{SumTypes.Uninit, SumTypes.Uninit, Vector{Int}}
@test_throws TypeError Q{Int, Bool}(1.0, true)
@test_throws TypeError P{Float64}(1.0)
@test_throws TypeError M{Bool}(true)
end

0 comments on commit dd4f806

Please sign in to comment.