Skip to content

Commit

Permalink
Purge Flux from tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 10, 2023
1 parent cc4f56e commit f7c5adc
Show file tree
Hide file tree
Showing 25 changed files with 1,952 additions and 487 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
*.jl.*.cov
*.jl.mem
Manifest.toml
/docs/build/
/docs/build/
test/Manifest.toml
test/gpu/Manifest.toml
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ADTypes = "0.1, 0.2"
Expand Down Expand Up @@ -81,5 +80,4 @@ StochasticDiffEq = "6.20"
Tracker = "0.2"
TruncatedStacktraces = "1.2"
Zygote = "0.6"
ZygoteRules = "0.2"
julia = "1.9"
5 changes: 2 additions & 3 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using StochasticDiffEq
import DiffEqNoiseProcess
import RandomNumbers: Xorshifts
using Random
import ZygoteRules, Zygote, ReverseDiff
import Zygote, ReverseDiff
import ArrayInterface
import Enzyme
import GPUArraysCore
Expand All @@ -27,8 +27,7 @@ using FunctionProperties: hasbranching
using Markdown

using Reexport
import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ProjectTo,
project_type, _eltype_projectto, rrule
import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented
abstract type SensitivityFunction end
abstract type TransformedFunction end

Expand Down
22 changes: 7 additions & 15 deletions src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjointSensitivityFunction

function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, f,
colorvec, needs_jac)
colorvec, needs_jac)
@unpack p, u0 = sol.prob

diffcache, y = adjointdiffcache(g, sensealg, false, sol, dgdu, dgdp, f, alg;
Expand All @@ -30,26 +30,18 @@ function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp
end

@noinline function SteadyStateAdjointProblem(sol, sensealg::SteadyStateAdjoint, alg,
dgdu::DG1 = nothing, dgdp::DG2 = nothing,
g::G = nothing; kwargs...) where {DG1, DG2, G}
dgdu::DG1 = nothing, dgdp::DG2 = nothing, g::G = nothing;
kwargs...) where {DG1, DG2, G}
@unpack f, p, u0 = sol.prob

if sol.prob isa NonlinearProblem
f = ODEFunction(f)
end
sol.prob isa NonlinearProblem && (f = ODEFunction(f))

dgdu === nothing && dgdp === nothing && g === nothing &&
error("Either `dgdu`, `dgdp`, or `g` must be specified.")

needs_jac = if has_adjoint(f)
false
# TODO: What is the correct heuristic? Can we afford to compute Jacobian for
# cases where the length(u0) > 50 and if yes till what threshold
elseif sensealg.linsolve === nothing
length(u0) 50
else
LinearSolve.needs_concrete_A(sensealg.linsolve)
end
needs_jac = ifelse(has_adjoint(f), false,
ifelse(sensealg.linsolve === nothing, length(u0) 50,
LinearSolve.needs_concrete_A(sensealg.linsolve)))

p === DiffEqBase.NullParameters() &&
error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!")
Expand Down
116 changes: 50 additions & 66 deletions test/HybridNODE.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using SciMLSensitivity, OrdinaryDiffEq, DiffEqCallbacks, Flux
using SciMLSensitivity, OrdinaryDiffEq, DiffEqCallbacks, Lux, ComponentArrays
using Optimization, OptimizationOptimisers
using Random, Test
using Zygote

Expand All @@ -9,10 +10,10 @@ function test_hybridNODE(sensealg)
t = range(tspan[1], tspan[2], length = datalength)
target = 3.0 * (1:datalength) ./ datalength # some dummy data to fit to
cbinput = rand(1, datalength) #some external ODE contribution
pmodel = Chain(Dense(2, 10, init = zeros),
Dense(10, 2, init = zeros))
p, re = Flux.destructure(pmodel)
dudt(u, p, t) = re(p)(u)
pmodel = Chain(Dense(2, 10, init_weight = zeros32), Dense(10, 2, init_weight = zeros32))
ps, st = Lux.setup(Random.default_rng(), pmodel)
ps = ComponentArray{Float64}(ps)
dudt(u, p, t) = first(pmodel(u, p, st))

# callback changes the first component of the solution every time
# t is an integer
Expand All @@ -27,24 +28,23 @@ function test_hybridNODE(sensealg)

function predict_n_ode(p)
arr = Array(solve(prob, Tsit5(),
p = p, sensealg = sensealg, saveat = 2.0, callback = callback))[1,
2:2:end]
p = p, sensealg = sensealg, saveat = 2.0, callback = callback))[1, 2:2:end]
return arr[1:datalength]
end

function loss_n_ode()
function loss_n_ode(p, _)
pred = predict_n_ode(p)
loss = sum(abs2, target .- pred) ./ datalength
end

cb = function () #callback function to observe training
pred = predict_n_ode(p)
display(loss_n_ode())
cb = function (p, l) #callback function to observe training
@show l
return false
end
@show sensealg
Flux.train!(loss_n_ode, Flux.params(p), Iterators.repeated((), 20), ADAM(0.005),
cb = cb)
@test loss_n_ode() < 0.5
res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), ps),
Adam(0.005); callback = cb, maxiters = 200)
@test loss_n_ode(res.u, nothing) < 0.5
println(" ")
end

