Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added additional loss against data for NNODE #666

Merged
merged 53 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
fc75bad
added additional loss against data for NNODE(still needs reviewing)
AstitvaAggarwal Mar 30, 2023
4f76ee3
Added additional Loss function Feature for NNODE
AstitvaAggarwal Apr 1, 2023
a15f6eb
Added Tests for additional_loss function feature
AstitvaAggarwal Apr 1, 2023
69ecc9f
formatted files
AstitvaAggarwal Apr 1, 2023
6508b35
docs
AstitvaAggarwal Apr 2, 2023
74af932
Docs for new NNODE Argument:additional_loss
AstitvaAggarwal Apr 2, 2023
9b3d249
fixed optimizations choice
AstitvaAggarwal Apr 2, 2023
4197cdb
rebased
AstitvaAggarwal Apr 3, 2023
68de9ef
added additional loss against data for NNODE(still needs reviewing)
AstitvaAggarwal Mar 30, 2023
b0d4718
Added additional Loss function Feature for NNODE
AstitvaAggarwal Apr 1, 2023
45c573e
Added Tests for additional_loss function feature
AstitvaAggarwal Apr 1, 2023
0d658d6
docs
AstitvaAggarwal Apr 2, 2023
c810ee3
Mixed my PR with PR #635 by sdesai1287
AstitvaAggarwal Apr 2, 2023
b9a4606
fixed optimizations choice
AstitvaAggarwal Apr 2, 2023
9221e6a
added additional loss against data for NNODE(still needs reviewing)
AstitvaAggarwal Mar 30, 2023
05afb08
Added additional Loss function Feature for NNODE
AstitvaAggarwal Apr 1, 2023
9c7f3f2
Added Tests for additional_loss function feature
AstitvaAggarwal Apr 1, 2023
a96aef3
docs
AstitvaAggarwal Apr 2, 2023
2847912
Mixed my PR with PR #635 by sdesai1287
AstitvaAggarwal Apr 2, 2023
b173314
fixed optimizations choice
AstitvaAggarwal Apr 2, 2023
1dcdc96
added additional loss against data for NNODE(still needs reviewing)
AstitvaAggarwal Mar 30, 2023
c0a5703
Added additional Loss function Feature for NNODE
AstitvaAggarwal Apr 1, 2023
df085ee
Added Tests for additional_loss function feature
AstitvaAggarwal Apr 1, 2023
2352c23
formatted files
AstitvaAggarwal Apr 1, 2023
e544d1f
docs
AstitvaAggarwal Apr 2, 2023
3e43a16
Mixed my PR with PR #635 by sdesai1287
AstitvaAggarwal Apr 2, 2023
79dd8fe
fixed optimizations choice
AstitvaAggarwal Apr 2, 2023
e28a3f2
rebased
AstitvaAggarwal Apr 3, 2023
d8e06cc
Actually performed Rebase and Formatted some files
AstitvaAggarwal Apr 3, 2023
2d20181
fixed ode_solve.jl rebase issues
AstitvaAggarwal Apr 4, 2023
c25a233
added additional loss against data for NNODE(still needs reviewing)
AstitvaAggarwal Mar 30, 2023
a5e200b
Added additional Loss function Feature for NNODE
AstitvaAggarwal Apr 1, 2023
568f934
Added Tests for additional_loss function feature
AstitvaAggarwal Apr 1, 2023
96f20bd
formatted files
AstitvaAggarwal Apr 1, 2023
08b6c6a
docs
AstitvaAggarwal Apr 2, 2023
8e0b7c8
Mixed my PR with PR #635 by sdesai1287
AstitvaAggarwal Apr 2, 2023
5b2ee57
fixed optimizations choice
AstitvaAggarwal Apr 2, 2023
d1216d4
rebased
AstitvaAggarwal Apr 3, 2023
f369613
Actually performed Rebase and Formatted some files
AstitvaAggarwal Apr 3, 2023
2f625f6
fixed ode_solve.jl rebase issues
AstitvaAggarwal Apr 4, 2023
e3938ed
.
AstitvaAggarwal Apr 4, 2023
44a3e90
added additional loss against data for NNODE(still needs reviewing)
AstitvaAggarwal Mar 30, 2023
f23b141
Added additional Loss function Feature for NNODE
AstitvaAggarwal Apr 1, 2023
f47159e
Added Tests for additional_loss function feature
AstitvaAggarwal Apr 1, 2023
b734a59
docs
AstitvaAggarwal Apr 2, 2023
18439e7
Mixed my PR with PR #635 by sdesai1287
AstitvaAggarwal Apr 2, 2023
c0f397e
fixed optimizations choice
AstitvaAggarwal Apr 2, 2023
43d9efb
Actually performed Rebase and Formatted some files
AstitvaAggarwal Apr 3, 2023
b6b31b2
rebase Fr Fr
AstitvaAggarwal Apr 4, 2023
09c9cf8
stuff happened
AstitvaAggarwal Apr 4, 2023
64f1464
Fixed Tests line 208 NNODE_tests
AstitvaAggarwal Apr 4, 2023
a2eef3d
changes from review
AstitvaAggarwal Apr 4, 2023
9495d96
Update src/ode_solve.jl
ChrisRackauckas Apr 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
"""
```julia
NNODE(chain, opt=OptimizationPolyalgorithms.PolyOpt(), init_params = nothing;
autodiff=false, batch=0, kwargs...)
autodiff=false, batch=0,additional_loss=nothing,kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
autodiff=false, batch=0,additional_loss=nothing,kwargs...)
autodiff=false, batch=0,additional_loss=nothing,
kwargs...)

```

Algorithm for solving ordinary differential equations using a neural network. This is a specialization
Expand All @@ -23,7 +23,14 @@ of the physics-informed neural network which is used as a solver for a standard
which thus uses the random initialization provided by the neural network library.

## Keyword Arguments

* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions,
θ are the weights of the neural network(s).
example:
ts=[t for t in 1:100]
(u_, t_) = (analytical_func(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't in the right spot. Make an example section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under this argument's description? or along with the lower example?

* `autodiff`: The switch between automatic and numerical differentiation for
the PDE operators. The reverse mode of the loss function is always
automatic differentiation (via Zygote), this is only for the derivative
Expand Down Expand Up @@ -63,20 +70,23 @@ is an accurate interpolation (up to the neural network training result). In addi
Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving
ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000.
"""
struct NNODE{C, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy}} <:
struct NNODE{C, O, P, B, K, AL <: Union{Nothing, Function},
S <: Union{Nothing, AbstractTrainingStrategy}
} <:
NeuralPDEAlgorithm
chain::C
opt::O
init_params::P
autodiff::Bool
batch::B
strategy::S
additional_loss::AL
kwargs::K
end
function NNODE(chain, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = nothing, kwargs...)
NNODE(chain, opt, init_params, autodiff, batch, strategy, kwargs)
autodiff = false, batch = nothing, additional_loss = nothing, kwargs...)
NNODE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs)
end

