Skip to content

Commit

Permalink
Add ParamEstim module
Browse files Browse the repository at this point in the history
  • Loading branch information
gerlero committed Jan 3, 2023
1 parent dbc5f80 commit 2de5533
Show file tree
Hide file tree
Showing 9 changed files with 335 additions and 63 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
LsqFit = "2fda8390-95c7-5789-9bda-21331edee243"
PCHIPInterpolation = "afe20452-48d1-4729-9a8b-50fb251f06cd"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184"
Expand All @@ -17,6 +18,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
ArgCheck = "2"
DiffResults = "1"
ForwardDiff = "0.10"
LsqFit = "0.13"
OrdinaryDiffEq = "6"
PCHIPInterpolation = "0.1"
RecipesBase = "1"
Expand Down
158 changes: 97 additions & 61 deletions docs/Manifest.toml

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Documenter
DocMeta.setdocmeta!(Fronts, :DocTestSetup, :(using Fronts); recursive=true)

makedocs(;
modules=[Fronts],
modules=[Fronts, Fronts.ParamEstim, Fronts.PorousModels],
authors="Gabriel S. Gerlero",
repo="https://github.com/gerlero/Fronts.jl/blob/{commit}{path}#{line}",
sitename="Fronts.jl",
Expand All @@ -22,6 +22,7 @@ makedocs(;
"Solutions" => "solution.md",
"Boltzmann transformation" => "boltzmann.md",
"Inverse problems" => "inverse.md",
"Parameter estimation" => "ParamEstim.md",
"Unsaturated flow models" => "PorousModels.md",
],
)
Expand Down
13 changes: 13 additions & 0 deletions docs/src/ParamEstim.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
```@meta
CurrentModule = Fronts.ParamEstim
```

# `Fronts.ParamEstim` module: parameter estimation support

The `ParamEstim` submodule provides support for optimization-based parameter estimation runs using `Fronts`.

