diff --git a/Project.toml b/Project.toml index a353b43..f793e81 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SumTypes" uuid = "8e1ec7a9-0e02-4297-b0fe-6433085c89f2" authors = ["MasonProtter "] -version = "0.3.8" +version = "0.4.0" [deps] MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/README.md b/README.md index cd8e9af..ec37bcd 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ - [Basics](https://github.com/MasonProtter/SumTypes.jl#basics) - [Destructuring sum types](https://github.com/MasonProtter/SumTypes.jl#destructuring-sum-types) +- [Using `full_type` to get the concrete type of a Sum Type](https://github.com/MasonProtter/SumTypes.jl/tree/compute-storage#using-full_type-to-get-the-concrete-type-of-a-sum-type) - [Avoiding namespace clutter](https://github.com/MasonProtter/SumTypes.jl#avoiding-namespace-clutter) - [Custom printing](https://github.com/MasonProtter/SumTypes.jl#custom-printing) - [Performance](https://github.com/MasonProtter/SumTypes.jl#performance) @@ -162,6 +163,45 @@ The `@cases` macro still falls far short of a full on pattern matching system, l +## Using `full_type` to get the concrete type of a Sum Type + +
+Click to expand + +SumTypes.jl generates structs with a compactified memory layout which is computed on demand for parametric types. Because of this, +every SumTypes actually has two extra type parameters related to its memory layout. This means that for instance, `Either{Int, Int}`: + +``` julia +julia> @sum_type Either{A, B} begin + Left{A}(::A) + Right{B}(::B) + end + +julia> isconcretetype(Either{Int, Int}) +false +``` + +In order to get the proper, concrete type corresponding to `Either{Int, Int}`, one can just use the `full_type` function exported by SumTypes.jl: + +``` julia +julia> full_type(Either{Int, Int}) +Either{Int64, Int64, 8, 0} + +julia> full_type(Either{Int, String}) +Either{Int64, String, 8, 1} + +julia> full_type(Either{Tuple{Int, Int, Int}, String}) +Either{Tuple{Int64, Int64, Int64}, String, 24, 1} + +julia> isconcretetype(ans) +true +``` + +Avoiding these extra parameters would require https://github.com/JuliaLang/julia/issues/8472 to be implemented. + +
+ + ## Avoiding namespace clutter
@@ -307,15 +347,15 @@ end ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. - Range (min … max): 393.061 μs … 6.725 ms ┊ GC (min … max): 0.00% … 90.92% - Time (median): 434.257 μs ┊ GC (median): 0.00% - Time (mean ± σ): 483.461 μs ± 435.758 μs ┊ GC (mean ± σ): 9.38% ± 9.39% + Range (min … max): 267.399 μs … 3.118 ms ┊ GC (min … max): 0.00% … 90.36% + Time (median): 278.904 μs ┊ GC (median): 0.00% + Time (mean ± σ): 316.971 μs ± 306.290 μs ┊ GC (mean ± σ): 11.68% ± 10.74% - █▅▃▁ ▁ - █████▆▆▃▃▁▁▃▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▃▅ █ - 393 μs Histogram: log(frequency) by time 4.2 ms < + █ ▁ + █▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇ █ + 267 μs Histogram: log(frequency) by time 2.77 ms < - Memory estimate: 654.16 KiB, allocs estimate: 21950. + Memory estimate: 654.75 KiB, allocs estimate: 21952. ``` SumTypes.jl @@ -358,13 +398,13 @@ end ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. - Range (min … max): 61.309 μs … 83.300 μs ┊ GC (min … max): 0.00% … 0.00% - Time (median): 62.350 μs ┊ GC (median): 0.00% - Time (mean ± σ): 62.376 μs ± 528.152 ns ┊ GC (mean ± σ): 0.00% ± 0.00% + Range (min … max): 54.890 μs … 73.650 μs ┊ GC (min … max): 0.00% … 0.00% + Time (median): 55.750 μs ┊ GC (median): 0.00% + Time (mean ± σ): 55.908 μs ± 655.652 ns ┊ GC (mean ± σ): 0.00% ± 0.00% - ▃█▂ ▁▄▃▂ - ▂▁▁▁▁▁▁▁▁▂▁▂▃▅▅▇███▆▄▃▃▄▄▇████▅▄▃▂▂▂▁▂▂▂▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▃ - 61.3 μs Histogram: frequency by time 64 μs < + ▁▄▇██▇▆▅▄ ▂▁ + ▁▁▁▁▂▃▄▇████████████▇▆▅▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃ + 54.9 μs Histogram: frequency by time 58.4 μs < Memory estimate: 0 bytes, allocs estimate: 0. ``` @@ -423,25 +463,24 @@ end ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. - Range (min … max): 69.355 μs … 234.343 μs ┊ GC (min … max): 0.00% … 0.00% - Time (median): 75.388 μs ┊ GC (median): 0.00% - Time (mean ± σ): 77.753 μs ± 13.757 μs ┊ GC (mean ± σ): 0.00% ± 0.00% + Range (min … max): 54.470 μs … 67.920 μs ┊ GC (min … max): 0.00% … 0.00% + Time (median): 55.640 μs ┊ GC (median): 0.00% + Time (mean ± σ): 55.692 μs ± 498.787 ns ┊ GC (mean ± σ): 0.00% ± 0.00% - █▃▄▃▇▆▆▃▄ ▁ ▂ - ██████████▇▇█▇▇█████▆▆▆▆▆▆▇▆▅▇▆▆▅▄▆▇▆▆▆▅▇▆▅▆▅▄▅▄▄▅▄▆▅▅▅▅▅▅▅▅ █ - 69.4 μs Histogram: log(frequency) by time 149 μs < + ▁▂▄▅▆▆▇▇▇█▅▅▃▂▂ + ▁▁▁▁▁▁▂▃▄▃▄▅▆▇▇████████████████▇▆▆▅▄▃▃▂▂▂▂▂▂▂▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁ ▄ + 54.5 μs Histogram: frequency by time 57.5 μs < Memory estimate: 0 bytes, allocs estimate: 0. ``` -SumTypes.jl is able to slightly beat Unityper.jl in this benckmark, though there are cases where the roles are reversed. -SumTypes.jl has some other advantages relative to Unityper.jl too, such as: -- SumTypes.jl allows [parametric types](https://docs.julialang.org/en/v1/manual/types/#Parametric-Types) for much greater container flexibility (Unityper does some memory layout optimizations that won't work with parametric types). -- SumTypes.jl does not require default values for every field of the struct +SumTypes.jl and Unityper.jl are about equal in this benchmark, though there are cases where there are differences. +SumTypes.jl has some other advantages relative to Unityper.jl such as: +- SumTypes.jl allows [parametric types](https://docs.julialang.org/en/v1/manual/types/#Parametric-Types) for much greater container flexibility. +- SumTypes.jl does not require default values for every field of the struct. - SumTypes.jl's `@cases` macro is more powerful and flexible than Unityper's `@compactified`. - SumTypes.jl allows you to hide its variants from the namespace (opt in). -Whereas some advantages of Unityper.jl are: -- A `@compactified` type from Unityper.jl will often have a smaller memory footprint than a corresponding type from SumTypes.jl -- If we had used `D(;common_field=1, b="hi")` in our benchmarks, SumTypes.jl could have incurred an allocation whereas Unitypeper.jl would not. This allocation is due to the compiler heuristics involved in `::Union{T, Nothing}` fields of structs and may be fixed in future versions of julia. +One advantage of Unityper.jl is: +- Because Unityper.jl doesn't allow parameterized types and needs to know all type information at macroexpansion time, their structs have a fixed layout for boxed variables that lets them avoid an allocation when storing heap allocated objects (this allocation would be in addition to the heap allocation for the object itself). If we had used `D(;common_field=1, b="hi")` in our benchmarks, SumTypes.jl could have incurred an allocation whereas Unityper.jl would not. As far as I know, this would requre https://github.com/JuliaLang/julia/issues/8472 in order to avoid in SumTypes.jl diff --git a/src/SumTypes.jl b/src/SumTypes.jl index 1abad69..605b9c4 100644 --- a/src/SumTypes.jl +++ b/src/SumTypes.jl @@ -1,6 +1,6 @@ module SumTypes -export @sum_type, @cases, Uninit +export @sum_type, @cases, Uninit, full_type using MacroTools: MacroTools @@ -8,6 +8,7 @@ function parent end function constructors end function constructor end function constructors_Union end +function variants_Tuple end function unwrap end function tags end function deparameterize end @@ -16,6 +17,9 @@ function flagtype end function flag_to_symbol end function symbol_to_flag end function tags_flags_nt end +function variants_Tuple end +function strip_size_params end +function full_type end struct Unsafe end @@ -26,8 +30,6 @@ struct Uninit end struct Variant{fieldnames, Tup <: Tuple} data::Tup Variant{fieldnames, Tup}(::Unsafe) where {fieldnames, Tup} = new{fieldnames, Tup}() - # Variant(::Unsafe, nt::NamedTuple{names, Tup}) where {names, Tup} = new{fieldnames, Tup}(Tuple(nt)) - # Variant{fieldnames}(t::Tup) where {fieldnames, Tup <: Tuple} = new{fieldnames, Tup}(t) Variant{fieldnames, Tup}(t::Tuple) where {fieldnames, Tup <: Tuple} = new{fieldnames, Tup}(t) end Base.:(==)(v1::Variant, v2::Variant) = v1.data == v2.data @@ -37,20 +39,27 @@ Base.indexed_iterate(x::Variant, i::Int, state=1) = (Base.@_inline_meta; (getfie const tag = Symbol("#tag#") get_tag(x) = getfield(x, tag) -get_tag_sym(x::T) where {T} = keys(tags_flags_nt(T))[Int(get_tag(x))] +get_tag_sym(x::T) where {T} = keys(tags_flags_nt(T))[Int(get_tag(x)) + 1] show_sumtype(io::IO, m::MIME, x) = show_sumtype(io, x) function show_sumtype(io::IO, x::T) where {T} tag = get_tag(x) sym = flag_to_symbol(T, tag) - if getfield(x, sym) isa Variant{(), Tuple{}} - print(io, String(sym), "::", typeof(x)) + T_stripped = if length(T.parameters) == 2 + String(T.name.name) else - print(io, String(sym), '(', join((repr(data) for data ∈ getfield(x, sym)), ", "), ")::", typeof(x)) + string(String(T.name.name), "{", join(repr.(T.parameters[1:end-2]), ", "), "}") + end + if unwrap(x) isa Variant{(), Tuple{}} + print(io, String(sym), "::", T_stripped) + else + print(io, String(sym), '(', join((repr(data) for data ∈ unwrap(x)), ", "), ")::", T_stripped) end end +include("compute_storage.jl") include("sum_type.jl") # @sum_type defined here include("cases.jl") # @cases defined here + end # module diff --git a/src/cases.jl b/src/cases.jl index f140ea6..0b61fbc 100644 --- a/src/cases.jl +++ b/src/cases.jl @@ -58,7 +58,7 @@ macro cases(to_match, block) ex = :(if $get_tag($data) === $symbol_to_flag($Typ, $(QuoteNode(stmts[1].variant))); $(stmts[1].iscall ? :(($(stmts[1].fieldnames...),) = - $getfield($data, $(QuoteNode(stmts[1].variant))) :: $constructor($Typ, $Val{$(QuoteNode(stmts[1].variant))} ) ) : nothing); + $unwrap($data, $constructor($Typ, $Val{$(QuoteNode(stmts[1].variant))}), $variants_Tuple($Typ)) ) : nothing); $(stmts[1].rhs) end) Base.remove_linenums!(ex) @@ -67,7 +67,7 @@ macro cases(to_match, block) for i ∈ 2:length(stmts) _if = :(if $get_tag($data) === $symbol_to_flag($Typ, $(QuoteNode(stmts[i].variant))); $(stmts[i].iscall ? :(($(stmts[i].fieldnames...),) = - $getfield($data, $(QuoteNode(stmts[i].variant))) :: $constructor($Typ, $Val{$(QuoteNode(stmts[i].variant))} )) : nothing); + $unwrap($data, $constructor($Typ, $Val{$(QuoteNode(stmts[i].variant))}), $variants_Tuple($Typ))) : nothing); $(stmts[i].rhs) end) _if.head = :elseif @@ -82,7 +82,6 @@ macro cases(to_match, block) let $data = $to_match $Typ = $typeof($data) $check_sum_type($Typ) - # $nt = $tags_flags_nt($Typ) $assert_exhaustive(Val{$tags($Typ)}, Val{$(Expr(:tuple, QuoteNode.(deparameterize.(variants))...))}) $ex end diff --git a/src/compute_storage.jl b/src/compute_storage.jl new file mode 100644 index 0000000..5e360c2 --- /dev/null +++ b/src/compute_storage.jl @@ -0,0 +1,123 @@ +struct PlaceHolder end + +macro assume_effects(args...) + if isdefined(Base, Symbol("@assume_effects")) + ex = :($Base.@assume_effects($(args...))) + else + ex = args[end] + end + esc(ex) +end + +@assume_effects :consistent :foldable function unsafe_padded_reinterpret(::Type{T}, x::U) where {T, U} + @assert isbitstype(T) && isbitstype(U) + n, m = sizeof(T), sizeof(U) + if sizeof(U) < sizeof(T) + payload = (x, ntuple(_ -> zero(UInt8), Val(n-m)), ) + else + payload = x + end + let r = Ref(payload) + GC.@preserve r begin + p = pointer_from_objref(r) + unsafe_load(Ptr{T}(p)) + end + end +end + +function extract_info(::Type{ST}, variants) where {ST} + + data = map(variants) do variant + (names, store_types) = variant.parameters + bits = [] + ptrs = [] + @assert length(names) == length(store_types.parameters) + foreach(zip(names, store_types.parameters)) do (name, T) + if isbitstype(T) + push!(bits, name => T) + else + push!(bits, name => SumTypes.PlaceHolder) + push!(ptrs, name => T) + end + end + bits, ptrs + end + bitss = map(x -> x[1], data) + ptrss = map(x -> x[2], data) + nptrs = maximum(length, ptrss) + ptr_names = map(v -> map(x -> x[1], v), ptrss) + bit_names = map(v -> map(x -> x[1], v), bitss) + bit_sigs = map(v -> map(x -> x[2], v), bitss) + + FT = fieldtype(ST, 3) + bit_size = if nptrs == 0 + maximum(v -> sizeof(Tuple{map(x -> x[2], v)..., FT}), bitss) - sizeof(FT) + else + maximum(v -> sizeof(Tuple{map(x -> x[2], v)..., }), bitss) + end + + (; + bitss = bitss, + ptrss = ptrss, + nptrs = nptrs, + ptr_names = ptr_names, + bit_size = bit_size, + bit_names = bit_names, + bit_sigs = bit_sigs, + ) +end + + +make(::Type{ST}, to_make, tag) where {ST} = make(ST, to_make, tag, variants_Tuple(ST)) +@generated function make(::Type{ST}, to_make::Var, tag, ::Type{var_Tuple}) where {ST, Var <: Variant, var_Tuple <: Tuple} + variants = var_Tuple.parameters + i = findfirst(==(Var), variants) + nt = extract_info(ST, variants) + + nptrs = nt.nptrs + ptr_names = nt.ptr_names + bit_size = nt.bit_size + bit_names = nt.bit_names + bit_sigs = nt.bit_sigs + + bitvariant = :(SumTypes.Variant{($(QuoteNode.(bit_names[i])...),), Tuple{$(bit_sigs[i]...)}}( + ($(([bit_sigs[i][j] == PlaceHolder ? PlaceHolder() : :(to_make.data[$j]) for j ∈ eachindex(bit_sigs[i]) ])...),) )) + ptr_args = [:(to_make.data[$j]) for j ∈ eachindex(bit_names[i]) if bit_names[i][j] ∈ ptr_names[i]] + con = Expr( + :new, + ST{bit_size, nptrs}, + :(unsafe_padded_reinterpret(NTuple{$bit_size, UInt8}, $bitvariant)), + Expr(:tuple, ptr_args..., (nothing for _ ∈ 1:(nptrs-length(ptr_args)))...), + :tag, + ) +end + + + +unwrap(x::ST, var) where {ST} = unwrap(x, var, variants_Tuple(ST)) +@generated function unwrap(x::ST, ::Type{Var}, ::Type{var_Tuple}) where {ST, Var, var_Tuple} + variants = var_Tuple.parameters + i = findfirst(==(Var), variants) + nt = extract_info(ST, variants) + ptrss = nt.ptrss + nptrs = nt.nptrs + ptr_names = nt.ptr_names + bit_size = nt.bit_size + bit_names = nt.bit_names + bit_sigs = nt.bit_sigs + quote + names = ($(QuoteNode.(bit_names[i])...),) + bits = unsafe_padded_reinterpret(Variant{names, Tuple{$(bit_sigs[i]...)}}, x.bits) + args = $(Expr(:tuple, + (bit_names[i][j] ∈ ptr_names[i] ? let k = findfirst(x -> x == bit_names[i][j], ptr_names[i]) + :(x.ptrs[$k]:: $(ptrss[i][k][2])) + end : :(bits.data[$j]) for j ∈ eachindex(bit_names[i]))...)) + Variant{names, $(Var.parameters[2])}(args) + end +end + +Base.@generated function full_type(::Type{ST}, ::Type{var_Tuple}) where {ST, var_Tuple} + variants = var_Tuple.parameters + nt = extract_info(ST, variants) + :($ST{$(nt.bit_size), $(nt.nptrs)}) +end diff --git a/src/sum_type.jl b/src/sum_type.jl index 9679c28..1401a80 100644 --- a/src/sum_type.jl +++ b/src/sum_type.jl @@ -1,5 +1,9 @@ macro sum_type(T, blk, _hide_variants=:(hide_variants = false)) + esc(_sum_type(T, blk, _hide_variants)) +end + +function _sum_type(T, blk, _hide_variants=:(hide_variants = false)) if _hide_variants isa Expr && _hide_variants.head == :(=) && _hide_variants.args[1] == :hide_variants hide_variants = _hide_variants.args[2] else @@ -7,28 +11,28 @@ macro sum_type(T, blk, _hide_variants=:(hide_variants = false)) end @assert blk isa Expr && blk.head == :block - T_name, T_params, T_params_constrained = if T isa Symbol - T, [], [] + T_name, T_params, T_params_constrained, T_param_bounds = if T isa Symbol + T, [], [], [] elseif T isa Expr && T.head == :curly - T.args[1], (x -> x isa Expr && x.head == :(<:) ? x.args[1] : x).(T.args[2:end]), T.args[2:end] + T.args[1], (x -> x isa Expr && x.head == :(<:) ? x.args[1] : x).(T.args[2:end]), T.args[2:end], (x -> x isa Expr && x.head == :(<:) ? x.args[2] : Any).(T.args[2:end]) end T_nameparam = isempty(T_params) ? T : :($T_name{$(T_params...)}) filter!(x -> !(x isa LineNumberNode), blk.args) - constructors = generate_constructor_data(T_name, T_params, T_params_constrained, T_nameparam, hide_variants, hash(__module__), blk) + constructors = generate_constructor_data(T_name, T_params, T_params_constrained, T_nameparam, hide_variants, blk) if !allunique(map(x -> x.name, constructors)) error("constructors must have unique names, got $(map(x -> x.name, constructors))") end con_expr = generate_constructor_exprs(T_name, T_params, T_params_constrained, T_nameparam, constructors) - out = generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_nameparam, constructors) - Expr(:toplevel, out, con_expr) |> esc + out = generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_param_bounds, T_nameparam, constructors) + Expr(:toplevel, out, con_expr) end #------------------------------------------------------ -function generate_constructor_data(T_name, T_params, T_params_constrained, T_nameparam, hide_variants, hsh, blk::Expr) +function generate_constructor_data(T_name, T_params, T_params_constrained, T_nameparam, hide_variants, blk::Expr) constructors = [] for con_ ∈ blk.args con_ isa LineNumberNode && continue @@ -56,14 +60,11 @@ function generate_constructor_data(T_name, T_params, T_params_constrained, T_nam push!(constructors, nt) else con::Expr = con_ - if con.head != :call - error("Malformed variant $con_") - end + con.head == :call || throw(ArgumentError("Malformed variant $con_")) con_name = con.args[1] isa Expr && con.args[1].head == :curly ? con.args[1].args[1] : con.args[1] con_params = (con.args[1] isa Expr && con.args[1].head == :curly) ? con.args[1].args[2:end] : [] - if !issubset(con_params, T_params) + issubset(con_params, T_params) || error("constructor parameters ($con_params) for $con_name, not a subset of sum type parameters $T_params") - end con_params_uninit = let v = copy(con_params) for i ∈ eachindex(T_params) if T_params[i] ∉ con_params @@ -84,6 +85,8 @@ function generate_constructor_data(T_name, T_params, T_params_constrained, T_nam field.args[1] end end + unique(con_field_names) == con_field_names || error("constructor field names must be unique, got $(con_field_names) for constructor $con_name") + con_field_types = map(con.args[2:end]) do field @assert field isa Symbol || (field isa Expr && field.head == :(::)) "malformed constructor field $field" if field isa Symbol @@ -151,38 +154,22 @@ function generate_constructor_exprs(T_name, T_params, T_params_constrained, T_na T_uninit = isempty(T_params) ? T_name : :($T_name{$(params_uninit...)}) T_init = isempty(T_params) ? T_name : :($T_name{$(T_params...)}) if value - T_con_fields = map(constructors) do nt - if nt.value - :($(nt.store_type_uninit)($unsafe)) - else - nothing - end - end ex = quote - const $gname = $(Expr(:new, T_uninit, T_con_fields..., Expr(:call, symbol_to_flag, T_name, QuoteNode(name)) )) - end + const $gname = $(Expr(:call, make, T_uninit, :($(nt.store_type_uninit)($unsafe)), Expr(:call, symbol_to_flag, T_name, QuoteNode(name)) )) + end push!(out.args, ex) else field_names_typed = map(((name, type),) -> :($name :: $type), zip(field_names, field_types)) - - T_con_fields = map(constructors) do nt#(_name, _, _nameparam, _, _, _, _, value, _gname, _gnameparam) - - default = nt.value ? :($(nt.store_type_uninit)($unsafe)) : nothing - name == nt.name ? :($store_type(($(field_names...),))) : default - end T_con = :($gouter_type($(field_names_typed...)) where {$(params_constrained...)} = - $(Expr(:new, T_uninit, T_con_fields..., Expr(:call, symbol_to_flag, T_name, QuoteNode(name)) ) )) + $(Expr(:call, make, T_uninit, :($store_type(($(field_names...),))), Expr(:call, symbol_to_flag, T_name, QuoteNode(name)) ))) - T_con_fields2 = map(constructors) do nt - default = nt.value ? :($(nt.store_type_uninit)($unsafe)) : nothing + T_con2 = if !all(x -> x ∈ (Any, :Any) ,field_types) s = Expr(:call, store_type, Expr(:tuple, [:($convert($field_type, $field_name)) for (field_type, field_name) ∈ zip(field_types, field_names)]...)) - nt.name == name ? s : default - end - T_con2 = if !all(x -> x ∈ (Any, :Any) ,field_types) + :($gouter_type($(field_names...)) where {$(params_constrained...)} = - $(Expr(:new, T_uninit, T_con_fields2..., Expr(:call, symbol_to_flag, T_name, QuoteNode(name)) ))) + $(Expr(:call, make, T_uninit, s, Expr(:call, symbol_to_flag, T_name, QuoteNode(name))))) end maybe_no_param = if !isempty(params) :($gname($(field_names_typed...)) where {$(params...)} = $gouter_type($(field_names...))) @@ -198,19 +185,11 @@ function generate_constructor_exprs(T_name, T_params, T_params_constrained, T_na push!(out.args, ex) end enumerate_constructors = collect(enumerate(constructors)) - if_nest = mapfoldr(((cond, data), old) -> Expr(:if, cond, data, old), enumerate_constructors, init=:(error("invalid tag"))) do (i , nt) - name = nt.name - data = map(constructors) do nt - default = nt.value ? :($(nt.store_type_uninit)($unsafe)) : nothing - nt.name == name ? :($getfield(x, $(QuoteNode(name))) :: $(nt.store_type)) : default - end - :(tag == $i), Expr(:new, T_init, data..., :tag) - end + if true push!(converts, T_uninit => quote - $Base.convert(::Type{$T_init}, x::$T_uninit) where {$(T_params...)} = $(Expr(:block, - :(tag = getfield(x, $(QuoteNode(tag)) )), - if_nest )) + $Base.convert(::Type{$T_init}, x::$T_uninit) where {$(T_params...)} = + $make($T_init, $unwrap(x), $getfield(x, $(QuoteNode(tag)) )) $T_init(x::$T_uninit) where {$(T_params...)} = $convert($T_init, x) end) end @@ -225,68 +204,69 @@ end #------------------------------------------------------ -function generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_nameparam, constructors) +function generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_param_bounds, T_nameparam, constructors) con_outer_types = (x -> x.outer_type ).(constructors) con_gouter_types = (x -> x.gouter_type).(constructors) con_names = (x -> x.name ).(constructors) con_gnames = (x -> x.gname ).(constructors) - flagtype = length(constructors) <= typemax(UInt8) ? UInt8 : length(constructors) < typemax(UInt16) ? UInt16 : length(constructors) <= typemax(UInt32) ? UInt32 : + flagtype = length(constructors) < typemax(UInt8) ? UInt8 : length(constructors) < typemax(UInt16) ? UInt16 : + length(constructors) <= typemax(UInt32) ? UInt32 : error("Too many variants in SumType, got $(length(constructors)). The current maximum number is $(typemax(UInt32) |> Int)") - - data_fields = map(constructors) do nt - name = nt.name - store_type = nt.store_type - if nt.value - :($name :: $store_type) - else - :($name :: Union{$Nothing, $store_type}) - end - end - - sum_struct_def = Expr(:struct, false, T, Expr(:block, data_fields..., :($tag :: $flagtype), :(1 + 1))) + + N = Symbol("#N#") + M = Symbol("#M#") + T_full = T isa Expr && T.head == :curly ? Expr(:curly, T.args..., N, M) : Expr(:curly, T, N, M) + sum_struct_def = Expr(:struct, false, T_full, + Expr(:block, :(bits :: $NTuple{$N, $UInt8}), :(ptrs :: $NTuple{$M, $Any}), :($tag :: $flagtype), :(1 + 1))) enumerate_constructors = collect(enumerate(constructors)) if_nest_unwrap = mapfoldr(((cond, data), old) -> Expr(:if, cond, data, old), enumerate_constructors, init=:(error("invalid tag"))) do (i, nt) - :(tag == $i), :($getfield(x, $(QuoteNode(nt.name)))) + :(tag == $(flagtype(i-1))), :($unwrap(x, $(nt.store_type))) end only_define_with_params = if !isempty(T_params) quote - $SumTypes.constructors(::Type{$T_nameparam}) where {$(T_params...)} = + $SumTypes.constructors(::Type{<:$T_nameparam}) where {$(T_params...)} = $NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.store_type for nt ∈ constructors)...))) - $Base.adjoint(::Type{$T_nameparam}) where {$(T_params...)} = + $Base.adjoint(::Type{<:$T_nameparam}) where {$(T_params...)} = $NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.value ? :($T_nameparam($(nt.gname))) : nt.gouter_type for nt ∈ constructors)...))) - end + $SumTypes.variants_Tuple(::Type{<:$T_nameparam}) where {$(T_params...)} = + $Tuple{$((nt.store_type for nt ∈ constructors)...)} + $SumTypes.full_type(::Type{$T_name}) = $full_type($T_name{$(T_param_bounds...)}, $variants_Tuple($T_nameparam{$(T_param_bounds...)})) + end end - ex = quote $sum_struct_def $SumTypes.is_sumtype(::Type{<:$T_name}) = true - + $SumTypes.strip_size_params(::Type{$T_name{$(T_params...), $N, $M}}) where {$(T_params...), $N, $M} = $T_nameparam $SumTypes.flagtype(::Type{<:$T_name}) = $flagtype $SumTypes.symbol_to_flag(::Type{<:$T_name}, sym::Symbol) = $(foldr(collect(enumerate(con_names)), init=:(error("Invalid tag symbol $sym"))) do (i, _sym), old - Expr(:if, :(sym == $(QuoteNode(_sym))), flagtype(i), old) + Expr(:if, :(sym == $(QuoteNode(_sym))), flagtype(i-1), old) end) $SumTypes.flag_to_symbol(::Type{<:$T_name}, flag::$flagtype) = $(foldr(collect(enumerate(con_names)), init=:(error("Invalid tag symbol $sym"))) do (i, sym), old - Expr(:if, :(flag == $i), QuoteNode(sym), old) + Expr(:if, :(flag == $(i-1)), QuoteNode(sym), old) end) $SumTypes.tags_flags_nt(::Type{<:$T_name}) = $(Expr(:tuple, Expr(:parameters, (Expr(:kw, name, flagtype(i)) for (i, name) ∈ enumerate(con_names))...))) $SumTypes.tags(::Type{<:$T_name}) = $(Expr(:tuple, map(x -> QuoteNode(x.name), constructors)...)) - $SumTypes.constructors(::Type{$T_name}) = + $SumTypes.constructors(::Type{<:$T_name}) = $NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.store_type_uninit for nt ∈ constructors)...))) - $SumTypes.unwrap(x::$T_name) = let tag = $get_tag(x) + $SumTypes.variants_Tuple(::Type{<:$T_name}) = + $Tuple{$((nt.store_type_uninit for nt ∈ constructors)...)} + + $SumTypes.unwrap(x::$T_nameparam) where {$(T_params...)}= let tag = $get_tag(x) $if_nest_unwrap end - $Base.adjoint(::Type{$T_name}) = + $Base.adjoint(::Type{<:$T_name}) = $NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.gname for nt ∈ constructors)...))) - + $SumTypes.full_type(::Type{$T_nameparam}) where {$(T_params...)} = $full_type($T_nameparam, $variants_Tuple($T_nameparam)) + $Base.show(io::IO, x::$T_name) = $show_sumtype(io, x) $Base.show(io::IO, m::MIME"text/plain", x::$T_name) = $show_sumtype(io, m, x) @@ -294,10 +274,9 @@ function generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_n $only_define_with_params end foreach(constructors) do nt - - con1 = :($SumTypes.constructor(::Type{$T_name}, ::Type{Val{$(QuoteNode(nt.name))}}) = $(nt.store_type_uninit)) + con1 = :($SumTypes.constructor(::Type{<:$T_name}, ::Type{Val{$(QuoteNode(nt.name))}}) = $(nt.store_type_uninit)) con2 = if !isempty(T_params) - :($SumTypes.constructor(::Type{$T_nameparam}, ::Type{Val{$(QuoteNode(nt.name))}}) where {$(T_params...)} = $(nt.store_type)) + :($SumTypes.constructor(::Type{<:$T_nameparam}, ::Type{Val{$(QuoteNode(nt.name))}}) where {$(T_params...)} = $(nt.store_type)) end push!(ex.args, con1, con2) end diff --git a/test/runtests.jl b/test/runtests.jl index 95b21b8..452bbdb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,18 +56,26 @@ end Right(x) => x end)) - @test_throws Exception macroexpand(@__MODULE__(), - :(@sum_type Blah begin - duplicate_field - duplicate_field - end)) - + @test_throws Exception SumTypes._sum_type( + :Blah, :(begin + duplicate_field + duplicate_field + end)) - @test_throws Exception macroexpand(@__MODULE__(), - :(@sum_type Blah begin - duplicate_field - duplicate_field - end some_option=false)) + @test_throws Exception SumTypes._sum_type( + :Blah, :(begin + duplicate_field + end), :(some_option=false)) + + @test_throws Exception SumTypes._sum_type( + :Blah, :(begin + x * field^2 -1 + end)) + + @test_throws Exception SumTypes._sum_type( + :(Blah{T}), :(begin + foo{U}(::U) + end )) let x = Left([1]), y = Left([1.0]), z = Right([1]) @test x == y @@ -78,6 +86,10 @@ end @test_throws MethodError Left{Int}("hi") @test_throws MethodError Right{String}(1) @test Left{Int}(0x01) === Left{Int}(1) + + @test full_type(Either{Nothing, Nothing}) == Either{Nothing, Nothing, 0, 0} + @test full_type(Either{Int, Int}) == Either{Int, Int, 15, 0} + @test full_type(Either{Int, String}) == Either{Int, String, 8, 1} end #--------------------------------------------------------