Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update to support Lux 1.0 #522

Merged
merged 12 commits into from
Sep 21, 2024
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
style = "sciml"
format_markdown = true
format_docstrings = true
format_docstrings = true
annotate_untyped_fields_with_any = false
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: julia -e 'using CompatHelper; CompatHelper.main(;subdirs=["", "docs", "lib/DataDrivenDMD", "lib/DataDrivenSparse", "lib/DataDrivenSR"])'
run: julia -e 'using CompatHelper; CompatHelper.main(;subdirs=["", "docs", "lib/DataDrivenDMD", "lib/DataDrivenSparse", "lib/DataDrivenSR", "lib/DataDrivenLux"])'
4 changes: 2 additions & 2 deletions lib/DataDrivenDMD/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ end

function (algorithm::AbstractKoopmanAlgorithm)(prob::InternalDataDrivenProblem;
control_input = nothing, kwargs...)
@unpack traindata, testdata, control_idx, options = prob
@unpack abstol = options
(; traindata, testdata, control_idx, options) = prob
(; abstol) = options
# Preprocess control idx, indicates if any control is active in a single basis atom
control_idx = map(any, eachrow(control_idx))
no_controls = .!control_idx
Expand Down
34 changes: 22 additions & 12 deletions lib/DataDrivenLux/Project.toml
Original file line number Diff line number Diff line change
@@ -1,48 +1,58 @@
name = "DataDrivenLux"
uuid = "47881146-99d0-492a-8425-8f2f33327637"
authors = ["JuliusMartensen <[email protected]>"]
version = "0.1.1"
version = "0.2.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[compat]
AbstractDifferentiation = "0.4"
AbstractDifferentiation = "0.6"
ChainRulesCore = "1.15"
ComponentArrays = "0.13"
CommonSolve = "0.2.4"
ComponentArrays = "0.15"
ConcreteStructs = "0.2.3"
DataDrivenDiffEq = "1"
Distributions = "0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.9.3"
ForwardDiff = "0.10"
IntervalArithmetic = "0.20"
IntervalArithmetic = "0.22"
InverseFunctions = "0.1"
Lux = "0.4"
NNlib = "0.8"
Lux = "1"
LuxCore = "1"
Optim = "1.7"
Optimisers = "0.2"
Optimisers = "0.3"
ProgressMeter = "1.7"
Reexport = "1.2"
TransformVariables = "0.7"
julia = "1.6"
Setfield = "1"
StatsBase = "0.34.3"
TransformVariables = "0.8"
WeightInitializers = "1"
julia = "1.10"

[extras]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand Down
89 changes: 48 additions & 41 deletions lib/DataDrivenLux/src/DataDrivenLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,43 @@ module DataDrivenLux
using DataDrivenDiffEq

# Load specific (abstract) types
using DataDrivenDiffEq: AbstractBasis
using DataDrivenDiffEq: AbstractDataDrivenAlgorithm
using DataDrivenDiffEq: AbstractDataDrivenResult
using DataDrivenDiffEq: AbstractDataDrivenProblem
using DataDrivenDiffEq: DDReturnCode, ABSTRACT_CONT_PROB, ABSTRACT_DISCRETE_PROB
using DataDrivenDiffEq: InternalDataDrivenProblem
using DataDrivenDiffEq: is_implicit, is_controlled

using DataDrivenDiffEq.DocStringExtensions
using DataDrivenDiffEq.CommonSolve
using DataDrivenDiffEq.CommonSolve: solve!
using DataDrivenDiffEq.StatsBase
using DataDrivenDiffEq.Parameters
using DataDrivenDiffEq.Setfield

using Reexport
@reexport using Optim
using Lux

using InverseFunctions
using TransformVariables
using NNlib
using Distributions
using DistributionsAD

using ChainRulesCore
using ComponentArrays

