Skip to content

Commit

Permalink
Compactify storage (#20)
Browse files Browse the repository at this point in the history
* implement compactification

* add link to top of README

* try to resolve printing problems on early julia versions

* fix alignment

* cleanup coverage

* more tests

* put ptrs at the start for better alignment

* move macro body to function for better testing

* whoops, that killed performance

* pad smarter

* rely on method inlining for unwrap

* make this v0.4.0 since the extra parameters may be a breaking change
  • Loading branch information
MasonProtter authored Apr 3, 2023
1 parent b0423ba commit 297b2ac
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 122 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SumTypes"
uuid = "8e1ec7a9-0e02-4297-b0fe-6433085c89f2"
authors = ["MasonProtter <[email protected]>"]
version = "0.3.8"
version = "0.4.0"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
91 changes: 65 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- [Basics](https://github.com/MasonProtter/SumTypes.jl#basics)
- [Destructuring sum types](https://github.com/MasonProtter/SumTypes.jl#destructuring-sum-types)
- [Using `full_type` to get the concrete type of a Sum Type](https://github.com/MasonProtter/SumTypes.jl/tree/compute-storage#using-full_type-to-get-the-concrete-type-of-a-sum-type)
- [Avoiding namespace clutter](https://github.com/MasonProtter/SumTypes.jl#avoiding-namespace-clutter)
- [Custom printing](https://github.com/MasonProtter/SumTypes.jl#custom-printing)
- [Performance](https://github.com/MasonProtter/SumTypes.jl#performance)
Expand Down Expand Up @@ -162,6 +163,45 @@ The `@cases` macro still falls far short of a full on pattern matching system, l

<!-- </details> -->

## Using `full_type` to get the concrete type of a Sum Type

<details>
<summary>Click to expand</summary>

SumTypes.jl generates structs with a compactified memory layout which is computed on demand for parametric types. Because of this,
every SumTypes actually has two extra type parameters related to its memory layout. This means that for instance, `Either{Int, Int}`:

``` julia
julia> @sum_type Either{A, B} begin
Left{A}(::A)
Right{B}(::B)
end

julia> isconcretetype(Either{Int, Int})
false
```

In order to get the proper, concrete type corresponding to `Either{Int, Int}`, one can just use the `full_type` function exported by SumTypes.jl:

``` julia
julia> full_type(Either{Int, Int})
Either{Int64, Int64, 8, 0}

julia> full_type(Either{Int, String})
Either{Int64, String, 8, 1}

julia> full_type(Either{Tuple{Int, Int, Int}, String})
Either{Tuple{Int64, Int64, Int64}, String, 24, 1}

julia> isconcretetype(ans)
true
```

Avoiding these extra parameters would require https://github.com/JuliaLang/julia/issues/8472 to be implemented.

</details>


## Avoiding namespace clutter

<details>
Expand Down Expand Up @@ -307,15 +347,15 @@ end

```
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 393.061 μs … 6.725 ms ┊ GC (min … max): 0.00% … 90.92%
Time (median): 434.257 μs ┊ GC (median): 0.00%
Time (mean ± σ): 483.461 μs ± 435.758 μs ┊ GC (mean ± σ): 9.38% ± 9.39%
Range (min … max): 267.399 μs … 3.118 ms ┊ GC (min … max): 0.00% … 90.36%
Time (median): 278.904 μs ┊ GC (median): 0.00%
Time (mean ± σ): 316.971 μs ± 306.290 μs ┊ GC (mean ± σ): 11.68% ± 10.74%
▅▃▁
████▆▆▃▃▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▃▅
393 μs Histogram: log(frequency) by time 4.2 ms <
▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇
267 μs Histogram: log(frequency) by time 2.77 ms <
Memory estimate: 654.16 KiB, allocs estimate: 21950.
Memory estimate: 654.75 KiB, allocs estimate: 21952.
```

SumTypes.jl
Expand Down Expand Up @@ -358,13 +398,13 @@ end

```
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 61.309 μs … 83.300 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 62.350 μs ┊ GC (median): 0.00%
Time (mean ± σ): 62.376 μs ± 528.152 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
Range (min … max): 54.890 μs … 73.650 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 55.750 μs ┊ GC (median): 0.00%
Time (mean ± σ): 55.908 μs ± 655.652 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▃█▂ ▁▄▃▂
▁▁▁▁▁▁▁▁▂▁▂▃▅▅▇███▆▄▃▃▄▄▇████▅▄▃▂▂▂▂▂▂▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
61.3 μs Histogram: frequency by time 64 μs <
▁▄▇██▇▆▅▄ ▂▁
▁▁▁▁▂▃▄▇████████████▇▆▅▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
54.9 μs Histogram: frequency by time 58.4 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
```
Expand Down Expand Up @@ -423,25 +463,24 @@ end

```
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 69.355 μs … 234.343 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 75.388 μs ┊ GC (median): 0.00%
Time (mean ± σ): 77.753 μs ± 13.757 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
Range (min … max): 54.470 μs … 67.920 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 55.640 μs ┊ GC (median): 0.00%
Time (mean ± σ): 55.692 μs ± 498.787 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█▃▄▃▇▆▆▃▄
██████████▇▇█▇▇█████▆▆▆▆▆▆▇▆▅▇▆▆▅▄▆▇▆▆▆▅▇▆▅▆▅▄▅▄▄▅▄▆▅▅▅▅▅▅▅▅ █
69.4 μs Histogram: log(frequency) by time 149 μs <
▁▂▄▅▆▆▇▇▇█▅▅▃▂▂
▁▁▁▁▁▁▂▃▄▃▄▅▆▇▇███████████████▇▆▆▅▄▃▃▂▂▂▂▂▂▂▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁ ▄
54.5 μs Histogram: frequency by time 57.5 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
```

SumTypes.jl is able to slightly beat Unityper.jl in this benckmark, though there are cases where the roles are reversed.
SumTypes.jl has some other advantages relative to Unityper.jl too, such as:
- SumTypes.jl allows [parametric types](https://docs.julialang.org/en/v1/manual/types/#Parametric-Types) for much greater container flexibility (Unityper does some memory layout optimizations that won't work with parametric types).
- SumTypes.jl does not require default values for every field of the struct
SumTypes.jl and Unityper.jl are about equal in this benchmark, though there are cases where there are differences.
SumTypes.jl has some other advantages relative to Unityper.jl such as:
- SumTypes.jl allows [parametric types](https://docs.julialang.org/en/v1/manual/types/#Parametric-Types) for much greater container flexibility.
- SumTypes.jl does not require default values for every field of the struct.
- SumTypes.jl's `@cases` macro is more powerful and flexible than Unityper's `@compactified`.
- SumTypes.jl allows you to hide its variants from the namespace (opt in).

Whereas some advantages of Unityper.jl are:
- A `@compactified` type from Unityper.jl will often have a smaller memory footprint than a corresponding type from SumTypes.jl
- If we had used `D(;common_field=1, b="hi")` in our benchmarks, SumTypes.jl could have incurred an allocation whereas Unitypeper.jl would not. This allocation is due to the compiler heuristics involved in `::Union{T, Nothing}` fields of structs and may be fixed in future versions of julia.
One advantage of Unityper.jl is:
- Because Unityper.jl doesn't allow parameterized types and needs to know all type information at macroexpansion time, their structs have a fixed layout for boxed variables that lets them avoid an allocation when storing heap allocated objects (this allocation would be in addition to the heap allocation for the object itself). If we had used `D(;common_field=1, b="hi")` in our benchmarks, SumTypes.jl could have incurred an allocation whereas Unityper.jl would not. As far as I know, this would requre https://github.com/JuliaLang/julia/issues/8472 in order to avoid in SumTypes.jl

23 changes: 16 additions & 7 deletions src/SumTypes.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module SumTypes

export @sum_type, @cases, Uninit
export @sum_type, @cases, Uninit, full_type

using MacroTools: MacroTools

function parent end
function constructors end
function constructor end
function constructors_Union end
function variants_Tuple end
function unwrap end
function tags end
function deparameterize end
Expand All @@ -16,6 +17,9 @@ function flagtype end
function flag_to_symbol end
function symbol_to_flag end
function tags_flags_nt end
function variants_Tuple end
function strip_size_params end
function full_type end


struct Unsafe end
Expand All @@ -26,8 +30,6 @@ struct Uninit end
struct Variant{fieldnames, Tup <: Tuple}
data::Tup
Variant{fieldnames, Tup}(::Unsafe) where {fieldnames, Tup} = new{fieldnames, Tup}()
# Variant(::Unsafe, nt::NamedTuple{names, Tup}) where {names, Tup} = new{fieldnames, Tup}(Tuple(nt))
# Variant{fieldnames}(t::Tup) where {fieldnames, Tup <: Tuple} = new{fieldnames, Tup}(t)
Variant{fieldnames, Tup}(t::Tuple) where {fieldnames, Tup <: Tuple} = new{fieldnames, Tup}(t)
end
Base.:(==)(v1::Variant, v2::Variant) = v1.data == v2.data
Expand All @@ -37,20 +39,27 @@ Base.indexed_iterate(x::Variant, i::Int, state=1) = (Base.@_inline_meta; (getfie

const tag = Symbol("#tag#")
get_tag(x) = getfield(x, tag)
get_tag_sym(x::T) where {T} = keys(tags_flags_nt(T))[Int(get_tag(x))]
get_tag_sym(x::T) where {T} = keys(tags_flags_nt(T))[Int(get_tag(x)) + 1]

show_sumtype(io::IO, m::MIME, x) = show_sumtype(io, x)
function show_sumtype(io::IO, x::T) where {T}
tag = get_tag(x)
sym = flag_to_symbol(T, tag)
if getfield(x, sym) isa Variant{(), Tuple{}}
print(io, String(sym), "::", typeof(x))
T_stripped = if length(T.parameters) == 2
String(T.name.name)
else
print(io, String(sym), '(', join((repr(data) for data getfield(x, sym)), ", "), ")::", typeof(x))
string(String(T.name.name), "{", join(repr.(T.parameters[1:end-2]), ", "), "}")
end
if unwrap(x) isa Variant{(), Tuple{}}
print(io, String(sym), "::", T_stripped)
else
print(io, String(sym), '(', join((repr(data) for data unwrap(x)), ", "), ")::", T_stripped)
end
end

include("compute_storage.jl")
include("sum_type.jl") # @sum_type defined here
include("cases.jl") # @cases defined here


end # module
5 changes: 2 additions & 3 deletions src/cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ macro cases(to_match, block)

ex = :(if $get_tag($data) === $symbol_to_flag($Typ, $(QuoteNode(stmts[1].variant)));
$(stmts[1].iscall ? :(($(stmts[1].fieldnames...),) =
$getfield($data, $(QuoteNode(stmts[1].variant))) :: $constructor($Typ, $Val{$(QuoteNode(stmts[1].variant))} ) ) : nothing);
$unwrap($data, $constructor($Typ, $Val{$(QuoteNode(stmts[1].variant))}), $variants_Tuple($Typ)) ) : nothing);
$(stmts[1].rhs)
end)
Base.remove_linenums!(ex)
Expand All @@ -67,7 +67,7 @@ macro cases(to_match, block)
for i 2:length(stmts)
_if = :(if $get_tag($data) === $symbol_to_flag($Typ, $(QuoteNode(stmts[i].variant)));
$(stmts[i].iscall ? :(($(stmts[i].fieldnames...),) =
$getfield($data, $(QuoteNode(stmts[i].variant))) :: $constructor($Typ, $Val{$(QuoteNode(stmts[i].variant))} )) : nothing);
$unwrap($data, $constructor($Typ, $Val{$(QuoteNode(stmts[i].variant))}), $variants_Tuple($Typ))) : nothing);
$(stmts[i].rhs)
end)
_if.head = :elseif
Expand All @@ -82,7 +82,6 @@ macro cases(to_match, block)
let $data = $to_match
$Typ = $typeof($data)
$check_sum_type($Typ)
# $nt = $tags_flags_nt($Typ)
$assert_exhaustive(Val{$tags($Typ)}, Val{$(Expr(:tuple, QuoteNode.(deparameterize.(variants))...))})
$ex
end
Expand Down
123 changes: 123 additions & 0 deletions src/compute_storage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
struct PlaceHolder end

macro assume_effects(args...)
if isdefined(Base, Symbol("@assume_effects"))
ex = :($Base.@assume_effects($(args...)))
else
ex = args[end]
end
esc(ex)
end

@assume_effects :consistent :foldable function unsafe_padded_reinterpret(::Type{T}, x::U) where {T, U}
@assert isbitstype(T) && isbitstype(U)
n, m = sizeof(T), sizeof(U)
if sizeof(U) < sizeof(T)
payload = (x, ntuple(_ -> zero(UInt8), Val(n-m)), )
else
payload = x
end
let r = Ref(payload)
GC.@preserve r begin
p = pointer_from_objref(r)
unsafe_load(Ptr{T}(p))
end
end
end

function extract_info(::Type{ST}, variants) where {ST}

data = map(variants) do variant
(names, store_types) = variant.parameters
bits = []
ptrs = []
@assert length(names) == length(store_types.parameters)
foreach(zip(names, store_types.parameters)) do (name, T)
if isbitstype(T)
push!(bits, name => T)
else
push!(bits, name => SumTypes.PlaceHolder)
push!(ptrs, name => T)
end
end
bits, ptrs
end
bitss = map(x -> x[1], data)
ptrss = map(x -> x[2], data)
nptrs = maximum(length, ptrss)
ptr_names = map(v -> map(x -> x[1], v), ptrss)
bit_names = map(v -> map(x -> x[1], v), bitss)
bit_sigs = map(v -> map(x -> x[2], v), bitss)

FT = fieldtype(ST, 3)
bit_size = if nptrs == 0
maximum(v -> sizeof(Tuple{map(x -> x[2], v)..., FT}), bitss) - sizeof(FT)
else
maximum(v -> sizeof(Tuple{map(x -> x[2], v)..., }), bitss)
end

(;
bitss = bitss,
ptrss = ptrss,
nptrs = nptrs,
ptr_names = ptr_names,
bit_size = bit_size,
bit_names = bit_names,
bit_sigs = bit_sigs,
)
end


make(::Type{ST}, to_make, tag) where {ST} = make(ST, to_make, tag, variants_Tuple(ST))
@generated function make(::Type{ST}, to_make::Var, tag, ::Type{var_Tuple}) where {ST, Var <: Variant, var_Tuple <: Tuple}
variants = var_Tuple.parameters
i = findfirst(==(Var), variants)
nt = extract_info(ST, variants)

nptrs = nt.nptrs
ptr_names = nt.ptr_names
bit_size = nt.bit_size
bit_names = nt.bit_names
bit_sigs = nt.bit_sigs

bitvariant = :(SumTypes.Variant{($(QuoteNode.(bit_names[i])...),), Tuple{$(bit_sigs[i]...)}}(
($(([bit_sigs[i][j] == PlaceHolder ? PlaceHolder() : :(to_make.data[$j]) for j eachindex(bit_sigs[i]) ])...),) ))
ptr_args = [:(to_make.data[$j]) for j eachindex(bit_names[i]) if bit_names[i][j] ptr_names[i]]
con = Expr(
:new,
ST{bit_size, nptrs},
:(unsafe_padded_reinterpret(NTuple{$bit_size, UInt8}, $bitvariant)),
Expr(:tuple, ptr_args..., (nothing for _ 1:(nptrs-length(ptr_args)))...),
:tag,
)
end



unwrap(x::ST, var) where {ST} = unwrap(x, var, variants_Tuple(ST))
@generated function unwrap(x::ST, ::Type{Var}, ::Type{var_Tuple}) where {ST, Var, var_Tuple}
variants = var_Tuple.parameters
i = findfirst(==(Var), variants)
nt = extract_info(ST, variants)
ptrss = nt.ptrss
nptrs = nt.nptrs
ptr_names = nt.ptr_names
bit_size = nt.bit_size
bit_names = nt.bit_names
bit_sigs = nt.bit_sigs
quote
names = ($(QuoteNode.(bit_names[i])...),)
bits = unsafe_padded_reinterpret(Variant{names, Tuple{$(bit_sigs[i]...)}}, x.bits)
args = $(Expr(:tuple,
(bit_names[i][j] ptr_names[i] ? let k = findfirst(x -> x == bit_names[i][j], ptr_names[i])
:(x.ptrs[$k]:: $(ptrss[i][k][2]))
end : :(bits.data[$j]) for j eachindex(bit_names[i]))...))
Variant{names, $(Var.parameters[2])}(args)
end
end

Base.@generated function full_type(::Type{ST}, ::Type{var_Tuple}) where {ST, var_Tuple}
variants = var_Tuple.parameters
nt = extract_info(ST, variants)
:($ST{$(nt.bit_size), $(nt.nptrs)})
end
Loading

0 comments on commit 297b2ac

Please sign in to comment.