-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] ExpectationProblem interface #55
Changes from all commits
ee274b8
78dac26
b1d2139
87283d4
0bcc3ca
4cf45df
86aa308
a612d20
e7273e2
2971537
0c6b810
677c5ab
89485ab
2ca17cd
2062cf9
054cd7c
5038752
66c331c
b7ebe8a
7365134
48999eb
d5d5a64
17af0ff
e741324
6b7f279
759aa46
1dbe5a0
9b28bbc
e564fea
cc023a7
8bfaa89
03bbc40
d7a5278
61e5779
df42709
235b5fd
50024a8
d26129d
e553fd6
8b97077
b82abdd
2c990d2
a3bd13a
c94c730
bd745e6
2f5d90b
8abd919
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
style = "sciml" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
*.jl.*.cov | ||
*.jl.mem | ||
Manifest.toml | ||
tests/Manifest.toml |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,38 @@ | ||
module DiffEqUncertainty | ||
|
||
using DiffEqBase, Statistics, Distributions, Reexport | ||
# LinearAlgebra | ||
using DiffEqBase, Statistics, Reexport, RecursiveArrayTools, StaticArrays, | ||
Distributions, KernelDensity, Zygote, LinearAlgebra, Random | ||
using Parameters: @unpack | ||
|
||
@reexport using Integrals | ||
using KernelDensity | ||
import DiffEqBase: solve | ||
|
||
include("expectation/system_utils.jl") | ||
include("expectation/distribution_utils.jl") | ||
include("expectation/problem_types.jl") | ||
include("expectation/solution_types.jl") | ||
include("expectation/expectation.jl") | ||
|
||
include("probints.jl") | ||
include("koopman.jl") | ||
|
||
# Type Piracy, should upstream | ||
Base.eltype(K::UnivariateKDE) = eltype(K.density) | ||
Base.eltype(K::UnivariateKDE) = eltype(K.density) | ||
Base.minimum(K::UnivariateKDE) = minimum(K.x) | ||
Base.maximum(K::UnivariateKDE) = maximum(K.x) | ||
Base.extrema(K::UnivariateKDE) = minimum(K), maximum(K) | ||
|
||
Base.minimum(d::AbstractMvNormal) = fill(-Inf, length(d)) | ||
Base.maximum(d::AbstractMvNormal) = fill(Inf, length(d)) | ||
Base.extrema(d::AbstractMvNormal) = minimum(d), maximum(d) | ||
|
||
Base.minimum(d::Product) = minimum.(d.v) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are already defined in Distributions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had these here as a stop gap until this PR when through JuliaStats/Distributions.jl#1319 |
||
Base.maximum(d::Product) = maximum.(d.v) | ||
Base.extrema(d::Product) = minimum(d), maximum(d) | ||
|
||
export ProbIntsUncertainty, AdaptiveProbIntsUncertainty | ||
|
||
export ProbIntsUncertainty,AdaptiveProbIntsUncertainty | ||
export expectation, centralmoment, Koopman, MonteCarlo | ||
export Koopman, MonteCarlo, PrefusedAD, PostfusedAD, NonfusedAD | ||
export GenericDistribution, SystemMap, ExpectationProblem, build_integrand | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
`GenericDistribution(d, ds...)` | ||
|
||
Defines a generic distribution that just wraps functions for pdf function, rand and bounds. | ||
User can use this for define any arbitray joint pdf. Included b/c Distributions.jl Product | ||
method of mixed distirbutions are type instable | ||
""" | ||
struct GenericDistribution{TF, TRF, TLB, TUB} | ||
pdf_func::TF | ||
rand_func::TRF | ||
lb::TLB | ||
ub::TUB | ||
end | ||
|
||
function GenericDistribution(d, ds...) | ||
dists = (d, ds...) | ||
pdf_func(x) = exp(sum(logpdf(f, y) for (f, y) in zip(dists, x))) | ||
rand_func() = [rand(d) for d in dists] | ||
lb = SVector(map(minimum, dists)...) | ||
ub = SVector(map(maximum, dists)...) | ||
|
||
GenericDistribution(pdf_func, rand_func, lb, ub) | ||
end | ||
|
||
Distributions.pdf(d::GenericDistribution, x) = d.pdf_func(x) | ||
Base.minimum(d::GenericDistribution) = d.lb | ||
Base.maximum(d::GenericDistribution) = d.ub | ||
Base.extrema(d::GenericDistribution) = minimum(d), maximum(d) | ||
Base.rand(d::GenericDistribution) = d.rand_func() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
abstract type AbstractExpectationADAlgorithm end | ||
struct NonfusedAD <: AbstractExpectationADAlgorithm end | ||
struct PrefusedAD <: AbstractExpectationADAlgorithm | ||
norm_partials::Bool | ||
end | ||
PrefusedAD() = PrefusedAD(true) | ||
struct PostfusedAD <: AbstractExpectationADAlgorithm | ||
norm_partials::Bool | ||
end | ||
PostfusedAD() = PostfusedAD(true) | ||
|
||
abstract type AbstractExpectationAlgorithm <: DiffEqBase.DEAlgorithm end | ||
struct Koopman{TS} <: | ||
AbstractExpectationAlgorithm where {TS <: AbstractExpectationADAlgorithm} | ||
sensealg::TS | ||
end | ||
Koopman() = Koopman(NonfusedAD()) | ||
struct MonteCarlo <: AbstractExpectationAlgorithm | ||
trajectories::Int | ||
end | ||
|
||
# Builds integrand for arbitrary functions | ||
function build_integrand(prob::ExpectationProblem, ::Koopman, ::Val{false}) | ||
@unpack g, d = prob | ||
function (x, p) | ||
g(x, p) * pdf(d, x) | ||
end | ||
end | ||
|
||
# Builds integrand for DEProblems | ||
function build_integrand(prob::ExpectationProblem{F}, ::Koopman, | ||
::Val{false}) where {F <: SystemMap} | ||
@unpack S, g, h, d = prob | ||
function (x, p) | ||
uΜ, pΜ = h(x, p.x[1], p.x[2]) | ||
g(S(uΜ, pΜ), pΜ) * pdf(d, x) | ||
end | ||
end | ||
|
||
function _make_view(x::Union{Vector{T}, Adjoint{T, Vector{T}}}, i) where {T} | ||
@view x[i] | ||
end | ||
|
||
function _make_view(x, i) | ||
@view x[:, i] | ||
end | ||
|
||
function build_integrand(prob::ExpectationProblem{F}, ::Koopman, | ||
::Val{true}) where {F <: SystemMap} | ||
@unpack S, g, h, d = prob | ||
|
||
if prob.nout == 1 # TODO fix upstream in quadrature, expected sizes depend on quadrature method is requires different copying based on nout > 1 | ||
set_result! = @inline function (dx, sol) | ||
dx[:] .= sol[:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about the semantics of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also highlights an interface issue with Quadrature (which I am probably guilty of creating), i.e. we shouldn't really need separate |
||
end | ||
else | ||
set_result! = @inline function (dx, sol) | ||
dx .= reshape(sol[:, :], size(dx)) | ||
end | ||
end | ||
|
||
prob_func = function (prob, i, repeat, x) # TODO is it better to make prob/output funcs outside of integrand, then call w/ closure? | ||
u0, p = h((_make_view(x, i)), prob.u0, prob.p) | ||
remake(prob, u0 = u0, p = p) | ||
end | ||
|
||
output_func(sol, i, x) = (g(sol, sol.prob.p) * pdf(d, (_make_view(x, i))), false) | ||
|
||
function (dx, x, p) where {T} | ||
trajectories = size(x, 2) | ||
# TODO How to inject ensemble method in solve? currently in SystemMap, but does that make sense? | ||
ensprob = EnsembleProblem(S.prob; output_func = (sol, i) -> output_func(sol, i, x), | ||
prob_func = (prob, i, repeat) -> prob_func(prob, i, | ||
repeat, x)) | ||
sol = solve(ensprob, S.args...; trajectories = trajectories, S.kwargs...) | ||
set_result!(dx, sol) | ||
nothing | ||
end | ||
end | ||
|
||
# solve expectation problem of generic callable functions via MonteCarlo | ||
function DiffEqBase.solve(exprob::ExpectationProblem, expalg::MonteCarlo) | ||
params = parameters(exprob) | ||
dist = distribution(exprob) | ||
g = observable(exprob) | ||
ExpectationSolution(mean(g(rand(dist), params) for _ in 1:(expalg.trajectories)), nothing, nothing) | ||
end | ||
|
||
# solve expectation over DEProblem via MonteCarlo | ||
function DiffEqBase.solve(exprob::ExpectationProblem{F}, | ||
expalg::MonteCarlo) where {F <: SystemMap} | ||
d = distribution(exprob) | ||
cov = input_cov(exprob) | ||
S = mapping(exprob) | ||
g = observable(exprob) | ||
|
||
prob_func = function (prob, i, repeat) | ||
u0, p = cov(rand(d), prob.u0, prob.p) | ||
remake(prob, u0 = u0, p = p) | ||
end | ||
|
||
output_func(sol, i) = (g(sol, sol.prob.p), false) | ||
|
||
monte_prob = EnsembleProblem(S.prob; | ||
output_func = output_func, | ||
prob_func = prob_func) | ||
sol = solve(monte_prob, S.args...; trajectories = expalg.trajectories, S.kwargs...) | ||
ExpectationSolution(mean(sol.u),nothing,nothing) | ||
end | ||
|
||
# Solve Koopman expectation | ||
function DiffEqBase.solve(prob::ExpectationProblem, expalg::Koopman, args...; | ||
maxiters = 1000000, | ||
batch = 0, | ||
quadalg = HCubatureJL(), | ||
ireltol = 1e-2, iabstol = 1e-2, | ||
kwargs...) where {A <: AbstractExpectationADAlgorithm} | ||
integrand = build_integrand(prob, expalg, Val(batch > 1)) | ||
lb, ub = extrema(prob.d) | ||
|
||
sol = integrate(quadalg, expalg.sensealg, integrand, lb, ub, prob.params; | ||
reltol = ireltol, abstol = iabstol, maxiters = maxiters, | ||
nout = prob.nout, batch = batch, | ||
kwargs...) | ||
|
||
return ExpectationSolution(sol.u,sol.resid,sol) | ||
end | ||
|
||
# Integrate function to test new Adjoints, will need to roll up to Integrals.jl | ||
function integrate(quadalg, adalg::AbstractExpectationADAlgorithm, f, lb::TB, ub::TB, p; | ||
nout = 1, batch = 0, | ||
kwargs...) where {TB} | ||
#TODO check batch iip type stability w/ IntegralProblem{XXXX} | ||
prob = IntegralProblem{batch > 1}(f, lb, ub, p; nout = nout, batch = batch) | ||
solve(prob, quadalg; kwargs...) | ||
end | ||
|
||
# defines adjoint via β«β/βp f(x,p) dx | ||
Zygote.@adjoint function integrate(quadalg, adalg::NonfusedAD, f::F, lb::T, ub::T, | ||
params::P; | ||
nout = 1, batch = 0, norm = norm, | ||
kwargs...) where {F, T, P} | ||
primal = integrate(quadalg, adalg, f, lb, ub, params; | ||
norm = norm, nout = nout, batch = batch, | ||
kwargs...) | ||
|
||
function integrate_pullbacks(Ξ) | ||
function dfdp(x, params) | ||
_, back = Zygote.pullback(p -> f(x, p), params) | ||
back(Ξ)[1] | ||
end | ||
βp = integrate(quadalg, adalg, dfdp, lb, ub, params; | ||
norm = norm, nout = nout * length(params), batch = batch, | ||
kwargs...) | ||
# βlb = -f(lb,params) #needs correct for dim > 1 | ||
# βub = f(ub,params) | ||
return nothing, nothing, nothing, nothing, nothing, βp | ||
end | ||
primal, integrate_pullbacks | ||
end | ||
|
||
# defines adjoint via β«[f(x,p; β/βp f(x,p)] dx, ie it fuses the primal, post the primal calculation | ||
# has flag to only compute quad norm with respect to only the primal in the pull-back. Gives same quadrature points as doing forwarddiff | ||
Zygote.@adjoint function integrate(quadalg, adalg::PostfusedAD, f::F, lb::T, ub::T, | ||
params::P; | ||
nout = 1, batch = 0, norm = norm, | ||
kwargs...) where {F, T, P} | ||
primal = integrate(quadalg, adalg, f, lb, ub, params; | ||
norm = norm, nout = nout, batch = batch, | ||
kwargs...) | ||
|
||
_norm = adalg.norm_partials ? norm : primalnorm(nout, norm) | ||
|
||
function integrate_pullbacks(Ξ) | ||
function dfdp(x, params) | ||
y, back = Zygote.pullback(p -> f(x, p), params) | ||
[y; back(Ξ)[1]] #TODO need to match proper arrray type? promote_type??? | ||
end | ||
βp = integrate(quadalg, adalg, dfdp, lb, ub, params; | ||
norm = _norm, nout = nout + nout * length(params), batch = batch, | ||
kwargs...) | ||
return nothing, nothing, nothing, nothing, nothing, @view βp[(nout + 1):end] | ||
end | ||
primal, integrate_pullbacks | ||
end | ||
|
||
# Fuses primal and partials prior to pullback, I doubt this will stick around based on required system evals. | ||
Zygote.@adjoint function integrate(quadalg, adalg::PrefusedAD, f::F, lb::T, ub::T, | ||
params::P; | ||
nout = 1, batch = 0, norm = norm, | ||
kwargs...) where {F, T, P} | ||
# from Seth Axen via Slack | ||
# Does not work w/ ArrayPartition unless with following hack | ||
# Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(Array(A), T, dims) | ||
# TODO add ArrayPartition similar fix upstream, see https://github.com/SciML/RecursiveArrayTools.jl/issues/135 | ||
βf_βparams(x, params) = only(Zygote.jacobian(p -> f(x, p), params)) | ||
f_augmented(x, params) = [f(x, params); βf_βparams(x, params)...] #TODO need to match proper arrray type? promote_type??? | ||
_norm = adalg.norm_partials ? norm : primalnorm(nout, norm) | ||
|
||
res = integrate(quadalg, adalg, f_augmented, lb, ub, params; | ||
norm = _norm, nout = nout + nout * length(params), batch = batch, | ||
kwargs...) | ||
primal = first(res) | ||
function integrate_pullback(Ξy) | ||
βparams = Ξy .* conj.(@view(res[(nout + 1):end])) | ||
return nothing, nothing, nothing, nothing, nothing, βparams | ||
end | ||
primal, integrate_pullback | ||
end | ||
|
||
# define norm function based only on primal part of fused integrand | ||
function primalnorm(nout, fnorm) | ||
x -> fnorm(@view x[1:nout]) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
abstract type AbstractUncertaintyProblem end | ||
|
||
struct ExpectationProblem{TS, TG, TH, TF, TP} <: AbstractUncertaintyProblem | ||
# defines β« g(S(h(x,u0,p)))*f(x)dx | ||
# π = uncertainty space, π = Initial condition space, β = model parameter space, | ||
S::TS # mapping, S: π Γ β β π | ||
g::TG # observable(output_func), g: π Γ β β ββΏα΅α΅α΅ | ||
h::TH # cov(input_func), h: π Γ π Γ β β π Γ β | ||
d::TF # distribution, pdf(d,x): π β β | ||
params::TP | ||
nout::Int | ||
end | ||
|
||
# Constructor for general maps/functions | ||
function ExpectationProblem(g, pdist, params; nout = 1) | ||
h(x, u, p) = x, p | ||
S(x, p) = x | ||
ExpectationProblem(S, g, h, pdist, params, nout) | ||
end | ||
|
||
# Constructor for DEProblems | ||
function ExpectationProblem(sm::SystemMap, g, h, d; nout = 1) | ||
ExpectationProblem(sm, g, h, d, | ||
ArrayPartition(deepcopy(sm.prob.u0),deepcopy(sm.prob.p)), | ||
nout) | ||
end | ||
|
||
distribution(prob::ExpectationProblem) = prob.d | ||
mapping(prob::ExpectationProblem) = prob.S | ||
observable(prob::ExpectationProblem) = prob.g | ||
input_cov(prob::ExpectationProblem) = prob.h | ||
parameters(prob::ExpectationProblem) = prob.params | ||
|
||
## | ||
# struct CentralMomentProblem | ||
# ns::NTuple{Int,N} | ||
# altype::Union{NestedExpectation, BinomialExpansion} #Should rely be in solve | ||
# exp_prob::ExpectationProblem | ||
# end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
struct ExpectationSolution{uType,R,O} | ||
u::uType | ||
resid::R | ||
original::O | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#Callable wrapper for DE solves. Enables seperation of args/kwargs... | ||
struct SystemMap{DT<:DiffEqBase.DEProblem,A,K} | ||
prob::DT | ||
args::A | ||
kwargs::K | ||
end | ||
SystemMap(prob, args...; kwargs...) = SystemMap(prob, args, kwargs) | ||
|
||
function (sm::SystemMap{DT})(u0,p) where DT | ||
prob::DT = remake(sm.prob, | ||
u0 = convert(typeof(sm.prob.u0),u0), | ||
p = convert(typeof(sm.prob.p), p)) | ||
solve(prob, sm.args...; sm.kwargs...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines are technically type piracy right? Also, they might not be correct for any
AbstractNormal
that accepts singular covariance matrices, e.g.,Diagonal([0, 1])
does not have infinite support for the first variable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the
UnivariateKDE
parts should be upstreamed.Base.minimum(d::AbstractMvNormal)
was upstreamed in Distributions in JuliaStats/Distributions.jl#1319, probably need to double check the PR discussion to make sure a bug wasn't introduced.