using IntervalArithmetic
using Random
using Distributed
using ProgressMeter
using Logging
using AbstractDifferentiation, ForwardDiff
using Optimisers
using DataDrivenDiffEq: AbstractBasis, AbstractDataDrivenAlgorithm,
AbstractDataDrivenResult, AbstractDataDrivenProblem, DDReturnCode,
ABSTRACT_CONT_PROB, ABSTRACT_DISCRETE_PROB,
InternalDataDrivenProblem, is_implicit, is_controlled

using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF, SIGNATURES
using CommonSolve: CommonSolve, solve!
using ConcreteStructs: @concrete
using Setfield: Setfield, @set!

using Optim: Optim, LBFGS
using Optimisers: Optimisers, Adam

using Lux: Lux, logsoftmax, softmax!
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
using WeightInitializers: WeightInitializers, ones32, zeros32

using InverseFunctions: InverseFunctions, NoInverse
using TransformVariables: TransformVariables, as, transform_logdensity
using Distributions: Distributions, Distribution, Normal, Uniform, Univariate, dof,
loglikelihood, logpdf, mean, mode, quantile, scale, truncated
using DistributionsAD: DistributionsAD
using StatsBase: StatsBase, aicc, nobs, nullloglikelihood, r2, rss, sum, weights

using ChainRulesCore: @ignore_derivatives
using ComponentArrays: ComponentArrays, ComponentVector

using IntervalArithmetic: IntervalArithmetic, Interval, interval, isempty
using ProgressMeter: ProgressMeter
using AbstractDifferentiation: AbstractDifferentiation
using ForwardDiff: ForwardDiff

using Logging: Logging, NullLogger, with_logger
using Random: Random, AbstractRNG
using Distributed: Distributed, pmap

const AD = AbstractDifferentiation

abstract type AbstractAlgorithmCache <: AbstractDataDrivenResult end
abstract type AbstractDAGSRAlgorithm <: AbstractDataDrivenAlgorithm end
Expand All @@ -62,17 +64,20 @@ export AdditiveError, MultiplicativeError
export ObservedModel

# Simplex
include("./lux/simplex.jl")
include("lux/simplex.jl")
export Softmax, GumbelSoftmax, DirectSimplex

# Nodes and Layers
include("./lux/path_state.jl")
include("lux/path_state.jl")
export PathState
include("./lux/node.jl")

include("lux/node.jl")
export FunctionNode
include("./lux/layer.jl")

include("lux/layer.jl")
export FunctionLayer
include("./lux/graph.jl")

include("lux/graph.jl")
export LayeredDAG

include("caches/dataset.jl")
Expand All @@ -87,6 +92,8 @@ export SearchCache
include("algorithms/rewards.jl")
export RelativeReward, AbsoluteReward

include("algorithms/common.jl")

include("algorithms/randomsearch.jl")
export RandomSearch

Expand All @@ -98,4 +105,4 @@ export CrossEntropy

include("solve.jl")

end # module DataDrivenLux
end
19 changes: 19 additions & 0 deletions lib/DataDrivenLux/src/algorithms/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@kwdef @concrete struct CommonAlgOptions
populationsize::Int = 100
functions = (sin, exp, cos, log, +, -, /, *)
arities = (1, 1, 1, 1, 2, 2, 2, 2)
n_layers::Int = 1
skip::Bool = true
simplex <: AbstractSimplex = Softmax()
loss = aicc
keep <: Union{Real, Int} = 0.1
use_protected::Bool = true
distributed::Bool = false
threaded::Bool = false
rng <: AbstractRNG = Random.default_rng()
optimizer = LBFGS()
optim_options <: Optim.Options = Optim.Options()
optimiser <: Union{Nothing, Optimisers.AbstractRule} = nothing
observed <: Union{ObservedModel, Nothing} = nothing
alpha::Real = 0.999f0
end
76 changes: 25 additions & 51 deletions lib/DataDrivenLux/src/algorithms/crossentropy.jl
Original file line number Diff line number Diff line change
@@ -1,59 +1,34 @@
"""
$(TYPEDEF)
@concrete struct CrossEntropy <: AbstractDAGSRAlgorithm
options <: CommonAlgOptions
end

Uses the crossentropy method for discrete optimization to search the space of possible solutions.
"""
$(SIGNATURES)

