Skip to content

Commit

Permalink
fix: more lux tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 19, 2024
1 parent 7fc7bcd commit 8c7bf52
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 50 deletions.
61 changes: 20 additions & 41 deletions lib/DataDrivenLux/src/lux/layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ It accumulates all outputs of the nodes.
# Fields
$(FIELDS)
"""
struct FunctionLayer{skip, T, output_dimension} <: AbstractLuxWrapperLayer{:nodes}
nodes::T
@concrete struct FunctionLayer <: AbstractLuxWrapperLayer{:nodes}
nodes
skip
end

function FunctionLayer(
Expand All @@ -17,40 +18,30 @@ function FunctionLayer(
nodes = map(eachindex(arities)) do i
# We check if we have an inverse here
return FunctionNode(fs[i], arities[i], in_dimension, (id_offset, i);
input_functions = input_functions, kwargs...)
input_functions, kwargs...)
end

output_dimension = length(arities)
output_dimension += skip ? in_dimension : 0

names = map(gensym string, fs)
nodes = NamedTuple{names}(nodes)
return FunctionLayer{skip, typeof(nodes), output_dimension}(nodes)
end

function (r::FunctionLayer)(x, ps, st)
return _apply_layer(r.nodes, x, ps, st)
inner_model = Lux.Chain(
Lux.BranchLayer(nodes...), Lux.WrappedFunction(Base.Fix1(reduce, vcat)))
return FunctionLayer(
skip ? Lux.Parallel(vcat, inner_model, Lux.NoOpLayer()) : inner_model, skip)
end

function (r::FunctionLayer{true})(x, ps, st)
y, st = _apply_layer(r.nodes, x, ps, st)
return vcat(y, x), st
end

Base.keys(m::FunctionLayer) = Base.keys(getfield(m, :nodes))

Base.getindex(c::FunctionLayer, i::Int) = c.nodes[i]

Base.length(c::FunctionLayer) = length(c.nodes)
Base.lastindex(c::FunctionLayer) = lastindex(c.nodes)
Base.firstindex(c::FunctionLayer) = firstindex(c.nodes)

function get_loglikelihood(r::FunctionLayer, ps, st)
return _get_layer_loglikelihood(r.nodes, ps, st)
if r.skip
return _get_layer_loglikelihood(
r.nodes.layers[1].layers[1], ps.layer_1.layer_1, st.layer_1.layer_1)
else
return _get_layer_loglikelihood(r.nodes.layers[1].layers, ps.layer_1, st.layer_1)
end
end

function get_configuration(r::FunctionLayer, ps, st)
return _get_configuration(r.nodes, ps, st)
if r.skip
return _get_configuration(
r.nodes.layers[1].layers[1], ps.layer_1.layer_1, st.layer_1.layer_1)
else
return _get_configuration(r.nodes.layers[1].layers, ps.layer_1, st.layer_1)
end
end

@generated function _get_layer_loglikelihood(
Expand All @@ -72,15 +63,3 @@ end
push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
return Expr(:block, calls...)
end

@generated function _apply_layer(
layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}) where {fields}
N = length(fields)
y_symbols = vcat([gensym() for _ in 1:N])
st_symbols = [gensym() for _ in 1:N]
calls = [:(($(y_symbols[i]), $(st_symbols[i])) = Lux.apply(
layers.$(fields[i]), x, ps.$(fields[i]), st.$(fields[i]))) for i in 1:N]
push!(calls, :(st = NamedTuple{$fields}(($(Tuple(st_symbols)...),))))
push!(calls, :(return vcat($(y_symbols...)), st))
return Expr(:block, calls...)
end
10 changes: 5 additions & 5 deletions lib/DataDrivenLux/src/lux/node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function FunctionNode(f::F, arity::Int, input_dimension::Int,
to input_mask"

internal_node = InternalFunctionNode{id}(f, arity, input_dimension, simplex, input_mask)
node = skip ? Lux.Parallel(vcat, internal_node, Lux, NoOpLayer()) : internal_node
node = skip ? Lux.Parallel(vcat, internal_node, Lux.NoOpLayer()) : internal_node
return FunctionNode(node)
end

Expand Down Expand Up @@ -78,13 +78,13 @@ end
end

function (l::InternalFunctionNode)(x::AbstractMatrix, ps, st)
return mapreduce(hcat, eachcol(x)) do xi
return LuxCore.apply(l, xi, ps, st)
end
m = Lux.StatefulLuxLayer{true}(l, ps, st)
z = map(m, eachcol(x))
return reduce(hcat, z), m.st
end

function (l::InternalFunctionNode)(x::AbstractVector, ps, st)
return l.f(get_masked_inputs(l, x, ps, st)...)
return l.f(get_masked_inputs(l, x, ps, st)...), st
end

function (l::InternalFunctionNode)(x::AbstractVector{<:AbstractPathState}, ps, st)
Expand Down
12 changes: 8 additions & 4 deletions lib/DataDrivenLux/test/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@ using StableRNGs
arities = (1, 2, 3)
x = randn(3)
X = randn(3, 10)

layer = FunctionLayer(3, arities, fs, id_offset = 2)
rng = StableRNG(43)
ps, st = Lux.setup(rng, layer)
layer_states, new_st = layer(states, ps, st)
@test all(exp.(values(DataDrivenLux.get_loglikelihood(layer, ps, new_st))) .≈
(1 / 3, 1 / 9, 1 / 27))
@test map(DataDrivenLux.get_interval, layer_states) ==
[interval(-1, 1), interval(-20, 20), interval(-110, 110)]
@test length(layer) == 3
@test length(keys(layer)) == 3

intervals = map(DataDrivenLux.get_interval, layer_states)
@test isequal_interval(intervals[1], interval(-1, 1))
@test isequal_interval(intervals[2], interval(-20, 20))
@test isequal_interval(intervals[3], interval(-110, 110))

y, _ = layer(x, ps, new_st)
Y, _ = layer(X, ps, new_st)
@test y == [sin(x[1]); x[3] + x[1]; x[1] * x[3] - x[3]]
@test Y == [sin.(X[1:1, :]); X[3:3, :] + X[1:1, :]; X[1:1, :] .* X[3:3, :] - X[3:3, :]]

fs = (sin, cos, log, exp, +, -, *)
@test DataDrivenLux.mask_inverse(log, 1, collect(fs)) == [1, 1, 1, 0, 1, 1, 1]
@test DataDrivenLux.mask_inverse(exp, 1, collect(fs)) == [1, 1, 0, 1, 1, 1, 1]
Expand Down

0 comments on commit 8c7bf52

Please sign in to comment.