From 1b7b1487e31f113e2ad3c0979430764e9881b6ea Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 9 Oct 2024 12:20:15 +0530 Subject: [PATCH 1/4] refactor: remove Tables.jl impl, fall back to RecursiveArrayTools --- Project.toml | 2 +- src/SciMLBase.jl | 3 +-- src/tabletraits.jl | 66 ---------------------------------------------- test/traits.jl | 16 +++++------ 4 files changed, 10 insertions(+), 77 deletions(-) delete mode 100644 src/tabletraits.jl diff --git a/Project.toml b/Project.toml index e7fc11a22..eda166335 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,6 @@ SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" @@ -111,6 +110,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 8e02dd3f3..ee1dc0d79 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -4,7 +4,7 @@ if isdefined(Base, :Experimental) && @eval Base.Experimental.@max_methods 1 end using ConstructionBase -using RecipesBase, RecursiveArrayTools, Tables +using RecipesBase, RecursiveArrayTools using SciMLStructures using SymbolicIndexingInterface using DocStringExtensions @@ -741,7 +741,6 @@ include("ensemble/ensemble_analysis.jl") include("solve.jl") include("interpolation.jl") include("integrator_interface.jl") -include("tabletraits.jl") include("remake.jl") include("callbacks.jl") diff --git a/src/tabletraits.jl b/src/tabletraits.jl deleted file mode 100644 index 0a57e4986..000000000 --- a/src/tabletraits.jl +++ /dev/null @@ -1,66 +0,0 @@ -Tables.isrowtable(::Type{<:AbstractTimeseriesSolution}) = true -Tables.columns(x::AbstractTimeseriesSolution) = Tables.columntable(Tables.rows(x)) - -struct AbstractTimeseriesSolutionRows{T, U} - names::Vector{Symbol} - types::Vector{Type} - lookup::Dict{Symbol, Int} - t::T - u::U -end - -function AbstractTimeseriesSolutionRows(names, types, t, u) - AbstractTimeseriesSolutionRows(names, types, - Dict(nm => i for (i, nm) in enumerate(names)), t, u) -end - -Base.length(x::AbstractTimeseriesSolutionRows) = length(x.u) -function Base.eltype(::Type{AbstractTimeseriesSolutionRows{T, U}}) where {T, U} - AbstractTimeseriesSolutionRow{eltype(T), eltype(U)} -end -function Base.iterate(x::AbstractTimeseriesSolutionRows, st = 1) - st > length(x) ? nothing : - (AbstractTimeseriesSolutionRow(x.names, x.lookup, x.t[st], x.u[st]), st + 1) -end - -function Tables.rows(sol::AbstractTimeseriesSolution) - VT = eltype(sol.u) - syms = variable_symbols(sol) - if VT <: AbstractArray - N = length(sol.u[1]) - names = [ - :timestamp, - (isempty(syms) ? (Symbol("value", i) for i in 1:N) : - (getname(syms[i]) for i in 1:N))... - ] - types = Type[eltype(sol.t), (eltype(sol.u[1]) for i in 1:N)...] - else - names = [:timestamp, isempty(syms) ? :value : syms[1]] - types = Type[eltype(sol.t), VT] - end - return AbstractTimeseriesSolutionRows(names, types, sol.t, sol.u) -end - -Tables.schema(x::AbstractTimeseriesSolutionRows) = Tables.Schema(x.names, x.types) - -struct AbstractTimeseriesSolutionRow{T, U} <: Tables.AbstractRow - names::Vector{Symbol} - lookup::Dict{Symbol, Int} - t::T - u::U -end - -Tables.columnnames(x::AbstractTimeseriesSolutionRow) = getfield(x, :names) -function Tables.getcolumn(x::AbstractTimeseriesSolutionRow, i::Int) - i == 1 ? getfield(x, :t) : getfield(x, :u)[i - 1] -end -function Tables.getcolumn(x::AbstractTimeseriesSolutionRow, nm::Symbol) - nm === :timestamp ? getfield(x, :t) : getfield(x, :u)[getfield(x, :lookup)[nm] - 1] -end - -IteratorInterfaceExtensions.isiterable(sol::AbstractTimeseriesSolution) = true -function IteratorInterfaceExtensions.getiterator(sol::AbstractTimeseriesSolution) - Tables.datavaluerows(Tables.rows(sol)) -end -#TableTraits.isiterabletable(sol::AbstractTimeseriesSolution) = true -Tables.istable(::EnsembleSolution) = false diff --git a/test/traits.jl b/test/traits.jl index e67204200..bb97dae9f 100644 --- a/test/traits.jl +++ b/test/traits.jl @@ -1,13 +1,13 @@ -using SciMLBase, Test +using SciMLBase, Tables, Test using OrdinaryDiffEq, DataFrames, SymbolicIndexingInterface -@test SciMLBase.Tables.isrowtable(ODESolution) -@test SciMLBase.Tables.isrowtable(RODESolution) -@test SciMLBase.Tables.isrowtable(DAESolution) -@test !SciMLBase.Tables.isrowtable(SciMLBase.NonlinearSolution) -@test !SciMLBase.Tables.isrowtable(SciMLBase.LinearSolution) -@test !SciMLBase.Tables.isrowtable(SciMLBase.QuadratureSolution) -@test !SciMLBase.Tables.isrowtable(SciMLBase.OptimizationSolution) +@test Tables.isrowtable(ODESolution) +@test Tables.isrowtable(RODESolution) +@test Tables.isrowtable(DAESolution) +@test !Tables.isrowtable(SciMLBase.NonlinearSolution) +@test !Tables.isrowtable(SciMLBase.LinearSolution) +@test !Tables.isrowtable(SciMLBase.QuadratureSolution) +@test !Tables.isrowtable(SciMLBase.OptimizationSolution) function rhs(u, p, t) return -u From 2a195b2ec55da802bb80dd7e15664a13b5f6da40 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 9 Oct 2024 12:46:34 +0530 Subject: [PATCH 2/4] test: add test for Tables.jl interface using MTK variable names --- test/downstream/tables.jl | 12 ++++++++++++ test/runtests.jl | 3 +++ 2 files changed, 15 insertions(+) create mode 100644 test/downstream/tables.jl diff --git a/test/downstream/tables.jl b/test/downstream/tables.jl new file mode 100644 index 000000000..6535650d5 --- /dev/null +++ b/test/downstream/tables.jl @@ -0,0 +1,12 @@ +using ModelingToolkit, OrdinaryDiffEq, Tables, Test +using ModelingToolkit: t_nounits as t, D_nounits as D + +@variables x(t) y(t) +@parameters p q +@mtkbuild sys = ODESystem( + [D(x) ~ p * x + q * y, 0 ~ y^2 - 2x * y - 4], t, guesses = [y => 1.0]) +prob = ODEProblem(sys, [x => 1.0], (0.0, 5.0), [p => 1.0, q => 2.0]) +sol = solve(prob, Rodas5P()) + +@test sort(collect(Tables.columnnames(Tables.columns(sol)))) == + [:timestamp, Symbol("x(t)"), Symbol("y(t)")] diff --git a/test/runtests.jl b/test/runtests.jl index 03fd38506..8874ccac1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -115,6 +115,9 @@ end @time @safetestset "ODE Solution Stripping" begin include("downstream/ode_stripping.jl") end + @time @safetestset "Tables interface with MTK" begin + include("downstream/tables.jl") + end end if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface") From 8984269577c3f0aef5db7ea8dce8ff1b1cae716b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Oct 2024 05:05:23 -0400 Subject: [PATCH 3/4] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index eda166335..40b79cd20 100644 --- a/Project.toml +++ b/Project.toml @@ -116,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff"] +test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"] From 17e3b0ee5ac34320b7fd8625d0b9da7595ad63ad Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 9 Oct 2024 15:08:02 +0530 Subject: [PATCH 4/4] test: `AbstractDiffEqArray` is no longer `isrowtable` --- test/traits.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/traits.jl b/test/traits.jl index bb97dae9f..1d7461481 100644 --- a/test/traits.jl +++ b/test/traits.jl @@ -1,9 +1,10 @@ using SciMLBase, Tables, Test using OrdinaryDiffEq, DataFrames, SymbolicIndexingInterface -@test Tables.isrowtable(ODESolution) -@test Tables.isrowtable(RODESolution) -@test Tables.isrowtable(DAESolution) +# https://github.com/SciML/SciMLBase.jl/pull/813#issuecomment-2401803039 +@test !Tables.isrowtable(ODESolution) +@test !Tables.isrowtable(RODESolution) +@test !Tables.isrowtable(DAESolution) @test !Tables.isrowtable(SciMLBase.NonlinearSolution) @test !Tables.isrowtable(SciMLBase.LinearSolution) @test !Tables.isrowtable(SciMLBase.QuadratureSolution)