diff --git a/Project.toml b/Project.toml index 8864eb4..6c4caec 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,11 @@ version = "0.1.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" [compat] Adapt = "1" diff --git a/perf/Project.toml b/perf/Project.toml new file mode 100644 index 0000000..3b5ece3 --- /dev/null +++ b/perf/Project.toml @@ -0,0 +1,8 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" + +[compat] +julia = "1.3" diff --git a/perf/benchmarks.jl b/perf/benchmarks.jl new file mode 100644 index 0000000..9e0f934 --- /dev/null +++ b/perf/benchmarks.jl @@ -0,0 +1,153 @@ +#function sv_ms_assign_loop(N::Int) +# s1 = MatrixShape(Float64, 5,10) +# s2 = VectorShape(Float64, 15) +# s3 = ScalarShape(Float64) +# +# ms = MultiShape(s1 = s1, s2 = s2, s3=s3) +# src = rand!(allocate(ms, N)) +# dst = zeros(ms, N) +# +# bench = @benchmarkable begin +# sv_ms_assign_loop_bench($ms, $src, $dst) +# end teardown=(@assert $src == $dst) +# +# std = @benchmarkable sv_assign_loop_std($src, $dst) +# +# bench, std +#end +# +#function sv_ms_assign_loop_bench(ms, src, dst) +# @uviews src dst @inbounds for i in axes(src, 2) +# s, d = view(src, :, i), view(dst, :, i) +# src, dst = ms(s), ms(d) +# +# dst.s1 .= src.s1 +# dst.s2 .= src.s2 +# dst.s3 = src.s3 +# end +#end +# +#function sv_ms_assign_loop_std(src, dst) +# @uviews src dst @inbounds for i in axes(src, 2) +# s, d = view(src, :, i), view(dst, :, i) +# s .= d +# end +#end + + +function sv_ms_nested_assign() + s1 = MatrixShape(Float64,5,10) + s2 = VectorShape(Float64,5) + s3 = ScalarShape(Float64) + ms1 = MultiShape(s1 = s1, s2 = s2, s3=s3) + + s4 = VectorShape(Float64,3) + s5 = ScalarShape(Float64) + ms2 = MultiShape(s4=s4, s5=s5) + + s6 = ScalarShape(Float64) + ms3 = MultiShape(s6=s6, ms2=ms2) + + ms = MultiShape(ms1=ms1, ms2=ms2, ms3=ms3) + + src = rand!(allocate(ms)) + shaped_src = ms(rand!(allocate(ms))) + dst = zeros(ms) + shaped_dst = ms(zeros(ms)) + + bench = @benchmarkable begin + sv_ms_nested_assign_bench($shaped_dst, $shaped_src) + end teardown=(@assert $shaped_dst == $shaped_src) + + std = @benchmarkable begin + sv_ms_nested_assign_std($dst, $src) + end teardown = (@assert $dst == $src) + + bench, std +end + +function sv_ms_nested_assign_bench(dst, src) + @uviews dst src @inbounds begin + dst.ms1.s1 .= src.ms1.s1 + dst.ms1.s2 .= src.ms1.s2 + dst.ms1.s3 = src.ms1.s3 + + dst.ms2.s4 .= src.ms2.s4 + dst.ms2.s5 = src.ms2.s5 + + dst.ms3.s6 = src.ms3.s6 + dst.ms3.ms2.s4 .= src.ms3.ms2.s4 + dst.ms3.ms2.s5 = src.ms3.ms2.s5 + end +end + +function sv_ms_nested_assign_std(dst, src) + @uviews dst src @inbounds begin + dst[1:50] .= view(src, 1:50) + dst[51:55] .= view(src, 51:55) + dst[56] = src[56] + dst[57:59] .= view(src, 57:59) + dst[60] = src[60] + dst[61] = src[61] + dst[62:64] .= view(src, 62:64) + dst[65] = src[65] + end +end + + +function sv_ms_assign() + s1 = MatrixShape(Float64, 5,10) + s2 = VectorShape(Float64, 15) + s3 = ScalarShape(Float64) + ms = MultiShape(s1 = s1, s2 = s2, s3=s3) + + src = rand!(allocate(ms)) + shaped_src = ms(rand!(allocate(ms))) + dst = zeros(ms) + shaped_dst = ms(zeros(ms)) + + bench = @benchmarkable begin + sv_ms_assign_bench($shaped_dst, $shaped_src) + end teardown=(@assert $shaped_src == $shaped_dst) + + std = @benchmarkable begin + sv_ms_assign_std($dst, $src) + end teardown = @assert $dst == $src + + bench, std +end + +function sv_ms_assign_bench(src, dst) + @uviews dst src @inbounds begin + dst.s1 .= src.s1 + dst.s2 .= src.s2 + dst.s3 = src.s3 + end +end + +function sv_ms_assign_std(dst, src) + @uviews dst src @inbounds begin + dst[1:50] .= view(src, 1:50) + dst[51:65] .= view(src, 51:65) + dst[66] = src[66] + end +end + + +function sv_vs_assign() + s = VectorShape(Float64, 15) + + src = rand!(allocate(s)) + shaped_src = ShapedView(rand!(allocate(s)), s) + dst = zeros(s) + shaped_dst = ShapedView(zeros(s), s) + + bench = @benchmarkable begin + @inbounds $shaped_dst .= $shaped_src + end teardown=(@assert $shaped_src == $shaped_dst) + std = @benchmarkable begin + @inbounds $dst .= $src + end + bench, std +end + diff --git a/perf/run_benchmarks.jl b/perf/run_benchmarks.jl new file mode 100644 index 0000000..8e8ee93 --- /dev/null +++ b/perf/run_benchmarks.jl @@ -0,0 +1,58 @@ +using Random, BenchmarkTools, UnsafeArrays, Statistics +using BenchmarkTools: prettytime, prettypercent + +push!(LOAD_PATH, joinpath(@__DIR__, "..")) +using Shapes + +include("benchmarks.jl") + +function bgroup!(suite_or_group, name, bench, std) + g = suite_or_group[name] = BenchmarkGroup() + g["bench"] = bench + g["std"] = std + g +end + + + +# Define a parent BenchmarkGroup to contain our suite +suite = BenchmarkGroup() + +suite["sv"] = BenchmarkGroup() +bgroup!(suite["sv"], "vs_assign", sv_vs_assign()...) +bgroup!(suite["sv"], "ms_assign", sv_ms_assign()...) +bgroup!(suite["sv"], "ms_nested_assign", sv_ms_nested_assign()...) + +function run_benchmarks() + tune!(suite) + results = run(suite) + summarize_results(results) + results +end + +summarize_results(suite_or_group) = summarize_results(suite_or_group, "") +function summarize_results(suite_or_group, name) + if haskey(suite_or_group, "bench") + println() + println("Benchmark $name") + bench, std = median.(values(suite_or_group)) + if bench.allocs != 0 + @warn "Non-zero allocs in bench: $(bench.allocs)" + end + if std.allocs != 0 + @error "Non-zero allocs in std: $(std.allocs)" + end + bench_t = time(bench) + std_t = time(std) + ratio = (bench_t - std_t) / std_t + + str = "Bench: $(prettytime(bench_t))" + str *= " Std: $(prettytime(std_t))" + str *= " Percent slower: $(prettypercent(ratio))" + println(str) + else + for (k, v) in suite_or_group + summarize_results(v, k) + end + end +end \ No newline at end of file diff --git a/proto/Project.toml b/proto/Project.toml deleted file mode 100644 index e38b02f..0000000 --- a/proto/Project.toml +++ /dev/null @@ -1,19 +0,0 @@ -name = "proto" -uuid = "66c8cbca-5ec6-4091-aa7b-25ec20c0f1c0" -authors = ["colinxs "] -version = "0.1.0" - -[deps] -ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[compat] -StaticArrays = "^0.12" -julia = "1.1.1" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] \ No newline at end of file diff --git a/proto/nested.jl b/proto/nested.jl deleted file mode 100644 index a714e1c..0000000 --- a/proto/nested.jl +++ /dev/null @@ -1,59 +0,0 @@ -#include("../src/Shapes.jl") -using Shapes - -s1 = MatrixShape{5,10,Real}() -s2 = VectorShape{5,Float32}() -s3 = ScalarShape{Int}() -ms1 = MultiShape(s1 = s1, s2 = s2, s3=s3) - -s4 = VectorShape{3,Float32}() -s5 = ScalarShape{Int}() -ms2 = MultiShape(s4=s4, s5=s5) - -s6 = ScalarShape{Float64}() -ms3 = MultiShape(s6=s6, ms2=ms2) - -ms = MultiShape(ms1=ms1, ms2=ms2, ms3=ms3) - -x = collect(1:length(ms)) -sv = ms(x) - -getmultishape(::ShapedView{MS}) where MS = MS - -gindices(ms::MultiShape, shapename::Symbol) = computeidxsimpl(ms, shapename, getproperty(ms, shapename)) -function computeidxsimpl(ms::MultiShape, shapename::Symbol, shape::AbstractShape) - nt = get(ms) - shapeidx = Base.fieldindex(typeof(nt), shapename) - shapes = values(ms) - from = shapeidx == 1 ? 1 : 1 + sum(i -> length(shapes[i]), 1:(shapeidx-1)) - ifelse(shape isa ScalarShape, from, from:(from+length(shape)-1)) -end - -goffset(ms::MultiShape, shapename::Symbol) = goffsetimpl(ms, shapename, getproperty(ms, shapename)) -function goffsetimpl(ms::MultiShape, shapename::Symbol, shape::AbstractShape) - nt = Shapes.get(ms) - shapeidx = Base.fieldindex(typeof(nt), shapename) - shapes = values(ms) - from = shapeidx == 1 ? 1 : 1 + sum(i -> length(shapes[i]), 1:(shapeidx-1)) - return from - #ifelse(shape isa ScalarShape, from, from:(from+length(shape)-1)) -end - - -gidx(x::ShapedView) = getfield(x, :data) - -function Base.getindex(sv::ShapedView, i1, I...) - ms = getmultishape(sv) - data = getfield(sv, :data) - _getindex(data, ms, i1, I...) -end - -function _getindex(data, ms::MultiShape, i1, I...) - shape = getproperty(ms, i1) - offset = goffset(ms, i1) - data = view(data, offset:(offset + length(shape) - 1)) - _getindex(data, shape, I...) -end -_getindex(data, ::AbstractShape) = data - - diff --git a/proto/src/proto.jl b/proto/src/proto.jl deleted file mode 100644 index 3db235b..0000000 --- a/proto/src/proto.jl +++ /dev/null @@ -1,5 +0,0 @@ -module proto - -greet() = print("Hello World!") - -end # module diff --git a/src/ShapedView.jl b/src/ShapedView.jl index 75e2a93..439a899 100644 --- a/src/ShapedView.jl +++ b/src/ShapedView.jl @@ -1,56 +1,137 @@ -struct ShapedView{MS,D<:AbstractVector} +# TODO require that T <: U? +struct ShapedView{T,N,SH,D<:AbstractVector{T}} <: DenseArray{T,N} data::D - Base.@propagate_inbounds function ShapedView( - data::AbstractVector, - multishape::MultiShape, - ) - @boundscheck if length(data) != length(multishape) - error("length of `data` must be equal to the length of `multishape`") - Base.require_one_based_indexing(data) - end - new{multishape,typeof(data)}(data) + offset::Int + Base.@propagate_inbounds function ShapedView{T,N,SH,D}( + data::D, + offset::Int + ) where {T,N,SH,D<:AbstractVector{T}} + @boundscheck check_offset_shape_inbounds(data, offset, SH) + check_has_unit_axes(data) + SH isa AbstractShape || throw(ArgumentError("Type parameter SH must be <: AbstractShape")) + new{T,N,SH,D}(data, offset) end end -(s::MultiShape)(data) = ShapedView(data, s) - -Base.getproperty(sv::ShapedView, name::Symbol) = getproperty(sv, Val(name)) -@generated function Base.getproperty(sv::ShapedView{MS}, ::Val{name}) where {MS,name} - shape = getproperty(MS, name) - idxs = getindices(MS, name) - if shape isa ScalarShape - return quote - @_inline_meta - getfield(sv, :data)[$idxs] - end - elseif shape isa VectorShape - return quote - @_inline_meta - view(getfield(sv, :data), $idxs) - end - else - return quote - @_inline_meta - reshape(view(getfield(sv, :data), $idxs), $(size(shape))) - end - end +function ShapedView{T,N,SH}(data, offset) where {T,N,SH} + ShapedView{T,N,SH,typeof(data)}(data, offset) +end + +function ShapedView(data::AbstractVector, offset::Int, shape::AbstractShape) + ShapedView{eltype(data), ndims(shape), shape, typeof(data)}(data, offset) +end + +function ShapedView(data::AbstractVector, shape::AbstractShape) + ShapedView{eltype(data), ndims(shape), shape, typeof(data)}(data, 0) +end + +ShapedView(data, shape) = ShapedView(data, 0, shape) + +(s::AbstractShape)(data) = ShapedView(data, 0, s) + + +@pure shapeof(::Type{SV}) where {T,N,SH,SV <: ShapedView{T,N,SH}} = SH +@inline shapeof(A::ShapedView) = shapeof(typeof(A)) + +@pure Size(::Type{SV}) where {SV <: ShapedView} = Size(shapeof(SV)) +@inline Size(A::ShapedView) = Size(typeof(A)) + +@pure Length(::Type{SV}) where {SV <: ShapedView} = Length(shapeof(SV)) +@inline Length(A::ShapedView) = Length(typeof(A)) + +@pure Base.size(::Type{SV}) where {SV <: ShapedView} = size(shapeof(SV)) +@inline Base.size(A::ShapedView) = size(typeof(A)) + +@pure Base.length(::Type{SV}) where {SV <: ShapedView} = length(shapeof(SV)) +@inline Base.length(A::ShapedView) = length(typeof(A)) + +@pure Base.ndims(::Type{SV}) where {SV <: ShapedView} = ndims(shapeof(SV)) +@inline Base.ndims(A::ShapedView) = ndims(typeof(A)) + +@pure Base.axes(::Type{SV}) where {SV <: ShapedView} = axes(shapeof(SV)) +@inline Base.axes(A::ShapedView) = axes(typeof(A)) + +@pure Base.IndexStyle(::Type{SV}) where {SV <: ShapedView} = IndexLinear() +@inline Base.IndexStyle(A::ShapedView) = IndexStyle(typeof(A)) + +@pure Base.propertynames(::Type{SV}) where {SV <: ShapedView} = propertynames(shapeof(SV)) +@inline Base.propertynames(A::ShapedView) = propertynames(typeof(A)) + +function Base.unsafe_convert(::Type{Ptr{T}}, A::ShapedView{T}) where {T} + Base.unsafe_convert(Ptr{T}, _data(A)) +end + +Base.dataids(A::ShapedView) = Base.dataids(_data(A)) + +function Base.copy(A::ShapedView{T,N,SH}) where {T,N,SH} + ShapedView{T,N,SH}(copy(_data(A)), _offset(A)) end -Base.setproperty!(sv::ShapedView, name::Symbol, val) = setproperty!(sv, Val(name), val) -@generated function Base.setproperty!( - sv::ShapedView{MS}, - ::Val{name}, - val, -) where {MS,name} - shape = getproperty(MS, name) - if shape isa ScalarShape - idxs = getindices(MS, name) - return quote - @_inline_meta - setindex!(getfield(sv, :data), val, $idxs) - end - else - msg = "Can only setproperty! on scalars, try ShapedView(data).shapename .= val (note the .=)" - return :(error($msg)) + +@propagate_inbounds function Base.getproperty(A::ShapedView, name::Symbol) + shape = shapeof(A) + offset = _offset(A) + getoffset(shape, name) + innershape = getproperty(shape, name) + _maybe_shapedview(_data(A), offset, innershape) +end + +@propagate_inbounds function _maybe_shapedview(data, offset, ::AbstractScalarShape) + i = offset + firstindex(data) + @boundscheck checkbounds(data, i) + @inbounds getindex(data, i) +end + +@propagate_inbounds function _maybe_shapedview(data, offset, shape) + @boundscheck check_offset_shape_inbounds(data, offset, shape) + @inbounds ShapedView{eltype(data), ndims(shape), shape, typeof(data)}(data, offset) +end + + +@propagate_inbounds function Base.setproperty!(A::ShapedView, name::Symbol, x) + innershape = getproperty(shapeof(A), name) + _maybe_setproperty!(innershape, A, name, x) +end + +@propagate_inbounds function _maybe_setproperty!(shape::AbstractScalarShape, A, name, x) + offset = _offset(A) + getoffset(shapeof(A), name) + data = _data(A) + i = firstindex(data) + offset + @boundscheck checkbounds(data, i) + @inbounds setindex!(data, x, i) +end + +@inline function _maybe_setproperty!(shape::AbstractShape, A, name, x) + error("Cannot call `setproperty!` for shape $name of type $(typeof(shape))") +end + + +@propagate_inbounds function Base.getindex(A::ShapedView, i::Int) + data = _data(A) + i += _offset(A) + @boundscheck checkbounds(data, i) + @inbounds getindex(data, i) +end + +@propagate_inbounds function Base.setindex!(A::ShapedView, val, i::Int) + data = _data(A) + i += _offset(A) + @boundscheck checkbounds(data, i) + @inbounds setindex!(data, val, i) +end + + +function UnsafeArrays.unsafe_uview(A::ShapedView{T,N,SH}) where {T,N,SH} + @inbounds ShapedView{T,N,SH}(UnsafeArrays.unsafe_uview(_data(A)), _offset(A)) +end + +function check_offset_shape_inbounds(data::AbstractVector{T}, offset::Int, shape::AbstractShape{S,U,N}) where {T,S,U,N} + if !(0 <= offset < length(data)) + throw(ArgumentError("offset must be in range [0, length(data))")) + end + if length(data) < offset + length(shape) + throw(ArgumentError("offset + length(shape) cannot be greater than length(data)")) end -end \ No newline at end of file +end + +@inline _data(A::ShapedView) = getfield(A, :data) +@inline _offset(A::ShapedView) = getfield(A, :offset) \ No newline at end of file diff --git a/src/Shapes.jl b/src/Shapes.jl index 56d5310..a755d62 100644 --- a/src/Shapes.jl +++ b/src/Shapes.jl @@ -10,6 +10,8 @@ using Random: AbstractRNG using Requires: @init, @require +import UnsafeArrays + import StaticArrays: Size, Length, similar_type, get using StaticArrays: tuple_prod, tuple_length, @@ -33,7 +35,6 @@ export AbstractShape, concrete_eltype, ShapedView, - getindices, allocate diff --git a/src/traits.jl b/src/traits.jl index c763aba..e513260 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -1,8 +1,7 @@ """ abstract type AbstractShape{S, T, N, L} end -The supertype for the various concrete shapes defined by `Shapes`. The type parameters -exactly match those of [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl). +The supertype for the various concrete shapes defined by `Shapes`. The,/StaticArrays.jl). The `S` parameter is a `Tuple`-type specifying the dimensions, or size, of the `AbstractShape`- such as `Tuple{3,4,5}` for a 3×4×5-sized array. The `T` parameter specifies the underlying data type of the shape (e.g. the element type for an array @@ -12,6 +11,10 @@ Constructors may drop the `L` and `T` parameters if they are inferrable from the """ abstract type AbstractShape{S,T,N,L} end +const AbstractScalarShape{T} = AbstractShape{Tuple{},T} +const AbstractVectorShape{S,T} = AbstractShape{Tuple{S},T} +const AbstractMatrixShape{S1,S2,T} = AbstractShape{Tuple{S1,S2},T} + Size(::Type{SH}) where {SH<:AbstractShape{S}} where {S<:Tuple} = Size(S) Size(shape::SH) where {SH<:AbstractShape} = Size(SH) @@ -41,9 +44,9 @@ Base.axes(::Type{SH}) where {SH<:AbstractShape} = _axes(Size(SH)) @inline concrete_eltype(::Type{SH}) where {SH<:AbstractShape} = default_datatype(eltype(SH)) @inline concrete_eltype(shape::AbstractShape) = concrete_eltype(typeof(shape)) - #ScalarShape{T} = AbstractShape{Tuple{}, T, 0} - #VectorShape{N,T} = AbstractShape{Tuple{N}, T, 1} - #MatrixShape{N,M,T} = AbstractShape{Tuple{N,M}, T, 2} + + + struct Shape{S,T,N,L} <: AbstractShape{S,T,N,L} function Shape{S,T,N,L}() where {S<:Tuple,T,N,L} check_shape_params(S, T, Val{N}, Val{L}) @@ -106,6 +109,8 @@ end Adapt.adapt_storage(T::DataType, sh::Shape{S}) where {S <: Tuple} = Shape{S,T}() + +# TODO unneccesary type params struct MultiShape{S,T,N,L,namedtuple} <: AbstractShape{S,T,N,L} function MultiShape{S,T,N,L,namedtuple}() where {S,T,N,L,namedtuple} check_multishape_params(S, T, Val{N}, Val{L}, Val{namedtuple}) @@ -138,14 +143,6 @@ Base.propertynames(ms::MultiShape) = propertynames(NamedTuple(ms)) Base.values(ms::MultiShape) = values(NamedTuple(ms)) Base.merge(ms1::MultiShape, ms2::MultiShape) = MultiShape(merge(NamedTuple(ms1), NamedTuple(ms2))) -getindices(ms::MultiShape, shapename::Symbol) = computeidxsimpl(ms, shapename, getproperty(ms, shapename)) -function computeidxsimpl(ms::MultiShape, shapename::Symbol, shape::AbstractShape) - nt = get(ms) - shapeidx = Base.fieldindex(typeof(nt), shapename) - shapes = values(ms) - from = shapeidx == 1 ? 1 : 1 + sum(i -> length(shapes[i]), 1:(shapeidx-1)) - ifelse(shape isa ScalarShape, from, from:(from+length(shape)-1)) -end @pure function namedtuple_length(nt::NamedTuple{<:Any,<:Tuple{Vararg{AbstractShape}}}) sum(shape -> get(Length(shape)), values(nt)) @@ -186,3 +183,15 @@ end end @pure Adapt.adapt_storage(T::DataType, ::MS) where {MS <: MultiShape} = MultiShape(map(s->Adapt.adapt(T, s), get(MS))) + +@pure shapeindex(ms::MultiShape, name::Symbol) = Base.fieldindex(typeof(get(ms)), name) + +@pure Base.firstindex(ms::MultiShape) = firstindex(get(ms)) +@pure Base.lastindex(ms::MultiShape) = lastindex(get(ms)) + +@pure Base.getindex(ms::MultiShape, i::Int) = getindex(get(ms), i) + +@pure function getoffset(ms::MultiShape, name::Symbol) + idx = shapeindex(ms, name) + idx == firstindex(ms) ? 0 : sum(i -> length(ms[i]), 1:(idx-1)) +end \ No newline at end of file diff --git a/src/util.jl b/src/util.jl index 3678900..a858330 100644 --- a/src/util.jl +++ b/src/util.jl @@ -48,3 +48,8 @@ function default_datatype end @inline _default_datatype(::Type{>:Int}) = Int @inline _default_datatype(::Type{>:Float64}) = Float64 @inline _default_datatype(::Type{>:Real}) = Float64 + +has_unit_axes(A) = all(ax->ax isa AbstractUnitRange{Int}, axes(A)) +function check_has_unit_axes(A) + has_unit_axes(A) || throw(ArgumentError("The axes of data must be <: AbstractUnitRange{Int}")) +end \ No newline at end of file diff --git a/test/ShapedView.jl b/test/ShapedView.jl new file mode 100644 index 0000000..427312c --- /dev/null +++ b/test/ShapedView.jl @@ -0,0 +1,72 @@ +@testset "basic" begin + let A=rand(10) + s1 = VectorShape(Float64, 10) + s2 = VectorShape(Float64, 11) + @test_throws ArgumentError ShapedView(A, -1, s1) + @test_throws ArgumentError ShapedView(A, 11, s1) + @test_throws ArgumentError ShapedView(A, 0, s2) + end + + + ms = MultiShape( + x = MatrixShape(Int,5,5), + y = ScalarShape(Float64), + z = VectorShape(Float64, 2), + ) + lx = length(ms.x) + ly = length(ms.y) + lz = length(ms.z) + d = Float64.(collect(1:(lx + ly + lz))) + sv = ms(d) + + x(sv) = sv.x + y(sv) = sv.y + z(sv) = sv.z + @test @inferred(x(sv)) == reshape(view(d, 1:lx), size(ms.x)) + @test @inferred(y(sv)) === d[1+lx] + @test @inferred(z(sv)) == view(d, (1 + lx + ly):(lx + ly + lz)) + + d1 = rand(1:100, 5, 5) + d2 = rand() + d3 = rand(2) + + sv.x .= d1 + sv.y = d2 + sv.z .= d3 + + @test sv.x == d1 + @test sv.y == d2 + @test sv.z == d3 +end + +(s1=rand(5,10), s2=rand(5), ) + +@testset "nested" begin + s1 = MatrixShape(Real,5,10) + s2 = VectorShape(Float32,5) + s3 = ScalarShape(Int) + ms1 = MultiShape(s1 = s1, s2 = s2, s3=s3) + + s4 = VectorShape(Float32,3) + s5 = ScalarShape(Int) + ms2 = MultiShape(s4=s4, s5=s5) + + s6 = ScalarShape(Float64) + ms3 = MultiShape(s6=s6, ms2=ms2) + + ms = MultiShape(ms1=ms1, ms2=ms2, ms3=ms3) + x = collect(1:length(ms)) + sv = ms(x) + + @test sv.ms1 == 1:56 + @test vec(sv.ms1.s1) == 1:50 + @test vec(sv.ms1.s2) == 51:55 + @test sv.ms1.s3 == 56 + + @test vec(sv.ms2.s4) == 57:59 + @test sv.ms2.s5 == 60 + + @test sv.ms3.s6 == 61 + @test vec(sv.ms3.ms2.s4) == 62:64 + @test sv.ms3.ms2.s5 == 65 +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7914aa1..0035e14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,63 +5,7 @@ using Shapes: concrete_eltype, ShapedView @testset "Shapes.jl" begin - @testset "Shape" begin - @test ScalarShape{Int}() === Shape{Tuple{},Int}() - @test VectorShape{5,Int}() === Shape{Tuple{5},Int}() - @test MatrixShape{5,10,Int}() === Shape{Tuple{5,10},Int}() - - @test ScalarShape(Int) === Shape{Tuple{},Int}() - @test VectorShape(Int, 5) === Shape{Tuple{5},Int}() - @test MatrixShape(Int, 5, 10) === Shape{Tuple{5,10},Int}() - @test Shape(Int, 1,2,3,4,5,6,7) === Shape{Tuple{1,2,3,4,5,6,7}, Int, 7, 5040}() - - s = MatrixShape{5,10,Real}() - @test @inferred(Size(typeof(s))) === @inferred(Size(s)) === Size(5, 10) - @test @inferred(Length(typeof(s))) === @inferred(Length(s)) === Length(50) - @test @inferred(eltype(typeof(s))) === @inferred(eltype(s)) === Real - @test @inferred(length(typeof(s))) === @inferred(length(s)) === 50 - - @test @inferred(size(typeof(s))) === @inferred(size(s)) === (5, 10) - @test @inferred(size(typeof(s), 1)) === @inferred(size(s, 1)) === 5 - @test @inferred(size(typeof(s), 2)) === @inferred(size(s, 2)) === 10 - - @test @inferred(ndims(typeof(s))) === @inferred(ndims(s)) === 2 - @test @inferred(axes(typeof(s))) === @inferred(axes(s)) === (SOneTo(5), SOneTo(10)) - @test @inferred(concrete_eltype(typeof(s))) === @inferred(concrete_eltype(s)) === - Float64 - end - - @testset "MultiShape" begin - s1 = MatrixShape{5,10,Real}() - s2 = VectorShape{5,Float32}() - s3 = ScalarShape{Int}() - s = MultiShape(s1 = s1, s2 = s2, s3=s3) - - @test MultiShape(s, s4=s3).s3 === s3 - let s4 = MatrixShape(Float64, 1,2), ms = MultiShape(s4=s4) - @test merge(s, ms).s4 === s4 - end - - @test s.s1 === s1 - @test s.s2 === s2 - @test s.s3 === s3 - @test NamedTuple(s) isa NamedTuple - @test NamedTuple(s) === typeof(s).parameters[end] - @test @inferred(propertynames(s)) === propertynames(NamedTuple(s)) - @test @inferred(values(s)) == values(NamedTuple(s)) - - @test @inferred(Size(typeof(s))) === @inferred(Size(s)) === Size(56) - @test @inferred(Length(typeof(s))) === @inferred(Length(s)) === Length(56) - @test @inferred(eltype(typeof(s))) === @inferred(eltype(s)) === Real - @test @inferred(length(typeof(s))) === @inferred(length(s)) === 56 - @test @inferred(size(typeof(s))) === @inferred(size(s)) === (56,) - @test @inferred(ndims(typeof(s))) === @inferred(ndims(s)) === 1 - @test @inferred(axes(typeof(s))) === @inferred(axes(s)) === (SOneTo(56),) - @test @inferred(concrete_eltype(typeof(s))) === Float64 - @test @inferred(concrete_eltype(s)) === Float64 - - - end + @testset "traits" begin include("traits.jl") end @testset "StaticArrays support" begin svec = SVector(1, 2, 3) @@ -112,35 +56,6 @@ using Shapes: concrete_eltype, ShapedView end end - @testset "ShapedView" begin - ms = MultiShape( - x = MatrixShape{5,5,Int}(), - y = ScalarShape{Float64}(), - z = VectorShape{2,Float64}(), - ) - lx = length(ms.x) - ly = length(ms.y) - lz = length(ms.z) - d = Float64.(collect(1:(lx + ly + lz))) - sv = ms(d) + @testset "ShapedView" begin include("ShapedView.jl") end - x(sv) = sv.x - y(sv) = sv.y - z(sv) = sv.z - @test @inferred(x(sv)) === reshape(view(d, 1:lx), size(ms.x)) - @test @inferred(y(sv)) === d[1+lx] - @test @inferred(z(sv)) === view(d, (1 + lx + ly):(lx + ly + lz)) - - d1 = rand(1:100, 5, 5) - d2 = rand() - d3 = rand(2) - - sv.x .= d1 - sv.y = d2 - sv.z .= d3 - - @test sv.x == d1 - @test sv.y == d2 - @test sv.z == d3 - end end diff --git a/test/traits.jl b/test/traits.jl new file mode 100644 index 0000000..0310c69 --- /dev/null +++ b/test/traits.jl @@ -0,0 +1,57 @@ +@testset "Shape" begin + @test ScalarShape{Int}() === Shape{Tuple{},Int}() + @test VectorShape{5,Int}() === Shape{Tuple{5},Int}() + @test MatrixShape{5,10,Int}() === Shape{Tuple{5,10},Int}() + + @test ScalarShape(Int) === Shape{Tuple{},Int}() + @test VectorShape(Int, 5) === Shape{Tuple{5},Int}() + @test MatrixShape(Int, 5, 10) === Shape{Tuple{5,10},Int}() + @test Shape(Int, 1,2,3,4,5,6,7) === Shape{Tuple{1,2,3,4,5,6,7}, Int, 7, 5040}() + + s = MatrixShape{5,10,Real}() + @test @inferred(Size(typeof(s))) === @inferred(Size(s)) === Size(5, 10) + @test @inferred(Length(typeof(s))) === @inferred(Length(s)) === Length(50) + @test @inferred(eltype(typeof(s))) === @inferred(eltype(s)) === Real + @test @inferred(length(typeof(s))) === @inferred(length(s)) === 50 + + @test @inferred(size(typeof(s))) === @inferred(size(s)) === (5, 10) + @test @inferred(size(typeof(s), 1)) === @inferred(size(s, 1)) === 5 + @test @inferred(size(typeof(s), 2)) === @inferred(size(s, 2)) === 10 + + @test @inferred(ndims(typeof(s))) === @inferred(ndims(s)) === 2 + @test @inferred(axes(typeof(s))) === @inferred(axes(s)) === (SOneTo(5), SOneTo(10)) + @test @inferred(concrete_eltype(typeof(s))) === @inferred(concrete_eltype(s)) === + Float64 +end + +@testset "MultiShape" begin + s1 = MatrixShape{5,10,Real}() + s2 = VectorShape{5,Float32}() + s3 = ScalarShape{Int}() + s = MultiShape(s1 = s1, s2 = s2, s3=s3) + + @test MultiShape(s, s4=s3).s3 === s3 + let s4 = MatrixShape(Float64, 1,2), ms = MultiShape(s4=s4) + @test merge(s, ms).s4 === s4 + end + + @test s.s1 === s1 + @test s.s2 === s2 + @test s.s3 === s3 + @test NamedTuple(s) isa NamedTuple + @test NamedTuple(s) === typeof(s).parameters[end] + @test @inferred(propertynames(s)) === propertynames(NamedTuple(s)) + @test @inferred(values(s)) == values(NamedTuple(s)) + + @test @inferred(Size(typeof(s))) === @inferred(Size(s)) === Size(56) + @test @inferred(Length(typeof(s))) === @inferred(Length(s)) === Length(56) + @test @inferred(eltype(typeof(s))) === @inferred(eltype(s)) === Real + @test @inferred(length(typeof(s))) === @inferred(length(s)) === 56 + @test @inferred(size(typeof(s))) === @inferred(size(s)) === (56,) + @test @inferred(ndims(typeof(s))) === @inferred(ndims(s)) === 1 + @test @inferred(axes(typeof(s))) === @inferred(axes(s)) === (SOneTo(56),) + @test @inferred(concrete_eltype(typeof(s))) === Float64 + @test @inferred(concrete_eltype(s)) === Float64 + + +end