Skip to content

Commit

Permalink
Merge #1681
Browse files Browse the repository at this point in the history
1681: Support NamedTuples for Chain + Parallel r=darsnack a=mcabbott

Closes #1680, WIP. Todo list includes:

- [x] add Parallel too
- [ ] ~~worry about whether any of this will upset Zygote, like FluxML/Zygote.jl#909 or, kick that can down the road.
- [x] add tests

Co-authored-by: Michael Abbott <[email protected]>
  • Loading branch information
bors[bot] and mcabbott authored Aug 4, 2021
2 parents 5d2a955 + fec920b commit dbb9f82
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 36 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

## v0.12.7
* Added support for [`GRUv3`](https://github.com/FluxML/Flux.jl/pull/1675)
* The layers within `Chain` and `Parallel` may now [have names](https://github.com/FluxML/Flux.jl/issues/1680).

## v0.12.5
* Added option to configure [`groups`](https://github.com/FluxML/Flux.jl/pull/1531) in `Conv`.
* REPL printing via [`show`](https://github.com/FluxML/Flux.jl/pull/1467) displays parameter counts.

## v0.12.4
* Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516)
Expand Down
87 changes: 59 additions & 28 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""
Chain(layers...)
Chain(name = layer, ...)
Chain multiple layers / functions together, so that they are called in sequence
on a given input.
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers.
Collects multiple layers / functions to be called in sequence
on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`,
and if names are given, `m[:name] == m[1]` etc.
# Examples
Expand All @@ -15,36 +14,51 @@ julia> m = Chain(x -> x^2, x -> x+1);
julia> m(5) == 26
true
julia> m = Chain(Dense(10, 5), Dense(5, 2));
julia> m = Chain(Dense(10, 5, tanh), Dense(5, 2));
julia> x = rand(10);
julia> x = rand(10, 32);
julia> m(x) == m[2](m[1](x))
true
julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)),
dec = Dense(5, 2));
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```
"""
struct Chain{T<:Tuple}
struct Chain{T}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
function Chain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return new{Tuple{}}(())
new{typeof(values(kw))}(values(kw))
end
end

@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex
Base.iterate, Base.lastindex, Base.keys

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))

(c::Chain)(x) = applychain(c.layers, x)
(c::Chain)(x) = applychain(Tuple(c.layers), x)

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)

function Base.show(io::IO, c::Chain)
print(io, "Chain(")
join(io, c.layers, ", ")
_show_layers(io, c.layers)
print(io, ")")
end
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")

# This is a temporary and naive implementation
# it might be replaced in the future for better performance
Expand All @@ -56,19 +70,15 @@ end
Calculate the forward results of each layers in Chain `c` with `input` as model input.
"""
function activations(c::Chain, input)
extraChain(c.layers, input)
end
activations(c::Chain, input) = extraChain(Tuple(c.layers), input)

function extraChain(fs::Tuple, x)
res = first(fs)(x)
return (res, extraChain(Base.tail(fs), res)...)
res = first(fs)(x)
return (res, extraChain(Base.tail(fs), res)...)
end

extraChain(::Tuple{}, x) = ()



"""
Dense(in, out, σ=identity; bias=true, init=glorot_uniform)
Dense(W::AbstractMatrix, [bias, σ])
Expand Down Expand Up @@ -378,32 +388,42 @@ end

"""
Parallel(connection, layers...)
Parallel(connection; name = layer, ...)
Create a 'Parallel' layer that passes an input array to each path in
`layers`, reducing the output with `connection`.
`layers`, before reducing the output with `connection`.
Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`.
If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
# Examples
```jldoctest
julia> model = Chain(Dense(3, 5),
Parallel(vcat, Dense(5, 4), Chain(Dense(5, 7), Dense(7, 4))),
Dense(8, 17));
julia> size(model(rand(3)))
julia> model(rand(3)) |> size
(17,)
julia> model = Parallel(+, Dense(10, 2), Dense(5, 2))
julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2))
Parallel(
+,
Dense(10, 2), # 22 parameters
Dense(5, 2), # 12 parameters
α = Dense(10, 2, tanh), # 22 parameters
β = Dense(5, 2), # 12 parameters
) # Total: 4 arrays, 34 parameters, 392 bytes.
julia> size(model(rand(10), rand(5)))
julia> model2(rand(10), rand(5)) |> size
(2,)
julia> model2[:α](rand(10)) |> size
(2,)
julia> model2[:β] == model2[2]
true
```
"""
struct Parallel{F, T}
Expand All @@ -412,21 +432,32 @@ struct Parallel{F, T}
end

Parallel(connection, layers...) = Parallel(connection, layers)
function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
elseif isempty(layers)
Parallel(connection, ())
end
Parallel(connection, layers)
end

@functor Parallel

(m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers)
(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, m.layers, xs)
(m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, Tuple(m.layers))
(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs)
(m::Parallel)(xs::Tuple) = m(xs...)

Base.getindex(m::Parallel, i::Integer) = m.layers[i]
Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)

Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))

trainable(m::Parallel) = (m.connection, m.layers...)

function Base.show(io::IO, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
join(io, m.layers, ", ")
_show_layers(io, m.layers)
print(io, ")")
end

Expand Down
29 changes: 21 additions & 8 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,28 @@ for T in [
end
end

function _big_show(io::IO, obj, indent::Int=0)
function _big_show(io::IO, obj, indent::Int=0, name=nothing)
children = trainable(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, nameof(typeof(obj)), "(")
for c in children
_big_show(io, c, indent+2)
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(")
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
# then we insert names -- can this be done more generically?
for k in Base.keys(obj)
_big_show(io, obj[k], indent+2, k)
end
elseif obj isa Parallel{<:Any, <:NamedTuple}
_big_show(io, obj.connection, indent+2)
for k in Base.keys(obj)
_big_show(io, obj[k], indent+2, k)
end
else
for c in children
_big_show(io, c, indent+2)
end
end
if indent == 0
if indent == 0 # i.e. this is the outermost container
print(io, ")")
_big_finale(io, obj)
else
Expand All @@ -49,8 +61,9 @@ for T in [
end
end

function _layer_show(io::IO, layer, indent::Int=0)
str = sprint(show, layer, context=io)
function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * sprint(show, layer, context=io)
print(io, " "^indent, str, indent==0 ? "" : ",")
if !isempty(params(layer))
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
Expand Down
21 changes: 21 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,23 @@ import Flux: activations
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
# numeric test should be put into testset of corresponding layer

@test_nowarn Chain(first = Dense(10, 5, σ), second = Dense(5, 2))(randn(10))
m = Chain(first = Dense(10, 5, σ), second = Dense(5, 2))
@test m[:first] == m[1]
@test m[1:2] == m

@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name
end

@testset "Activations" begin
c = Chain(Dense(3,5,relu), Dense(5,1,relu))
X = Float32.([1.0; 1.0; 1.0])
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c))

c2 = Chain(enc = c[1], dec = c[2])
@test Flux.activations(c, X) == Flux.activations(c2, X)
@test_nowarn gradient(()->Flux.activations(c2, X)[2][1], params(c2))
end

@testset "Dense" begin
Expand Down Expand Up @@ -184,11 +195,21 @@ import Flux: activations
@testset "concat size" begin
input = randn(10, 2)
@test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4)
@test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4)
end

@testset "vararg input" begin
inputs = randn(10), randn(5), randn(4)
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
@test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,)
end

@testset "named access" begin
m = Parallel(hcat, one = Dense(10, 10), two = identity)
@test m[1] == m[:one]

@test_throws ArgumentError Parallel(hcat, layers = Dense(10, 10), two = identity) # reserved names
@test_throws ArgumentError Parallel(hcat, connection = Dense(10, 10), two = identity)
end
end

Expand Down
1 change: 1 addition & 0 deletions test/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

@test repr(Dense(2,3)) == "Dense(2, 3)"
@test repr(Chain(Dense(2,3))) == "Chain(Dense(2, 3))"
@test repr(Chain(lay=Dense(2,3))) == "Chain(lay = Dense(2, 3))"

end
@testset "nested model printing" begin # 3-arg show, defined in show.jl
Expand Down

0 comments on commit dbb9f82

Please sign in to comment.