-
-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added additional loss against data for NNODE #666
Changes from 51 commits
fc75bad
4f76ee3
a15f6eb
69ecc9f
6508b35
74af932
9b3d249
4197cdb
68de9ef
b0d4718
45c573e
0d658d6
c810ee3
b9a4606
9221e6a
05afb08
9c7f3f2
a96aef3
2847912
b173314
1dcdc96
c0a5703
df085ee
2352c23
e544d1f
3e43a16
79dd8fe
e28a3f2
d8e06cc
2d20181
c25a233
a5e200b
568f934
96f20bd
08b6c6a
8e0b7c8
5b2ee57
d1216d4
f369613
2f625f6
e3938ed
44a3e90
f23b141
f47159e
b734a59
18439e7
c0f397e
43d9efb
b6b31b2
09c9cf8
64f1464
a2eef3d
9495d96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,7 +3,7 @@ abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end | |||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
```julia | ||||||||||||||||||||||||||||||||||||
NNODE(chain, opt=OptimizationPolyalgorithms.PolyOpt(), init_params = nothing; | ||||||||||||||||||||||||||||||||||||
autodiff=false, batch=0, kwargs...) | ||||||||||||||||||||||||||||||||||||
autodiff=false, batch=0,additional_loss=nothing,kwargs...) | ||||||||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Algorithm for solving ordinary differential equations using a neural network. This is a specialization | ||||||||||||||||||||||||||||||||||||
|
@@ -23,7 +23,14 @@ of the physics-informed neural network which is used as a solver for a standard | |||||||||||||||||||||||||||||||||||
which thus uses the random initialization provided by the neural network library. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
## Keyword Arguments | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions, | ||||||||||||||||||||||||||||||||||||
θ are the weights of the neural network(s). | ||||||||||||||||||||||||||||||||||||
example: | ||||||||||||||||||||||||||||||||||||
ts=[t for t in 1:100] | ||||||||||||||||||||||||||||||||||||
(u_, t_) = (analytical_func(ts), ts) | ||||||||||||||||||||||||||||||||||||
function additional_loss(phi, θ) | ||||||||||||||||||||||||||||||||||||
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't in the right spot. Make an example section. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. under this argument's description? or along with the lower example? |
||||||||||||||||||||||||||||||||||||
* `autodiff`: The switch between automatic and numerical differentiation for | ||||||||||||||||||||||||||||||||||||
the PDE operators. The reverse mode of the loss function is always | ||||||||||||||||||||||||||||||||||||
automatic differentiation (via Zygote), this is only for the derivative | ||||||||||||||||||||||||||||||||||||
|
@@ -63,20 +70,23 @@ is an accurate interpolation (up to the neural network training result). In addi | |||||||||||||||||||||||||||||||||||
Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving | ||||||||||||||||||||||||||||||||||||
ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
struct NNODE{C, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy}} <: | ||||||||||||||||||||||||||||||||||||
struct NNODE{C, O, P, B, K, AL <: Union{Nothing, Function}, | ||||||||||||||||||||||||||||||||||||
S <: Union{Nothing, AbstractTrainingStrategy} | ||||||||||||||||||||||||||||||||||||
} <: | ||||||||||||||||||||||||||||||||||||
NeuralPDEAlgorithm | ||||||||||||||||||||||||||||||||||||
chain::C | ||||||||||||||||||||||||||||||||||||
opt::O | ||||||||||||||||||||||||||||||||||||
init_params::P | ||||||||||||||||||||||||||||||||||||
autodiff::Bool | ||||||||||||||||||||||||||||||||||||
batch::B | ||||||||||||||||||||||||||||||||||||
strategy::S | ||||||||||||||||||||||||||||||||||||
additional_loss::AL | ||||||||||||||||||||||||||||||||||||
kwargs::K | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
function NNODE(chain, opt, init_params = nothing; | ||||||||||||||||||||||||||||||||||||
strategy = nothing, | ||||||||||||||||||||||||||||||||||||
autodiff = false, batch = nothing, kwargs...) | ||||||||||||||||||||||||||||||||||||
NNODE(chain, opt, init_params, autodiff, batch, strategy, kwargs) | ||||||||||||||||||||||||||||||||||||
autodiff = false, batch = nothing, additional_loss = nothing, kwargs...) | ||||||||||||||||||||||||||||||||||||
NNODE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs) | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
@@ -236,7 +246,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, | |||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Representation of the loss function, paramtric on the training strategy `strategy` | ||||||||||||||||||||||||||||||||||||
Representation of the loss function, parametric on the training strategy `strategy` | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, | ||||||||||||||||||||||||||||||||||||
batch) | ||||||||||||||||||||||||||||||||||||
|
@@ -250,38 +260,38 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp | |||||||||||||||||||||||||||||||||||
sol.u | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Default this to ForwardDiff until Integrals.jl autodiff is sorted out | ||||||||||||||||||||||||||||||||||||
OptimizationFunction(loss, Optimization.AutoForwardDiff()) | ||||||||||||||||||||||||||||||||||||
return loss | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch) | ||||||||||||||||||||||||||||||||||||
ts = tspan[1]:(strategy.dx):tspan[2] | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
function loss(θ, _) | ||||||||||||||||||||||||||||||||||||
if batch | ||||||||||||||||||||||||||||||||||||
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p)) | ||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts]) | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
optf = OptimizationFunction(loss, Optimization.AutoZygote()) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return loss | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, | ||||||||||||||||||||||||||||||||||||
batch) | ||||||||||||||||||||||||||||||||||||
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken | ||||||||||||||||||||||||||||||||||||
function loss(θ, _) | ||||||||||||||||||||||||||||||||||||
ts = adapt(parameterless_type(θ), | ||||||||||||||||||||||||||||||||||||
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)]) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if batch | ||||||||||||||||||||||||||||||||||||
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p)) | ||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts]) | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
optf = OptimizationFunction(loss, Optimization.AutoZygote()) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return loss | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, | ||||||||||||||||||||||||||||||||||||
|
@@ -312,7 +322,8 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo | |||||||||||||||||||||||||||||||||||
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts]) | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
optf = OptimizationFunction(loss, Optimization.AutoZygote()) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return loss | ||||||||||||||||||||||||||||||||||||
ChrisRackauckas marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan) | ||||||||||||||||||||||||||||||||||||
|
@@ -407,7 +418,34 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, | |||||||||||||||||||||||||||||||||||
alg.batch | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
optf = generate_loss(strategy, phi, f, autodiff::Bool, tspan, p, batch) | ||||||||||||||||||||||||||||||||||||
# additional loss | ||||||||||||||||||||||||||||||||||||
additional_loss = alg.additional_loss | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Creates OptimizationFunction Object from total_loss | ||||||||||||||||||||||||||||||||||||
function total_loss(θ, _) | ||||||||||||||||||||||||||||||||||||
L2_loss = generate_loss(strategy, phi, f, autodiff, tspan, p, batch)(θ, phi) | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
if !(additional_loss isa Nothing) | ||||||||||||||||||||||||||||||||||||
return additional_loss(phi, θ) + L2_loss | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
L2_loss | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Choice of Optimization Algo for Training Strategies | ||||||||||||||||||||||||||||||||||||
opt_algo = if strategy isa QuadratureTraining | ||||||||||||||||||||||||||||||||||||
Optimization.AutoForwardDiff() | ||||||||||||||||||||||||||||||||||||
elseif strategy isa StochasticTraining | ||||||||||||||||||||||||||||||||||||
Optimization.AutoZygote() | ||||||||||||||||||||||||||||||||||||
elseif strategy isa WeightedIntervalTraining | ||||||||||||||||||||||||||||||||||||
Optimization.AutoZygote() | ||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||
# by default GridTraining choice of Optimization | ||||||||||||||||||||||||||||||||||||
# if adding new training algorithms we can extend this, | ||||||||||||||||||||||||||||||||||||
# if-elseif-else block for choices of optimization algos | ||||||||||||||||||||||||||||||||||||
Optimization.AutoZygote() | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Creates OptimizationFunction Object from total_loss | ||||||||||||||||||||||||||||||||||||
optf = OptimizationFunction(total_loss, opt_algo) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
iteration = 0 | ||||||||||||||||||||||||||||||||||||
callback = function (p, l) | ||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.