diff --git a/lib/DataDrivenLux/src/lux/layer.jl b/lib/DataDrivenLux/src/lux/layer.jl index 1abc1579..fba54b32 100644 --- a/lib/DataDrivenLux/src/lux/layer.jl +++ b/lib/DataDrivenLux/src/lux/layer.jl @@ -7,8 +7,9 @@ It accumulates all outputs of the nodes. # Fields $(FIELDS) """ -struct FunctionLayer{skip, T, output_dimension} <: AbstractLuxWrapperLayer{:nodes} - nodes::T +@concrete struct FunctionLayer <: AbstractLuxWrapperLayer{:nodes} + nodes + skip end function FunctionLayer( @@ -17,40 +18,30 @@ function FunctionLayer( nodes = map(eachindex(arities)) do i # We check if we have an inverse here return FunctionNode(fs[i], arities[i], in_dimension, (id_offset, i); - input_functions = input_functions, kwargs...) + input_functions, kwargs...) end - - output_dimension = length(arities) - output_dimension += skip ? in_dimension : 0 - - names = map(gensym ∘ string, fs) - nodes = NamedTuple{names}(nodes) - return FunctionLayer{skip, typeof(nodes), output_dimension}(nodes) -end - -function (r::FunctionLayer)(x, ps, st) - return _apply_layer(r.nodes, x, ps, st) + inner_model = Lux.Chain( + Lux.BranchLayer(nodes...), Lux.WrappedFunction(Base.Fix1(reduce, vcat))) + return FunctionLayer( + skip ? Lux.Parallel(vcat, inner_model, Lux.NoOpLayer()) : inner_model, skip) end -function (r::FunctionLayer{true})(x, ps, st) - y, st = _apply_layer(r.nodes, x, ps, st) - return vcat(y, x), st -end - -Base.keys(m::FunctionLayer) = Base.keys(getfield(m, :nodes)) - -Base.getindex(c::FunctionLayer, i::Int) = c.nodes[i] - -Base.length(c::FunctionLayer) = length(c.nodes) -Base.lastindex(c::FunctionLayer) = lastindex(c.nodes) -Base.firstindex(c::FunctionLayer) = firstindex(c.nodes) - function get_loglikelihood(r::FunctionLayer, ps, st) - return _get_layer_loglikelihood(r.nodes, ps, st) + if r.skip + return _get_layer_loglikelihood( + r.nodes.layers[1].layers[1], ps.layer_1.layer_1, st.layer_1.layer_1) + else + return _get_layer_loglikelihood(r.nodes.layers[1].layers, ps.layer_1, st.layer_1) + end end function get_configuration(r::FunctionLayer, ps, st) - return _get_configuration(r.nodes, ps, st) + if r.skip + return _get_configuration( + r.nodes.layers[1].layers[1], ps.layer_1.layer_1, st.layer_1.layer_1) + else + return _get_configuration(r.nodes.layers[1].layers, ps.layer_1, st.layer_1) + end end @generated function _get_layer_loglikelihood( @@ -72,15 +63,3 @@ end push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))) return Expr(:block, calls...) end - -@generated function _apply_layer( - layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}) where {fields} - N = length(fields) - y_symbols = vcat([gensym() for _ in 1:N]) - st_symbols = [gensym() for _ in 1:N] - calls = [:(($(y_symbols[i]), $(st_symbols[i])) = Lux.apply( - layers.$(fields[i]), x, ps.$(fields[i]), st.$(fields[i]))) for i in 1:N] - push!(calls, :(st = NamedTuple{$fields}(($(Tuple(st_symbols)...),)))) - push!(calls, :(return vcat($(y_symbols...)), st)) - return Expr(:block, calls...) -end diff --git a/lib/DataDrivenLux/src/lux/node.jl b/lib/DataDrivenLux/src/lux/node.jl index 07fb020b..35bf3248 100644 --- a/lib/DataDrivenLux/src/lux/node.jl +++ b/lib/DataDrivenLux/src/lux/node.jl @@ -49,7 +49,7 @@ function FunctionNode(f::F, arity::Int, input_dimension::Int, to input_mask" internal_node = InternalFunctionNode{id}(f, arity, input_dimension, simplex, input_mask) - node = skip ? Lux.Parallel(vcat, internal_node, Lux, NoOpLayer()) : internal_node + node = skip ? Lux.Parallel(vcat, internal_node, Lux.NoOpLayer()) : internal_node return FunctionNode(node) end @@ -78,13 +78,13 @@ end end function (l::InternalFunctionNode)(x::AbstractMatrix, ps, st) - return mapreduce(hcat, eachcol(x)) do xi - return LuxCore.apply(l, xi, ps, st) - end + m = Lux.StatefulLuxLayer{true}(l, ps, st) + z = map(m, eachcol(x)) + return reduce(hcat, z), m.st end function (l::InternalFunctionNode)(x::AbstractVector, ps, st) - return l.f(get_masked_inputs(l, x, ps, st)...) + return l.f(get_masked_inputs(l, x, ps, st)...), st end function (l::InternalFunctionNode)(x::AbstractVector{<:AbstractPathState}, ps, st) diff --git a/lib/DataDrivenLux/test/layers.jl b/lib/DataDrivenLux/test/layers.jl index ab452800..46173997 100644 --- a/lib/DataDrivenLux/test/layers.jl +++ b/lib/DataDrivenLux/test/layers.jl @@ -13,20 +13,24 @@ using StableRNGs arities = (1, 2, 3) x = randn(3) X = randn(3, 10) + layer = FunctionLayer(3, arities, fs, id_offset = 2) rng = StableRNG(43) ps, st = Lux.setup(rng, layer) layer_states, new_st = layer(states, ps, st) @test all(exp.(values(DataDrivenLux.get_loglikelihood(layer, ps, new_st))) .≈ (1 / 3, 1 / 9, 1 / 27)) - @test map(DataDrivenLux.get_interval, layer_states) == - [interval(-1, 1), interval(-20, 20), interval(-110, 110)] - @test length(layer) == 3 - @test length(keys(layer)) == 3 + + intervals = map(DataDrivenLux.get_interval, layer_states) + @test isequal_interval(intervals[1], interval(-1, 1)) + @test isequal_interval(intervals[2], interval(-20, 20)) + @test isequal_interval(intervals[3], interval(-110, 110)) + y, _ = layer(x, ps, new_st) Y, _ = layer(X, ps, new_st) @test y == [sin(x[1]); x[3] + x[1]; x[1] * x[3] - x[3]] @test Y == [sin.(X[1:1, :]); X[3:3, :] + X[1:1, :]; X[1:1, :] .* X[3:3, :] - X[3:3, :]] + fs = (sin, cos, log, exp, +, -, *) @test DataDrivenLux.mask_inverse(log, 1, collect(fs)) == [1, 1, 1, 0, 1, 1, 1] @test DataDrivenLux.mask_inverse(exp, 1, collect(fs)) == [1, 1, 0, 1, 1, 1, 1]