From cf87c31bbe4e1b9ff982b544924228fb316c2b6e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 14:36:14 -0400 Subject: [PATCH 01/20] allow NamedTuple in Chain, take 1 --- src/layers/basic.jl | 53 ++++++++++++++++++++++++--------------------- src/layers/show.jl | 24 +++++++++++++------- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0a53b70415..7360e14084 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -1,11 +1,10 @@ """ Chain(layers...) + Chain(name = layer, ...) -Chain multiple layers / functions together, so that they are called in sequence -on a given input. - -`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`. -`m[1:3](x)` will calculate the output of the first three layers. +Collects multiple layers / functions to be called in sequence +on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`, +and if names are given, `m.name == m[1]` etc. # Examples @@ -15,35 +14,43 @@ julia> m = Chain(x -> x^2, x -> x+1); julia> m(5) == 26 true -julia> m = Chain(Dense(10, 5), Dense(5, 2)); +julia> m = Chain(Dense(10, 5, tanh), Dense(5, 2)); -julia> x = rand(10); +julia> x = rand(10, 32); julia> m(x) == m[2](m[1](x)) true + +julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)), + dec = Dense(5, 2)); + +julia> m2(x) == (m2.dec ∘ m2.enc)(x) +true ``` """ -struct Chain{T<:Tuple} +struct Chain{T} layers::T Chain(xs...) = new{typeof(xs)}(xs) + Chain(; kw...) = new{typeof(values(kw))}(values(kw)) end -@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, - Base.iterate, Base.lastindex +Base.getproperty(c::Chain, s::Symbol) = getproperty(getfield(c, :layers), s) +for fun in (:getindex, :length, :first, :last, :iterate, :lastindex, :propertynames, :keys) + @eval Base.$fun(c::Chain, args...) = Base.$fun(getfield(c, :layers), args...) +end -functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...) +functor(::Type{<:Chain}, c) = getfield(c, :layers), ls -> Chain(ls...) applychain(::Tuple{}, x) = x -applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) +applychain(fs, x) = applychain(tail(Tuple(fs)), first(fs)(x)) -(c::Chain)(x) = applychain(c.layers, x) +(c::Chain)(x) = applychain(getfield(c, :layers), x) -Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) +Base.getindex(c::Chain, i::AbstractArray) = Chain(getfield(c, :layers)[i]...) function Base.show(io::IO, c::Chain) - print(io, "Chain(") - join(io, c.layers, ", ") - print(io, ")") + print(io, "Chain") + show(io, getfield(c, :layers)) end # This is a temporary and naive implementation @@ -56,24 +63,20 @@ end Calculate the forward results of each layers in Chain `c` with `input` as model input. """ -function activations(c::Chain, input) - extraChain(c.layers, input) -end +activations(c::Chain, input) = extraChain(getfield(c, :layers), input) function extraChain(fs::Tuple, x) - res = first(fs)(x) - return (res, extraChain(Base.tail(fs), res)...) + res = first(fs)(x) + return (res, extraChain(Base.tail(Tuple(fs)), res)...) end - extraChain(::Tuple{}, x) = () - """ Dense(in, out, σ=identity; bias=true, init=glorot_uniform) Dense(W::AbstractMatrix, [bias, σ]) -Create a traditional `Dense` layer, whose forward pass is given by: +Create a traditional fully-connected layer, whose forward pass is given by: y = σ.(W * x .+ bias) diff --git a/src/layers/show.jl b/src/layers/show.jl index 40d49dd9d1..90691796a4 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -13,16 +13,23 @@ for T in [ end end -function _big_show(io::IO, obj, indent::Int=0) +function _big_show(io::IO, obj, indent::Int=0, name=nothing) children = trainable(obj) if all(_show_leaflike, children) - _layer_show(io, obj, indent) + _layer_show(io, obj, indent, name) else - println(io, " "^indent, nameof(typeof(obj)), "(") - for c in children - _big_show(io, c, indent+2) + println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(") + if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) + # then we insert names -- can this be done more generically? + for k in Base.keys(obj) + _big_show(io, obj[k], indent+2, k) + end + else + for c in children + _big_show(io, c, indent+2) + end end - if indent == 0 + if indent == 0 # i.e. this is the outermost container print(io, ")") _big_finale(io, obj) else @@ -49,8 +56,9 @@ for T in [ end end -function _layer_show(io::IO, layer, indent::Int=0) - str = sprint(show, layer, context=io) +function _layer_show(io::IO, layer, indent::Int=0, name=nothing) + _str = isnothing(name) ? "" : "$name = " + str = _str * sprint(show, layer, context=io) print(io, " "^indent, str, indent==0 ? "" : ",") if !isempty(params(layer)) print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) From 45a5c1000ba1cc85d2bc3e4f9e33d172309625cf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 15:16:50 -0400 Subject: [PATCH 02/20] printing tests --- test/layers/show.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/layers/show.jl b/test/layers/show.jl index 9c689eba49..fa44fb6d1f 100644 --- a/test/layers/show.jl +++ b/test/layers/show.jl @@ -2,7 +2,7 @@ @testset "layer printing" begin # 2-arg show, defined with layes @test repr(Dense(2,3)) == "Dense(2, 3)" - @test repr(Chain(Dense(2,3))) == "Chain(Dense(2, 3))" + @test repr(Chain(Dense(2,3))) == "Chain(Dense(2, 3),)" end @testset "nested model printing" begin # 3-arg show, defined in show.jl @@ -35,12 +35,12 @@ end @test !occursin("# Total:", toplevel_chain) vector_chain = repr("text/plain", [Chain(Dense(2,3)), Chain(Dense(2,3))]) - @test occursin("Chain(Dense(2, 3))", vector_chain) + @test occursin("Chain(Dense(2, 3)", vector_chain) @test occursin("# 9 parameters", vector_chain) @test !occursin("# Total:", vector_chain) matrix_chain = repr("text/plain", fill(Chain(Dense(2,3)), 3,3)) - @test occursin("Chain(Dense(2, 3))", matrix_chain) + @test occursin("Chain(Dense(2, 3)", matrix_chain) @test !occursin("# 9 parameters", matrix_chain) @test !occursin("# Total:", matrix_chain) From d6be7007b77b8da6fcada8cd854fd0d1eb61e535 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 15:44:14 -0400 Subject: [PATCH 03/20] fixup --- src/layers/basic.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7360e14084..a4d03cb01a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -31,26 +31,31 @@ true struct Chain{T} layers::T Chain(xs...) = new{typeof(xs)}(xs) - Chain(; kw...) = new{typeof(values(kw))}(values(kw)) + function Chain(; kw...) + :layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`")) + isempty(kw) && return new{Tuple{}}(()) + new{typeof(values(kw))}(values(kw)) + end end -Base.getproperty(c::Chain, s::Symbol) = getproperty(getfield(c, :layers), s) +Base.parent(c::Chain) = getfield(c, :layers) +Base.getproperty(c::Chain, s::Symbol) = s === :layers ? parent(c) : getproperty(parent(c), s) for fun in (:getindex, :length, :first, :last, :iterate, :lastindex, :propertynames, :keys) - @eval Base.$fun(c::Chain, args...) = Base.$fun(getfield(c, :layers), args...) + @eval Base.$fun(c::Chain, args...) = Base.$fun(parent(c), args...) end -functor(::Type{<:Chain}, c) = getfield(c, :layers), ls -> Chain(ls...) +functor(::Type{<:Chain}, c) = parent(c), ls -> Chain(ls...) applychain(::Tuple{}, x) = x applychain(fs, x) = applychain(tail(Tuple(fs)), first(fs)(x)) -(c::Chain)(x) = applychain(getfield(c, :layers), x) +(c::Chain)(x) = applychain(parent(c), x) -Base.getindex(c::Chain, i::AbstractArray) = Chain(getfield(c, :layers)[i]...) +Base.getindex(c::Chain, i::AbstractArray) = Chain(parent(c)[i]...) function Base.show(io::IO, c::Chain) print(io, "Chain") - show(io, getfield(c, :layers)) + show(io, parent(c)) end # This is a temporary and naive implementation @@ -63,11 +68,11 @@ end Calculate the forward results of each layers in Chain `c` with `input` as model input. """ -activations(c::Chain, input) = extraChain(getfield(c, :layers), input) +activations(c::Chain, input) = extraChain(parent(c), input) function extraChain(fs::Tuple, x) res = first(fs)(x) - return (res, extraChain(Base.tail(Tuple(fs)), res)...) + return (res, extraChain(Base.tail(fs), res)...) end extraChain(::Tuple{}, x) = () From 74c8d071ce66b2fda9841e7ca8205c81d513937e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 16:19:23 -0400 Subject: [PATCH 04/20] fixup --- src/layers/basic.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a4d03cb01a..cc03bc58f0 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -39,10 +39,11 @@ struct Chain{T} end Base.parent(c::Chain) = getfield(c, :layers) +Base.propertynames(c::Chain) = (Base.keys(parent(c))..., :layers) Base.getproperty(c::Chain, s::Symbol) = s === :layers ? parent(c) : getproperty(parent(c), s) -for fun in (:getindex, :length, :first, :last, :iterate, :lastindex, :propertynames, :keys) - @eval Base.$fun(c::Chain, args...) = Base.$fun(parent(c), args...) -end + +@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, + Base.iterate, Base.lastindex, Base.keys functor(::Type{<:Chain}, c) = parent(c), ls -> Chain(ls...) From 75223ef33b0e3aea059ea6065de3dfefe5129c18 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 16:19:35 -0400 Subject: [PATCH 05/20] same idea for Parallel --- src/layers/basic.jl | 24 +++++++++++++++++++++--- src/layers/show.jl | 5 +++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index cc03bc58f0..a3e59976ae 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -387,6 +387,7 @@ end """ Parallel(connection, layers...) + Parallel(connection; name = layer, ...) Create a 'Parallel' layer that passes an input array to each path in `layers`, reducing the output with `connection`. @@ -421,21 +422,38 @@ struct Parallel{F, T} end Parallel(connection, layers...) = Parallel(connection, layers) +function Parallel(connection; kw...) + layers = NamedTuple(kw) + if :layers in Base.keys(layers) || :layers in Base.keys(layers) + throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `layers`")) + elseif isempty(layers) + throw(ArgumentError("can't construct a Parallel layer with no paths")) + end + Parallel(connection, layers) +end @functor Parallel -(m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers) -(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, m.layers, xs) +(m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, Tuple(m.layers)) +(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs) (m::Parallel)(xs::Tuple) = m(xs...) Base.getindex(m::Parallel, i::Integer) = m.layers[i] +Base.getindex(m::Parallel, s::Symbol) = m.layers[s] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) +Base.parent(m::Parallel) = getfield(m, :layers) +Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) +Base.propertynames(m::Parallel) = (Base.keys(getfield(m, :layers))..., :connection, :layers) +Base.getproperty(m::Parallel, s::Symbol) = s === :connection ? getfield(m, :connection) : + s === :layers ? parent(m) : getproperty(parent(m), s) + trainable(m::Parallel) = (m.connection, m.layers...) function Base.show(io::IO, m::Parallel) print(io, "Parallel(", m.connection, ", ") - join(io, m.layers, ", ") + # join(io, m.layers, ", ") + show(io, m.layers) # this is a bit ugly, but should parse. Can trim the brackets with a bit more effort. print(io, ")") end diff --git a/src/layers/show.jl b/src/layers/show.jl index 90691796a4..94afdb92d0 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -24,6 +24,11 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) for k in Base.keys(obj) _big_show(io, obj[k], indent+2, k) end + elseif obj isa Parallel{<:Any, <:NamedTuple} + _big_show(io, obj.connection, indent+2) + for k in Base.keys(obj) + _big_show(io, obj[k], indent+2, k) + end else for c in children _big_show(io, c, indent+2) From 454b9110bfa54e49c9a954cad9691a1ed4bc10be Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 16:42:52 -0400 Subject: [PATCH 06/20] rm parent --- src/layers/basic.jl | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a3e59976ae..6dc4a1d968 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -38,25 +38,24 @@ struct Chain{T} end end -Base.parent(c::Chain) = getfield(c, :layers) -Base.propertynames(c::Chain) = (Base.keys(parent(c))..., :layers) -Base.getproperty(c::Chain, s::Symbol) = s === :layers ? parent(c) : getproperty(parent(c), s) +Base.propertynames(c::Chain) = (Base.keys(c)..., :layers) +Base.getproperty(c::Chain, s::Symbol) = s === :layers ? getfield(c, :layers) : getproperty(getfield(c, :layers), s) @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys -functor(::Type{<:Chain}, c) = parent(c), ls -> Chain(ls...) +functor(::Type{<:Chain}, c) = getfield(c, :layers), ls -> Chain(ls...) applychain(::Tuple{}, x) = x -applychain(fs, x) = applychain(tail(Tuple(fs)), first(fs)(x)) +applychain(fs, x) = applychain(tail(fs), first(fs)(x)) -(c::Chain)(x) = applychain(parent(c), x) +(c::Chain)(x) = applychain(Tuple(getfield(c, :layers)), x) -Base.getindex(c::Chain, i::AbstractArray) = Chain(parent(c)[i]...) +Base.getindex(c::Chain, i::AbstractArray) = Chain(getfield(c, :layers)[i]...) function Base.show(io::IO, c::Chain) print(io, "Chain") - show(io, parent(c)) + show(io, c.layers) # allows NamedTuple, but prints a trailing comma sometimes end # This is a temporary and naive implementation @@ -69,7 +68,7 @@ end Calculate the forward results of each layers in Chain `c` with `input` as model input. """ -activations(c::Chain, input) = extraChain(parent(c), input) +activations(c::Chain, input) = extraChain(Tuple(c.layers), input) function extraChain(fs::Tuple, x) res = first(fs)(x) @@ -442,11 +441,10 @@ Base.getindex(m::Parallel, i::Integer) = m.layers[i] Base.getindex(m::Parallel, s::Symbol) = m.layers[s] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) -Base.parent(m::Parallel) = getfield(m, :layers) Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) Base.propertynames(m::Parallel) = (Base.keys(getfield(m, :layers))..., :connection, :layers) Base.getproperty(m::Parallel, s::Symbol) = s === :connection ? getfield(m, :connection) : - s === :layers ? parent(m) : getproperty(parent(m), s) + s === :layers ? getfield(m, :layers) : getproperty(getfield(m, :layers), s) trainable(m::Parallel) = (m.connection, m.layers...) From 5c24cab993d18d3668ccb6666cc49c323c88aec7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 16:46:10 -0400 Subject: [PATCH 07/20] allow m[1:2] with names --- src/layers/basic.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 6dc4a1d968..ed08278f16 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -52,6 +52,8 @@ applychain(fs, x) = applychain(tail(fs), first(fs)(x)) (c::Chain)(x) = applychain(Tuple(getfield(c, :layers)), x) Base.getindex(c::Chain, i::AbstractArray) = Chain(getfield(c, :layers)[i]...) +Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = + Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(getfield(c, :layers))[i])...) function Base.show(io::IO, c::Chain) print(io, "Chain") From 23e0811b9a81f95252b270beba1a377caba19f4e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 16:50:31 -0400 Subject: [PATCH 08/20] use _show_layers --- src/layers/basic.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ed08278f16..366b8a7a17 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -56,10 +56,14 @@ Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(getfield(c, :layers))[i])...) function Base.show(io::IO, c::Chain) - print(io, "Chain") - show(io, c.layers) # allows NamedTuple, but prints a trailing comma sometimes + print(io, "Chain(") + _show_layers(io, c.layers) + print(io, ")") end +_show_layers(io, layers::Tuple) = join(io, layers, ", ") +_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") + # This is a temporary and naive implementation # it might be replaced in the future for better performance # see issue https://github.com/FluxML/Flux.jl/issues/702 @@ -452,8 +456,7 @@ trainable(m::Parallel) = (m.connection, m.layers...) function Base.show(io::IO, m::Parallel) print(io, "Parallel(", m.connection, ", ") - # join(io, m.layers, ", ") - show(io, m.layers) # this is a bit ugly, but should parse. Can trim the brackets with a bit more effort. + _show_layers(io, m.layers) print(io, ")") end From 1d7c41a0d3a975a7aa8ca0a45bf1b6ed1cdb5f6a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 16:53:26 -0400 Subject: [PATCH 09/20] Revert "printing tests" This reverts commit 45a5c1000ba1cc85d2bc3e4f9e33d172309625cf. --- test/layers/show.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/layers/show.jl b/test/layers/show.jl index fa44fb6d1f..9c689eba49 100644 --- a/test/layers/show.jl +++ b/test/layers/show.jl @@ -2,7 +2,7 @@ @testset "layer printing" begin # 2-arg show, defined with layes @test repr(Dense(2,3)) == "Dense(2, 3)" - @test repr(Chain(Dense(2,3))) == "Chain(Dense(2, 3),)" + @test repr(Chain(Dense(2,3))) == "Chain(Dense(2, 3))" end @testset "nested model printing" begin # 3-arg show, defined in show.jl @@ -35,12 +35,12 @@ end @test !occursin("# Total:", toplevel_chain) vector_chain = repr("text/plain", [Chain(Dense(2,3)), Chain(Dense(2,3))]) - @test occursin("Chain(Dense(2, 3)", vector_chain) + @test occursin("Chain(Dense(2, 3))", vector_chain) @test occursin("# 9 parameters", vector_chain) @test !occursin("# Total:", vector_chain) matrix_chain = repr("text/plain", fill(Chain(Dense(2,3)), 3,3)) - @test occursin("Chain(Dense(2, 3)", matrix_chain) + @test occursin("Chain(Dense(2, 3))", matrix_chain) @test !occursin("# 9 parameters", matrix_chain) @test !occursin("# Total:", matrix_chain) From 537064ad2887e6705d94b2a5b5d1922d76169ab8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 17:01:52 -0400 Subject: [PATCH 10/20] Parallel constructor --- src/layers/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 366b8a7a17..824a6da3e3 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -429,10 +429,10 @@ end Parallel(connection, layers...) = Parallel(connection, layers) function Parallel(connection; kw...) layers = NamedTuple(kw) - if :layers in Base.keys(layers) || :layers in Base.keys(layers) - throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `layers`")) + if :layers in Base.keys(layers) || :connection in Base.keys(layers) + throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`")) elseif isempty(layers) - throw(ArgumentError("can't construct a Parallel layer with no paths")) + Parallel(connection, ()) end Parallel(connection, layers) end From 533767d41954b7ad599e2cbbea4e289bd126f674 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 17:14:41 -0400 Subject: [PATCH 11/20] tidy --- src/layers/basic.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 824a6da3e3..1a5950b410 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -47,7 +47,7 @@ Base.getproperty(c::Chain, s::Symbol) = s === :layers ? getfield(c, :layers) : g functor(::Type{<:Chain}, c) = getfield(c, :layers), ls -> Chain(ls...) applychain(::Tuple{}, x) = x -applychain(fs, x) = applychain(tail(fs), first(fs)(x)) +applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) (c::Chain)(x) = applychain(Tuple(getfield(c, :layers)), x) @@ -60,7 +60,6 @@ function Base.show(io::IO, c::Chain) _show_layers(io, c.layers) print(io, ")") end - _show_layers(io, layers::Tuple) = join(io, layers, ", ") _show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") @@ -443,8 +442,7 @@ end (m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs) (m::Parallel)(xs::Tuple) = m(xs...) -Base.getindex(m::Parallel, i::Integer) = m.layers[i] -Base.getindex(m::Parallel, s::Symbol) = m.layers[s] +Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) From ddcee35bb289bc061892e166f2fa8fa26889abfd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Jul 2021 17:22:00 -0400 Subject: [PATCH 12/20] docstring for Parallel --- src/layers/basic.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1a5950b410..dc742d4ccd 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -394,11 +394,14 @@ end Parallel(connection; name = layer, ...) Create a 'Parallel' layer that passes an input array to each path in -`layers`, reducing the output with `connection`. +`layers`, before reducing the output with `connection`. Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`. If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. +Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor. +These can be accessed by indexing or with a dot: `m[1] == m[:name] == m.name` is the first layer. + # Examples ```jldoctest @@ -406,18 +409,24 @@ julia> model = Chain(Dense(3, 5), Parallel(vcat, Dense(5, 4), Chain(Dense(5, 7), Dense(7, 4))), Dense(8, 17)); -julia> size(model(rand(3))) +julia> model(rand(3)) |> size (17,) -julia> model = Parallel(+, Dense(10, 2), Dense(5, 2)) +julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2)) Parallel( +, - Dense(10, 2), # 22 parameters - Dense(5, 2), # 12 parameters + α = Dense(10, 2, tanh), # 22 parameters + β = Dense(5, 2), # 12 parameters ) # Total: 4 arrays, 34 parameters, 392 bytes. -julia> size(model(rand(10), rand(5))) +julia> model2(rand(10), rand(5)) |> size +(2,) + +julia> model2.α(rand(10)) |> size (2,) + +julia> model2.β == model2[2] +true ``` """ struct Parallel{F, T} From 9195f9213718458858be2fa8e7b3bfa2ca639e13 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 14:28:05 -0400 Subject: [PATCH 13/20] rm field access via dot, for now --- src/layers/basic.jl | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index dc742d4ccd..fc7ae2a257 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -4,7 +4,7 @@ Collects multiple layers / functions to be called in sequence on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`, -and if names are given, `m.name == m[1]` etc. +and if names are given, `m[:name] == m[1]` etc. (but not yet `m.name`). # Examples @@ -24,7 +24,7 @@ true julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)), dec = Dense(5, 2)); -julia> m2(x) == (m2.dec ∘ m2.enc)(x) +julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) true ``` """ @@ -38,22 +38,19 @@ struct Chain{T} end end -Base.propertynames(c::Chain) = (Base.keys(c)..., :layers) -Base.getproperty(c::Chain, s::Symbol) = s === :layers ? getfield(c, :layers) : getproperty(getfield(c, :layers), s) - @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys -functor(::Type{<:Chain}, c) = getfield(c, :layers), ls -> Chain(ls...) +functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...) applychain(::Tuple{}, x) = x applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) -(c::Chain)(x) = applychain(Tuple(getfield(c, :layers)), x) +(c::Chain)(x) = applychain(Tuple(c.layers), x) -Base.getindex(c::Chain, i::AbstractArray) = Chain(getfield(c, :layers)[i]...) +Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = - Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(getfield(c, :layers))[i])...) + Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...) function Base.show(io::IO, c::Chain) print(io, "Chain(") @@ -400,7 +397,7 @@ Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor. -These can be accessed by indexing or with a dot: `m[1] == m[:name] == m.name` is the first layer. +These can be accessed by indexing or with a dot: `m[1] == m[:name]` is the first layer. # Examples @@ -422,10 +419,10 @@ Parallel( julia> model2(rand(10), rand(5)) |> size (2,) -julia> model2.α(rand(10)) |> size +julia> model2[:α](rand(10)) |> size (2,) -julia> model2.β == model2[2] +julia> model2[:β] == model2[2] true ``` """ @@ -455,9 +452,6 @@ Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) -Base.propertynames(m::Parallel) = (Base.keys(getfield(m, :layers))..., :connection, :layers) -Base.getproperty(m::Parallel, s::Symbol) = s === :connection ? getfield(m, :connection) : - s === :layers ? getfield(m, :layers) : getproperty(getfield(m, :layers), s) trainable(m::Parallel) = (m.connection, m.layers...) From 7cb756de7ec56826fa4562af60c12426a7fa0d48 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 14:28:17 -0400 Subject: [PATCH 14/20] add tests --- test/layers/basic.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0c0560ced6..3d6bbf7dcf 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -19,12 +19,28 @@ import Flux: activations @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10)) @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10)) # numeric test should be put into testset of corresponding layer + + @test_nowarn Chain(first = Dense(10, 5, σ), second = Dense(5, 2))(randn(10)) + m = Chain(first = Dense(10, 5, σ), second = Dense(5, 2)) + @test m[:first] == m[1] + @test m[1:2] == m + + @test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name + + mt = Chain((Dense(2,2),)) # constructor accepts Any + @test_throws MethodError mt(rand(2)) # but Tuples aren't callable + mnt = Chain((; name = Dense(2,2),)) + @test_throws MethodError mnt(rand(2)) end @testset "Activations" begin c = Chain(Dense(3,5,relu), Dense(5,1,relu)) X = Float32.([1.0; 1.0; 1.0]) @test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c)) + + c2 = Chain(enc = c[1], dec = c[2]) + @test Flux.activations(c, X) == Flux.activations(c2, X) + @test_nowarn gradient(()->Flux.activations(c2, X)[2][1], params(c2)) end @testset "Dense" begin @@ -184,11 +200,21 @@ import Flux: activations @testset "concat size" begin input = randn(10, 2) @test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4) + @test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4) end @testset "vararg input" begin inputs = randn(10), randn(5), randn(4) @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) + @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,) + end + + @testset "named access" begin + m = Parallel(hcat, one = Dense(10, 10), two = identity) + @test m[1] == m[:one] + + @test_throws ArgumentError Parallel(hcat, layers = Dense(10, 10), two = identity) # reserved names + @test_throws ArgumentError Parallel(hcat, connection = Dense(10, 10), two = identity) end end From b391a5ca444e367373cb6078fb673ef58c663c24 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 15:41:34 -0400 Subject: [PATCH 15/20] add NEWS item --- NEWS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/NEWS.md b/NEWS.md index 6a39614437..e72dcbd76a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,11 @@ # Flux Release Notes +## v0.12.7 +* The layers within `Chain` and `Parallel` may now [have names](https://github.com/FluxML/Flux.jl/issues/1680). + ## v0.12.5 * Added option to configure [`groups`](https://github.com/FluxML/Flux.jl/pull/1531) in `Conv`. +* REPL printing via [`show`](https://github.com/FluxML/Flux.jl/pull/1467) displays parameter counts. ## v0.12.4 * Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516) From 9abaf54fd2cdbf6ca40c0c09b51a711887aac9df Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 15:41:55 -0400 Subject: [PATCH 16/20] a test --- test/layers/show.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/layers/show.jl b/test/layers/show.jl index 9c689eba49..c551bad978 100644 --- a/test/layers/show.jl +++ b/test/layers/show.jl @@ -3,6 +3,7 @@ @test repr(Dense(2,3)) == "Dense(2, 3)" @test repr(Chain(Dense(2,3))) == "Chain(Dense(2, 3))" + @test repr(Chain(lay=Dense(2,3))) == "Chain(lay = Dense(2, 3))" end @testset "nested model printing" begin # 3-arg show, defined in show.jl From 300c95752f1bef7c4837fb1c4e0152d0727ad3f4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 1 Aug 2021 11:07:03 -0400 Subject: [PATCH 17/20] Apply suggestions from code review Co-authored-by: Kyle Daruwalla --- src/layers/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index fc7ae2a257..8e1f931fe0 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -4,7 +4,7 @@ Collects multiple layers / functions to be called in sequence on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`, -and if names are given, `m[:name] == m[1]` etc. (but not yet `m.name`). +and if names are given, `m[:name] == m[1]` etc. (but not `m.name` at present). # Examples @@ -397,7 +397,7 @@ Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor. -These can be accessed by indexing or with a dot: `m[1] == m[:name]` is the first layer. +These can be accessed by indexing: `m[1] == m[:name]` is the first layer. # Examples From ede4307c5f8c4b8b071870279f38337985f81870 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 1 Aug 2021 11:34:51 -0400 Subject: [PATCH 18/20] Update src/layers/basic.jl --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 8e1f931fe0..9c40c60818 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -4,7 +4,7 @@ Collects multiple layers / functions to be called in sequence on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`, -and if names are given, `m[:name] == m[1]` etc. (but not `m.name` at present). +and if names are given, `m[:name] == m[1]` etc. # Examples From 894c32baddcf93a89d061e35ff1d77ff7c3a97f0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 1 Aug 2021 13:40:12 -0400 Subject: [PATCH 19/20] let's leave these unwanted constructors untested, call them bugs --- test/layers/basic.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 3d6bbf7dcf..cebd123692 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -26,11 +26,6 @@ import Flux: activations @test m[1:2] == m @test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name - - mt = Chain((Dense(2,2),)) # constructor accepts Any - @test_throws MethodError mt(rand(2)) # but Tuples aren't callable - mnt = Chain((; name = Dense(2,2),)) - @test_throws MethodError mnt(rand(2)) end @testset "Activations" begin From fec920b7439a1421bf0ec99ed39872f3ad3557f9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 4 Aug 2021 11:42:12 -0400 Subject: [PATCH 20/20] Update src/layers/basic.jl Co-authored-by: Dhairya Gandhi --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9c40c60818..31ad324416 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -83,7 +83,7 @@ extraChain(::Tuple{}, x) = () Dense(in, out, σ=identity; bias=true, init=glorot_uniform) Dense(W::AbstractMatrix, [bias, σ]) -Create a traditional fully-connected layer, whose forward pass is given by: +Create a traditional `Dense` layer, whose forward pass is given by: y = σ.(W * x .+ bias)