From ffda4aa098650a0d5d3736392c3549655d597bbc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Sep 2024 20:53:41 -0400 Subject: [PATCH] feat: update to support Lux 1.0 --- .github/workflows/CompatHelper.yml | 2 +- lib/DataDrivenLux/Project.toml | 18 +++++++++----- lib/DataDrivenLux/src/DataDrivenLux.jl | 5 ++++ lib/DataDrivenLux/src/lux/graph.jl | 6 ++--- lib/DataDrivenLux/src/lux/layer.jl | 3 +-- lib/DataDrivenLux/src/lux/node.jl | 34 +++++++++++--------------- lib/DataDrivenLux/src/lux/simplex.jl | 6 ++--- 7 files changed, 38 insertions(+), 36 deletions(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 2de58423f..a1461cf05 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -23,4 +23,4 @@ jobs: - name: CompatHelper.main() env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: julia -e 'using CompatHelper; CompatHelper.main(;subdirs=["", "docs", "lib/DataDrivenDMD", "lib/DataDrivenSparse", "lib/DataDrivenSR"])' + run: julia -e 'using CompatHelper; CompatHelper.main(;subdirs=["", "docs", "lib/DataDrivenDMD", "lib/DataDrivenSparse", "lib/DataDrivenSR", "lib/DataDrivenLux"])' diff --git a/lib/DataDrivenLux/Project.toml b/lib/DataDrivenLux/Project.toml index 78591e64b..de66c583d 100644 --- a/lib/DataDrivenLux/Project.toml +++ b/lib/DataDrivenLux/Project.toml @@ -1,12 +1,13 @@ name = "DataDrivenLux" uuid = "47881146-99d0-492a-8425-8f2f33327637" authors = ["JuliusMartensen "] -version = "0.1.1" +version = "0.2.0" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -17,6 +18,7 @@ InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -24,25 +26,29 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" +WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] AbstractDifferentiation = "0.4" ChainRulesCore = "1.15" -ComponentArrays = "0.13" +ComponentArrays = "0.15" +ConcreteStructs = "0.2.3" DataDrivenDiffEq = "1" Distributions = "0.25" DistributionsAD = "0.6" ForwardDiff = "0.10" IntervalArithmetic = "0.20" InverseFunctions = "0.1" -Lux = "0.4" -NNlib = "0.8" +Lux = "1" +LuxCore = "1" +NNlib = "0.9" Optim = "1.7" -Optimisers = "0.2" +Optimisers = "0.3" ProgressMeter = "1.7" Reexport = "1.2" TransformVariables = "0.7" -julia = "1.6" +WeightInitializers = "1.0.3" +julia = "1.10" [extras] OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" diff --git a/lib/DataDrivenLux/src/DataDrivenLux.jl b/lib/DataDrivenLux/src/DataDrivenLux.jl index f34622851..dc9af4523 100644 --- a/lib/DataDrivenLux/src/DataDrivenLux.jl +++ b/lib/DataDrivenLux/src/DataDrivenLux.jl @@ -18,9 +18,14 @@ using DataDrivenDiffEq.StatsBase using DataDrivenDiffEq.Parameters using DataDrivenDiffEq.Setfield +using ConcreteStructs: @concrete + using Reexport @reexport using Optim + using Lux +using LuxCore +using WeightInitializers using InverseFunctions using TransformVariables diff --git a/lib/DataDrivenLux/src/lux/graph.jl b/lib/DataDrivenLux/src/lux/graph.jl index 2f6057293..b4269e318 100644 --- a/lib/DataDrivenLux/src/lux/graph.jl +++ b/lib/DataDrivenLux/src/lux/graph.jl @@ -7,13 +7,13 @@ different [`DecisionLayer`](@ref)s. # Fields $(FIELDS) """ -struct LayeredDAG{T} <: Lux.AbstractExplicitContainerLayer{(:layers,)} - layers::T +@concrete struct LayeredDAG <: AbstractLuxWrapperLayer{:layers} + layers end function LayeredDAG(in_dimension::Int, out_dimension::Int, n_layers::Int, fs::Vector{Pair{Function, Int}}; kwargs...) - LayeredDAG(in_dimension, out_dimension, n_layers, tuple(last.(fs)...), + return LayeredDAG(in_dimension, out_dimension, n_layers, tuple(last.(fs)...), tuple(first.(fs)...); kwargs...) end diff --git a/lib/DataDrivenLux/src/lux/layer.jl b/lib/DataDrivenLux/src/lux/layer.jl index 17034837a..67688da11 100644 --- a/lib/DataDrivenLux/src/lux/layer.jl +++ b/lib/DataDrivenLux/src/lux/layer.jl @@ -7,8 +7,7 @@ It accumulates all outputs of the nodes. # Fields $(FIELDS) """ -struct FunctionLayer{skip, T, output_dimension} <: - Lux.AbstractExplicitContainerLayer{(:nodes,)} +struct FunctionLayer{skip, T, output_dimension} <: AbstractLuxWrapperLayer{:nodes} nodes::T end diff --git a/lib/DataDrivenLux/src/lux/node.jl b/lib/DataDrivenLux/src/lux/node.jl index da81ebd95..2594a5564 100644 --- a/lib/DataDrivenLux/src/lux/node.jl +++ b/lib/DataDrivenLux/src/lux/node.jl @@ -9,15 +9,15 @@ and a latent array of weights representing a probability distribution over the i $(FIELDS) """ -struct FunctionNode{skip, ID, F, S} <: Lux.AbstractExplicitLayer +@concrete struct FunctionNode{skip, ID} <: AbstractLuxLayer "Function which should map from in_dims ↦ R" - f::F + f "Arity of the function" arity::Int "Input dimensions of the signal" in_dims::Int "Mapping to the unit simplex" - simplex::S + simplex "Masking of the input values" input_mask::Vector{Bool} end @@ -53,21 +53,19 @@ end get_id(::FunctionNode{<:Any, id}) where {id} = id -function Lux.initialparameters(rng::AbstractRNG, l::FunctionNode) +function LuxCore.initialparameters(rng::AbstractRNG, l::FunctionNode) return (; weights = init_weights(l.simplex, rng, sum(l.input_mask), l.arity)) end -function Lux.initialstates(rng::AbstractRNG, p::FunctionNode) - begin - rand(rng) - rng_ = Lux.replicate(rng) - # Call once - (; - priors = init_weights(p.simplex, rng, sum(p.input_mask), p.arity), - active_inputs = zeros(Int, p.arity), - temperature = 1.0f0, - rng = rng_) - end +function LuxCore.initialstates(rng::AbstractRNG, p::FunctionNode) + rand(rng) + rng_ = LuxCore.replicate(rng) + # Call once + return (; + priors = init_weights(p.simplex, rng, sum(p.input_mask), p.arity), + active_inputs = zeros(Int, p.arity), + temperature = 1.0f0, + rng = rng_) end function update_state(p::FunctionNode, ps, st) @@ -79,11 +77,7 @@ function update_state(p::FunctionNode, ps, st) active_inputs[i] = findfirst(rand(rng) .<= cumsum(priors[:, i])) end - (; - priors = priors, - active_inputs = active_inputs, - temperature = temperature, - rng = rng) + return (; priors, active_inputs, temperature, rng) end function (l::FunctionNode{false})(x::AbstractArray{<:Number}, ps, st::NamedTuple) diff --git a/lib/DataDrivenLux/src/lux/simplex.jl b/lib/DataDrivenLux/src/lux/simplex.jl index 732986f65..b9d58dfdd 100644 --- a/lib/DataDrivenLux/src/lux/simplex.jl +++ b/lib/DataDrivenLux/src/lux/simplex.jl @@ -1,6 +1,4 @@ -function init_weights(::AbstractSimplex, rng::Random.AbstractRNG, dims...) - Lux.zeros32(rng, dims...) -end +init_weights(::AbstractSimplex, rng::Random.AbstractRNG, dims...) = zeros32(rng, dims...) """ $(TYPEDEF) @@ -53,7 +51,7 @@ function (::DirectSimplex)(rng::Random.AbstractRNG, x̂::AbstractVector, x::Abst end function init_weights(::DirectSimplex, rng::Random.AbstractRNG, dims...) - w = Lux.ones32(rng, dims...) + w = ones32(rng, dims...) w ./= first(dims) w end