```@docs
RSSCostFunction
candidate
trysolve
```
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Welcome to the documentation of the [`Fronts` package for Julia](https://github.
## Contents

```@contents
Pages = ["equations.md", "problems.md", "solvers.md", "solution.md", "boltzmann.md", "inverse.md", "PorousModels.md"]
Pages = ["equations.md", "problems.md", "solvers.md", "solution.md", "boltzmann.md", "inverse.md", "ParamEstim.md", "PorousModels.md"]
```

!!! note
Expand Down
2 changes: 2 additions & 0 deletions src/Fronts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@ export Solution, rb, flux, sorptivity
export SolvingError
export inverse

include("ParamEstim.jl")

end
175 changes: 175 additions & 0 deletions src/ParamEstim.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
module ParamEstim

import ..Fronts
using ..Fronts: Problem, Solution, solve, SolvingError

using LsqFit: curve_fit

"""
RSSCostFunction{fit_D0}(func, ϕ, data[, weights; catch_errors, D0tol, ϕi_hint])
Residual sum of squares cost function for parameter estimation.
# Type parameters
- `fit_D0::Bool`: whether to fit an additional constant factor `D0` that affects the diffusivity. Values
of `D0` can be found with relative efficiency without additional solver calls; so if any such constant
factors affecting the diffusivity are unknown, it is recommended not to fit those factors directly but set
`fit_D0` to `true` instead. Values of `D0` are found internally by local optimization, and they can be
retrieved by calling the `candidate` function.
# Arguments
- `func`: function that takes a vector of parameter values and returns either a `Fronts.Solution` or a
`Fronts.Problem`. If func returns a `Problem`, it is solved with `trysolve`. `func` is also allowed to
return `nothing` to signal that no solution could be found for the parameter values, which will imply an
infinite cost (see also the `catch_errors` keyword argument).
- `ϕ`: vector of values of the Boltzmann variable. See [`Fronts.ϕ`](@ref).
- `data`: data to fit. Must be a vector of the same length as `ϕ`.
- `weights`: optional weights for the data. If given, must be a vector of the same length as `data`.
# Keyword arguments
- `catch_errors=(Fronts.SolvingError,)`: collection of exception types that `func` is allowed to throw;
any of these exceptions will be caught and will result in an infinite cost.
- `D0tol=1e-3`: if `fit_D0` is `true`, a tolerance for `D0`.
- `ϕi_hint=ϕ[end]`: if `fit_D0` is `true`, a hint as to the point in ϕ where the initial condition begins.
The hint will be used as an aid in finding the optimal value for `D0`.
---
(::RSSCostFunction)(p::AbstractVector)
Return the cost of the solution obtained with parameter values `p`.
The `RSSCostFunction` object is meant to be passed to your optimizer of choice for minimization as the
objective function.
If you need to know more than just the cost, call the `candidate` function instead.
See also: [`candidate`](@ref), [`Fronts.Solution`](@ref), [`Fronts.Problem`](@ref), [`trysolve`](@ref)
"""
struct RSSCostFunction{fit_D0, _Tfunc, _Tϕ, _Tdata, _Tweights, _Tcatch_errors, _Tϕi_hint, _TD0tol}
_func::_Tfunc
::_Tϕ
_data::_Tdata
_weights::_Tweights
_catch_errors::_Tcatch_errors
_D0tol::_TD0tol
_ϕi_hint::_Tϕi_hint

function RSSCostFunction{true}(func, ϕ, data, weights=nothing; ϕi_hint=ϕ[end], D0tol=1e-3, catch_errors=(SolvingError,))
new{true,typeof(func),typeof(ϕ),typeof(data),typeof(weights),typeof(catch_errors),typeof(ϕi_hint),typeof(D0tol)}(func, ϕ, data, weights, catch_errors, D0tol, ϕi_hint)
end

function RSSCostFunction{false}(func, ϕ, data, weights=nothing; catch_errors=(SolvingError,))
new{false,typeof(func),typeof(ϕ),typeof(data),typeof(weights),typeof(catch_errors),Nothing,Nothing}(func, ϕ, data, weights, catch_errors, nothing, nothing)
end
end

(cf::RSSCostFunction)(arg) = candidate(cf, arg).cost

"""
trysolve(prob[, catch_errors, kwargs...])::Union{Fronts.Solution, Nothing}
Attempt to solve a problem with `Fronts.solve` and return the solution, but catch any exceptions of
the types included in `catch_errors` and return `nothing` on such failures.
# Arguments
- `prob`: problem to be solved.
# Keyword arguments
- `catch_errors=(Fronts.SolvingError,)`: collection of exception types that should be caught.
- `kwargs...`: any additional keyword arguments are passed to `solve`.
See also: [`Fronts.solve`](@ref), [`Fronts.SolvingError`](@ref)
"""
function trysolve(args...; catch_errors=(SolvingError,), kwargs...)
try
solve(args...; kwargs...)
catch e
if any(e isa err for err in catch_errors)
return nothing
else
rethrow(e)
end
end
end

function trysolve(cf::RSSCostFunction, params::AbstractVector)
try
return trysolve(cf, cf._func(params))
catch e
if any(e isa err for err in cf._catch_errors)
return nothing
else
rethrow(e)
end
end
end

function trysolve(cf::RSSCostFunction, prob::Problem)
return trysolve(prob, catch_errors=cf._catch_errors)
end

trysolve(::RSSCostFunction, sol::Solution) = sol

trysolve(::RSSCostFunction, ::Nothing) = nothing


struct _Candidate
sol::Union{Solution,Nothing}
D0::Float64
cost::Float64
end

"""
candidate(cf::RSSCostFunction, ::AbstractVector)
candidate(cf::RSSCostFunction, ::Fronts.Problem)
candidate(cf::RSSCostFunction, ::Fronts.Solution)
candidate(cf::RSSCostFunction, ::Nothing)
Return the candidate solution (including the cost) for a given cost function and parameter values,
problem, or solution.
The return of this function has the following fields:
- `sol`: the solution, or `nothing` if no solution could be found.
- `D0`: if `cf` has `fit_D0` set to `true` and `sol` is not `nothing`, the found value of `D0`.
- `cost`: the cost of the solution; infinite if `sol` is `nothing`.
"""
candidate(cf::RSSCostFunction, params::AbstractVector) = candidate(cf, trysolve(cf, params))

candidate(cf::RSSCostFunction, prob::Problem) = candidate(cf, trysolve(cf, prob))

candidate(::RSSCostFunction{true}, ::Nothing) = _Candidate(nothing, NaN, Inf)

candidate(::RSSCostFunction{false}, ::Nothing) = _Candidate(nothing, 1, Inf)

function candidate(cf::RSSCostFunction{false}, sol::Solution)
if !isnothing(cf._weights)
return _Candidate(sol, 1, sum(cf._weights.*(sol.(cf._ϕ) .- cf._data).^2))
else
return _Candidate(sol, 1, sum((sol.(cf._ϕ) .- cf._data).^2))
end
end

function candidate(cf::RSSCostFunction{true}, sol::Solution)
scaled!(ret, ϕ, (D0,)) = (ret .= sol.(ϕ./√D0))

scaling = curve_fit(scaled!,
cf._ϕ,
cf._data,
(!isnothing(cf._weights) ? (cf._weights,) : ())...,
[(cf._ϕi_hint/sol.ϕi)^2],
inplace=true,
lower=[0.0],
autodiff=:forwarddiff,
x_tol=cf._D0tol)

if !scaling.converged
@warn "Attempt to fit D0 did not converge"
end

return _Candidate(sol, only(scaling.param), sum(scaling.resid.^2))
end

export RSSCostFunction, candidate, trysolve

end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Fronts
using Fronts.PorousModels
using Fronts.ParamEstim
using Test
using Fronts._Diff: derivative
using OrdinaryDiffEq: ODEFunction, ODEProblem
Expand All @@ -16,6 +17,7 @@ using Plots: plot
include("test_transform.jl")
include("test_isindomain.jl")
include("test_inverse.jl")
include("test_ParamEstim.jl")
include("test_PorousModels.jl")
include("test_plot.jl")
end
41 changes: 41 additions & 0 deletions test/test_ParamEstim.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@testset "ParamEstim" begin

@testset "fit_D0 false" begin
θ = solve(DirichletProblem-> 2*θ, i=0, b=1))
ϕ = range(0, 20, length=100)

cf = RSSCostFunction{false}(ϕ, θ.(ϕ), catch_errors=(DomainError,)) do (k,)
DirichletProblem-> k*θ, i=0, b=1)
end

@test cf([2]) == 0
@test cf([1]) > 0
@test cf([3]) > 0
@test cf([0]) == Inf

end

@testset "fit_D0 true" begin
θ = solve(DirichletProblem-> 2*θ, i=0, b=1))
ϕ = range(0, 20, length=100)

cf = RSSCostFunction{true}(ϕ, θ.(ϕ)) do (k,)
DirichletProblem-> k*θ, i=0, b=1)
end

cand = candidate(cf, [2])
@test isapprox(cand.D0, 1, atol=1e-3)
@test isapprox(cand.cost, 0, atol=1e-7)

cand = candidate(cf, [1])
@test isapprox(cand.D0, 2, atol=1e-3)
@test isapprox(cand.cost, 0, atol=1e-7)

cand = candidate(cf, [3])
@test isapprox(cand.D0, 2/3, atol=1e-3)
@test isapprox(cand.cost, 0, atol=1e-7)

@test_throws DomainError candidate(cf, [0])
end

end

0 comments on commit 2de5533

Please sign in to comment.