Skip to content

Commit

Permalink
feat: update to support Lux 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 19, 2024
1 parent 08d630d commit ffda4aa
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: julia -e 'using CompatHelper; CompatHelper.main(;subdirs=["", "docs", "lib/DataDrivenDMD", "lib/DataDrivenSparse", "lib/DataDrivenSR"])'
run: julia -e 'using CompatHelper; CompatHelper.main(;subdirs=["", "docs", "lib/DataDrivenDMD", "lib/DataDrivenSparse", "lib/DataDrivenSR", "lib/DataDrivenLux"])'
18 changes: 12 additions & 6 deletions lib/DataDrivenLux/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "DataDrivenLux"
uuid = "47881146-99d0-492a-8425-8f2f33327637"
authors = ["JuliusMartensen <[email protected]>"]
version = "0.1.1"
version = "0.2.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -17,32 +18,37 @@ InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[compat]
AbstractDifferentiation = "0.4"
ChainRulesCore = "1.15"
ComponentArrays = "0.13"
ComponentArrays = "0.15"
ConcreteStructs = "0.2.3"
DataDrivenDiffEq = "1"
Distributions = "0.25"
DistributionsAD = "0.6"
ForwardDiff = "0.10"
IntervalArithmetic = "0.20"
InverseFunctions = "0.1"
Lux = "0.4"
NNlib = "0.8"
Lux = "1"
LuxCore = "1"
NNlib = "0.9"
Optim = "1.7"
Optimisers = "0.2"
Optimisers = "0.3"
ProgressMeter = "1.7"
Reexport = "1.2"
TransformVariables = "0.7"
julia = "1.6"
WeightInitializers = "1.0.3"
julia = "1.10"

[extras]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand Down
5 changes: 5 additions & 0 deletions lib/DataDrivenLux/src/DataDrivenLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ using DataDrivenDiffEq.StatsBase
using DataDrivenDiffEq.Parameters
using DataDrivenDiffEq.Setfield

using ConcreteStructs: @concrete

using Reexport
@reexport using Optim

using Lux
using LuxCore
using WeightInitializers

using InverseFunctions
using TransformVariables
Expand Down
6 changes: 3 additions & 3 deletions lib/DataDrivenLux/src/lux/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ different [`DecisionLayer`](@ref)s.
# Fields
$(FIELDS)
"""
struct LayeredDAG{T} <: Lux.AbstractExplicitContainerLayer{(:layers,)}
layers::T
@concrete struct LayeredDAG <: AbstractLuxWrapperLayer{:layers}
layers
end

function LayeredDAG(in_dimension::Int, out_dimension::Int, n_layers::Int,
fs::Vector{Pair{Function, Int}}; kwargs...)
LayeredDAG(in_dimension, out_dimension, n_layers, tuple(last.(fs)...),
return LayeredDAG(in_dimension, out_dimension, n_layers, tuple(last.(fs)...),
tuple(first.(fs)...); kwargs...)
end

Expand Down
3 changes: 1 addition & 2 deletions lib/DataDrivenLux/src/lux/layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ It accumulates all outputs of the nodes.
# Fields
$(FIELDS)
"""
struct FunctionLayer{skip, T, output_dimension} <:
Lux.AbstractExplicitContainerLayer{(:nodes,)}
struct FunctionLayer{skip, T, output_dimension} <: AbstractLuxWrapperLayer{:nodes}
nodes::T
end

Expand Down
34 changes: 14 additions & 20 deletions lib/DataDrivenLux/src/lux/node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ and a latent array of weights representing a probability distribution over the i
$(FIELDS)
"""
struct FunctionNode{skip, ID, F, S} <: Lux.AbstractExplicitLayer
@concrete struct FunctionNode{skip, ID} <: AbstractLuxLayer
"Function which should map from in_dims ↦ R"
f::F
f
"Arity of the function"
arity::Int
"Input dimensions of the signal"
in_dims::Int
"Mapping to the unit simplex"
simplex::S
simplex
"Masking of the input values"
input_mask::Vector{Bool}
end
Expand Down Expand Up @@ -53,21 +53,19 @@ end

get_id(::FunctionNode{<:Any, id}) where {id} = id

function Lux.initialparameters(rng::AbstractRNG, l::FunctionNode)
function LuxCore.initialparameters(rng::AbstractRNG, l::FunctionNode)
return (; weights = init_weights(l.simplex, rng, sum(l.input_mask), l.arity))
end

function Lux.initialstates(rng::AbstractRNG, p::FunctionNode)
begin
rand(rng)
rng_ = Lux.replicate(rng)
# Call once
(;
priors = init_weights(p.simplex, rng, sum(p.input_mask), p.arity),
active_inputs = zeros(Int, p.arity),
temperature = 1.0f0,
rng = rng_)
end
function LuxCore.initialstates(rng::AbstractRNG, p::FunctionNode)
rand(rng)
rng_ = LuxCore.replicate(rng)
# Call once
return (;
priors = init_weights(p.simplex, rng, sum(p.input_mask), p.arity),
active_inputs = zeros(Int, p.arity),
temperature = 1.0f0,
rng = rng_)
end

function update_state(p::FunctionNode, ps, st)
Expand All @@ -79,11 +77,7 @@ function update_state(p::FunctionNode, ps, st)
active_inputs[i] = findfirst(rand(rng) .<= cumsum(priors[:, i]))
end

(;
priors = priors,
active_inputs = active_inputs,
temperature = temperature,
rng = rng)
return (; priors, active_inputs, temperature, rng)
end

function (l::FunctionNode{false})(x::AbstractArray{<:Number}, ps, st::NamedTuple)
Expand Down
6 changes: 2 additions & 4 deletions lib/DataDrivenLux/src/lux/simplex.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
function init_weights(::AbstractSimplex, rng::Random.AbstractRNG, dims...)
Lux.zeros32(rng, dims...)
end
init_weights(::AbstractSimplex, rng::Random.AbstractRNG, dims...) = zeros32(rng, dims...)

"""
$(TYPEDEF)
Expand Down Expand Up @@ -53,7 +51,7 @@ function (::DirectSimplex)(rng::Random.AbstractRNG, x̂::AbstractVector, x::Abst
end

function init_weights(::DirectSimplex, rng::Random.AbstractRNG, dims...)
w = Lux.ones32(rng, dims...)
w = ones32(rng, dims...)
w ./= first(dims)
w
end

0 comments on commit ffda4aa

Please sign in to comment.