Skip to content

Commit

Permalink
Merge pull request #814 from AayushSabharwal/as/interp-empty-idxs
Browse files Browse the repository at this point in the history
fix: handle empty `idxs` in interpolation
  • Loading branch information
ChrisRackauckas authored Oct 9, 2024
2 parents c843174 + 661cafa commit e683d52
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ end
function (sol::AbstractODESolution)(t::Number, ::Type{deriv},
idxs::AbstractVector{<:Integer},
continuity) where {deriv}
if isempty(idxs)
return eltype(eltype(sol.u))[]
end
if eltype(sol.u) <: Number
idxs = only(idxs)
end
Expand All @@ -259,6 +262,9 @@ end
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector{<:Integer},
continuity) where {deriv}
if isempty(idxs)
return map(_ -> eltype(eltype(sol.u))[], t)
end
if eltype(sol.u) <: Number
idxs = only(idxs)
end
Expand Down Expand Up @@ -295,6 +301,9 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
any(isequal(NotSymbolic()), symbolic_type.(idxs))
error("Incorrect specification of `idxs`")
end
if isempty(idxs)
return eltype(eltype(sol.u))[]
end
error_if_observed_derivative(sol, idxs, deriv)
ps = parameter_values(sol)
if is_parameter_timeseries(sol) == Timeseries() && is_discrete_expression(sol, idxs)
Expand Down Expand Up @@ -335,6 +344,9 @@ end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::AbstractVector, continuity) where {deriv}
if isempty(idxs)
return map(_ -> eltype(eltype(sol.u))[], t)
end
error_if_observed_derivative(sol, idxs, deriv)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
getter = getu(sol, idxs)
Expand Down
14 changes: 14 additions & 0 deletions test/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ end
@test_throws ErrorException sol(-0.5)
@test_throws ErrorException sol([0, -0.5, 0])
end

@testset "interpolate with empty idxs" begin
f = (u, p, t) -> u
sol1 = SciMLBase.build_solution(
ODEProblem(f, 1.0, (0.0, 1.0)), :NoAlgorithm, 0.0:0.1:1.0, exp.(0.0:0.1:1.0))
sol2 = SciMLBase.build_solution(ODEProblem(f, [1.0, 2.0], (0.0, 1.0)), :NoAlgorithm,
0.0:0.1:1.0, vcat.(exp.(0.0:0.1:1.0), 2exp.(0.0:0.1:1.0)))
for sol in [sol1, sol2]
@test sol(0.15; idxs = []) == Float64[]
@test sol(0.15; idxs = Int[]) == Float64[]
@test sol([0.15, 0.25]; idxs = []) == [Float64[], Float64[]]
@test sol([0.15, 0.25]; idxs = Int[]) == [Float64[], Float64[]]
end
end

0 comments on commit e683d52

Please sign in to comment.