Expand All @@ -70,14 +70,15 @@ function test_hybridNODE2(sensealg)
ode_data = Array(sol)[1:2, 1:end]'

## Make model
dudt2 = Chain(Dense(4, 50, tanh),
Dense(50, 2))
p, re = Flux.destructure(dudt2) # use this p as the initial condition!
dudt2 = Chain(Dense(4, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), dudt2)
ps = ComponentArray{Float32}(ps)

function affect!(integrator)
integrator.u[3:4] = -3 * integrator.u[1:2]
end
function ODEfunc(dx, x, p, t)
dx[1:2] .= re(p)(x)
dx[1:2] .= first(dudt2(x, p, st))
dx[3:4] .= 0.0f0
end
z0 = u0
Expand All @@ -86,34 +87,27 @@ function test_hybridNODE2(sensealg)
initial_affect = true)

## Initialize learning functions
function predict_n_ode()
_prob = remake(prob, p = p)
Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = cb, save_everystep = false,
save_start = true, sensealg = sensealg))[1:2,
:]
end
function loss_n_ode()
pred = predict_n_ode()[1:2, 1:end]'
function predict_n_ode(ps)
Array(solve(prob, Tsit5(), u0 = z0, p = ps, callback = cb, save_everystep = false,
save_start = true, sensealg = sensealg))[1:2, :]
end
function loss_n_ode(ps, _)
pred = predict_n_ode(ps)[1:2, 1:end]'
loss = sum(abs2, ode_data .- pred)
loss
end
loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE
cba = function () #callback function to observe training
pred = predict_n_ode()[1:2, 1:end]'
display(sum(abs2, ode_data .- pred))

cba = function (p, loss) #callback function to observe training
@show loss
return false
end
cba()

## Learn
ps = Flux.params(p)
data = Iterators.repeated((), 25)

@show sensealg

Flux.train!(loss_n_ode, ps, data, ADAM(0.0025), cb = cba)
res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), ps),
Adam(0.0025); callback = cba, maxiters = 200)

@test loss_n_ode() < 0.5
@test loss_n_ode(res.u, nothing) < 0.5

println(" ")
end
Expand Down Expand Up @@ -142,14 +136,16 @@ function test_hybridNODE3(sensealg)
true_data = reshape(ode_data, (2, length(t), 1))
true_data = convert.(Float32, true_data)
callback_data = true_data * 1.0f-3
train_dataloader = Flux.Data.DataLoader((true_data = true_data,
callback_data = callback_data), batchsize = 1)
dudt2 = Chain(Dense(2, 50, tanh),
Dense(50, 2))
p, re = Flux.destructure(dudt2)

data = (true_data[:, :, 1], callback_data[:, :, 1])
dudt2 = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), dudt2)
ps = ComponentArray{Float32}(ps)

function dudt(du, u, p, t)
du .= re(p)(u)
du .= first(dudt2(u, p, st))
end

z0 = Float32[2.0; 0.0]
prob = ODEProblem(dudt, z0, tspan)

Expand All @@ -159,42 +155,30 @@ function test_hybridNODE3(sensealg)
DiscreteCallback(condition, affect!, save_positions = (false, false))
end

function predict_n_ode(true_data_0, callback_data, sense)
function predict_n_ode(p, true_data_0, callback_data, sense)
_prob = remake(prob, p = p, u0 = true_data_0)
solve(_prob, Tsit5(), callback = callback_(callback_data), saveat = t,
sensealg = sense)
end

function loss_n_ode(true_data, callback_data)
sol = predict_n_ode((vec(true_data[:, 1, :])), callback_data, sensealg)
function loss_n_ode(p, (true_data, callback_data))
sol = predict_n_ode(p, (vec(true_data[:, 1, :])), callback_data, sensealg)
pred = Array(sol)
loss = Flux.mse((true_data[:, :, 1]), pred)
loss = sum(abs2, true_data[:, :, 1] .- pred)
loss
end

ps = Flux.params(p)
opt = ADAM(0.1)
epochs = 10
function cb1(true_data, callback_data)
display(loss_n_ode(true_data, callback_data))
cba = function (p, loss) #callback function to observe training
@show loss
return false
end

function train!(loss, ps, data, opt, cb)
ps = Params(ps)
for (true_data, callback_data) in data
gs = gradient(ps) do
loss(true_data, callback_data)
end
Flux.update!(opt, ps, gs)
cb(true_data, callback_data)
end
return nothing
end
@show sensealg

res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), ps,
data), Adam(0.01); maxiters = 200, callback = cba)
loss = loss_n_ode(res.u, (true_data, callback_data))

Flux.@epochs epochs train!(loss_n_ode, Params(ps), train_dataloader, opt, cb1)
loss = loss_n_ode(true_data[:, :, 1], callback_data)
@info loss
@test loss < 0.5
end

Expand Down
Loading

0 comments on commit f7c5adc

Please sign in to comment.