# Fields
$(FIELDS)
Uses the crossentropy method for discrete optimization to search the space of possible
solutions.
"""
@with_kw struct CrossEntropy{F, A, L, O} <: AbstractDAGSRAlgorithm
"The number of candidates to track"
populationsize::Int = 100
"The functions to include in the search"
functions::F = (sin, exp, cos, log, +, -, /, *)
"The arities of the functions"
arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
"The number of layers"
n_layers::Int = 1
"Include skip layers"
skip::Bool = true
"Evaluation function to sort the samples"
loss::L = aicc
"The number of candidates to keep in each iteration"
keep::Union{Real, Int} = 0.1
"Use protected operators"
use_protected::Bool = true
"Use distributed optimization and resampling"
distributed::Bool = false
"Use threaded optimization and resampling - not implemented right now."
threaded::Bool = false
"Random seed"
rng::Random.AbstractRNG = Random.default_rng()
"Optim optimiser"
optimizer::O = LBFGS()
"Optim options"
optim_options::Optim.Options = Optim.Options()
"Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
observed::Union{ObservedModel, Nothing} = nothing
"Field for possible optimiser - no use for CrossEntropy"
optimiser::Nothing = nothing
"Update parameter for smoothness"
alpha::Real = 0.999f0
function CrossEntropy(; populationsize = 100, functions = (sin, exp, cos, log, +, -, /, *),
arities = (1, 1, 1, 1, 2, 2, 2, 2), n_layers = 1, skip = true, loss = aicc,
keep = 0.1, use_protected = true, distributed = false, threaded = false,
rng = Random.default_rng(), optimizer = LBFGS(), optim_options = Optim.Options(),
observed = nothing, alpha = 0.999f0)
return CrossEntropy(CommonAlgOptions(;
populationsize, functions, arities, n_layers, skip, simplex = DirectSimplex(), loss,
keep, use_protected, distributed, threaded, rng, optimizer,
optim_options, optimiser = nothing, observed, alpha))
end

Base.print(io::IO, ::CrossEntropy) = print(io, "CrossEntropy")
Base.print(io::IO, ::CrossEntropy) = print(io, "CrossEntropy()")
Base.summary(io::IO, x::CrossEntropy) = print(io, x)

function init_model(x::CrossEntropy, basis::Basis, dataset::Dataset, intervals)
@unpack n_layers, arities, functions, use_protected, skip = x

# We enforce the direct simplex here!
simplex = DirectSimplex()
(; n_layers, arities, functions, use_protected, skip) = x.options

# Get the parameter mapping
variable_mask = map(enumerate(equations(basis))) do (i, eq)
any(ModelingToolkit.isvariable, ModelingToolkit.get_variables(eq.rhs)) &&
IntervalArithmetic.iscommon(intervals[i])
return any(ModelingToolkit.isvariable, ModelingToolkit.get_variables(eq.rhs)) &&
IntervalArithmetic.iscommon(intervals[i])
end

variable_mask = Any[variable_mask...]
Expand All @@ -63,15 +38,14 @@ function init_model(x::CrossEntropy, basis::Basis, dataset::Dataset, intervals)
end

return LayeredDAG(length(basis), size(dataset.y, 1), n_layers, arities, functions;
skip = skip, input_functions = variable_mask, simplex = simplex)
skip, input_functions = variable_mask, x.options.simplex)
end

function update_parameters!(cache::SearchCache{<:CrossEntropy})
@unpack candidates, keeps, p, alg = cache
@unpack alpha = alg
p̄ = mean(map(candidates[keeps]) do candidate
ComponentVector(get_configuration(candidate.model.model, p, candidate.st))
p̄ = mean(map(cache.candidates[cache.keeps]) do candidate
return ComponentVector(get_configuration(candidate.model.model, cache.p, candidate.st))
end)
cache.p .= alpha * p + (one(alpha) - alpha) .* p̄
alpha = cache.alg.options.alpha
@. cache.p = alpha * cache.p + (true - alpha) * p̄
return
end
Loading
Loading