A simple library which allows you to do algebraic manipulations on neural
network models/layers defined in Lux.jl, for
example if l1
and l2
are Lux layers, then
l3 = @lift_nn 2*l1 + l2
defines a Lux layer l3
such that for an input x
, the layer produces an output
l3(x) = 2*l1(x) + l2(x)
.
Given two real-valued functions f
and g
which have the same domain, we can
define an algebra over the real numbers such that h=f+g
defines another
real-valued function h
. Because parameterized neural networks can be defined
as real-valued functions over the real numbers (excluding their parameters), then
we can "lift" this algebra
to an algebra over neural networks.
using AlgebraOfNNs
using Lux
l1 = Dense(16=>32)
l2 = Dense(16=>32)
l3 = @lift_nn 2*l1+l2
will call constructors for Lux's Parallel
and Chain
layers such that
julia> l3
Parallel(
+
Parallel(
*
WrappedFunction(#1),
Dense(16 => 32), # 544 parameters
),
Dense(16 => 32), # 544 parameters
) # Total: 1_088 parameters,
# plus 0 states.
- Works generically for any function call, not just algebraic operators (i.e.,
recall that in Julia even arithmetic operations like
a+b+c
are themselves simply n-ary function calls+(a,b,c)
) - Does not introduce any new state or learnable parameters