"""
Expand Down Expand Up @@ -236,7 +246,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
end

"""
Representation of the loss function, paramtric on the training strategy `strategy`
Representation of the loss function, parametric on the training strategy `strategy`
"""
function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
batch)
Expand All @@ -250,38 +260,38 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp
sol.u
end

# Default this to ForwardDiff until Integrals.jl autodiff is sorted out
OptimizationFunction(loss, Optimization.AutoForwardDiff())
return loss
end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch)
ts = tspan[1]:(strategy.dx):tspan[2]

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken

function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
end
end
optf = OptimizationFunction(loss, Optimization.AutoZygote())

return loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return loss
return loss

end

function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
batch)
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])

if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
end
end
optf = OptimizationFunction(loss, Optimization.AutoZygote())

return loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return loss
return loss

end

function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
Expand Down Expand Up @@ -312,7 +322,8 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
end
end
optf = OptimizationFunction(loss, Optimization.AutoZygote())

return loss
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
end

function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan)
Expand Down Expand Up @@ -407,7 +418,34 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
alg.batch
end

optf = generate_loss(strategy, phi, f, autodiff::Bool, tspan, p, batch)
# additional loss
additional_loss = alg.additional_loss

# Creates OptimizationFunction Object from total_loss
function total_loss(θ, _)
L2_loss = generate_loss(strategy, phi, f, autodiff, tspan, p, batch)(θ, phi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# additional loss
additional_loss = alg.additional_loss
# Creates OptimizationFunction Object from total_loss
function total_loss(θ, _)
L2_loss = generate_loss(strategy, phi, f, autodiff, tspan, p, batch)(θ, phi)
inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch)
additional_loss = alg.additional_loss
# Creates OptimizationFunction Object from total_loss
function total_loss(θ, _)
L2_loss = inner_f (θ, phi)

if !(additional_loss isa Nothing)
return additional_loss(phi, θ) + L2_loss
end
L2_loss
end

# Choice of Optimization Algo for Training Strategies
opt_algo = if strategy isa QuadratureTraining
Optimization.AutoForwardDiff()
elseif strategy isa StochasticTraining
Optimization.AutoZygote()
elseif strategy isa WeightedIntervalTraining
Optimization.AutoZygote()
else
# by default GridTraining choice of Optimization
# if adding new training algorithms we can extend this,
# if-elseif-else block for choices of optimization algos
Optimization.AutoZygote()
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
opt_algo = if strategy isa QuadratureTraining
Optimization.AutoForwardDiff()
elseif strategy isa StochasticTraining
Optimization.AutoZygote()
elseif strategy isa WeightedIntervalTraining
Optimization.AutoZygote()
else
# by default GridTraining choice of Optimization
# if adding new training algorithms we can extend this,
# if-elseif-else block for choices of optimization algos
Optimization.AutoZygote()
end
opt_algo = if strategy isa QuadratureTraining
Optimization.AutoForwardDiff()
else
Optimization.AutoZygote()
end


# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

iteration = 0
callback = function (p, l)
Expand Down
95 changes: 95 additions & 0 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ sol = solve(prob, NeuralPDE.NNODE(luxchain, opt; batch = true), verbose = true,
abstol = 1.0f-8, dt = 1 / 5.0f0)
@test sol.errors[:l2] < 0.5

# WeightedIntervalTraining(Lux Chain)
function f(u, p, t)
[p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]]
end
Expand All @@ -228,3 +229,97 @@ alg = NeuralPDE.NNODE(chain, opt, autodiff = false,
sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01)

@test abs(mean(sol) - mean(true_sol)) < 0.2

# Checking if additional_loss feature works for NNODE
linear = (u, p, t) -> cos(2pi * t)
linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t)
tspan = (0.0f0, 1.0f0)
dt = (tspan[2] - tspan[1]) / 99
ts = collect(tspan[1]:dt:tspan[2])
prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0))
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))

# Analytical solution
u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x)

# GridTraining (Flux Chain)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))

(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
end

alg1 = NeuralPDE.NNODE(chain, opt, strategy = GridTraining(0.01),
additional_loss = additional_loss)

sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500)
@test sol1.errors[:l2] < 0.5

# GridTraining (Lux Chain)
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))

(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
end

alg1 = NeuralPDE.NNODE(luxchain, opt, strategy = GridTraining(0.01),
additional_loss = additional_loss)

sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500)
@test sol1.errors[:l2] < 0.5

# QuadratureTraining (Flux Chain)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))

(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
end

alg1 = NeuralPDE.NNODE(chain, opt, additional_loss = additional_loss)

sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-10, maxiters = 200)
@test sol1.errors[:l2] < 0.5

# QuadratureTraining (Lux Chain)
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))

(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
end

alg1 = NeuralPDE.NNODE(luxchain, opt, additional_loss = additional_loss)

sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-10, maxiters = 200)
@test sol1.errors[:l2] < 0.5

# StochasticTraining(Flux Chain)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))

(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
end

alg1 = NeuralPDE.NNODE(chain, opt, strategy = StochasticTraining(1000),
additional_loss = additional_loss)

sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500)
@test sol1.errors[:l2] < 0.5

# StochasticTraining (Lux Chain)
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))

(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
end

alg1 = NeuralPDE.NNODE(luxchain, opt, strategy = StochasticTraining(1000),
additional_loss = additional_loss)

sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500)
@test sol1.errors[:l2] < 0.5