Skip to content

Commit

Permalink
Merge #1675
Browse files Browse the repository at this point in the history
1675: Adding GRUv3 support. r=darsnack a=mkschleg

As per the starting discussion in #1671, we should provide support for variations on the GRU and LSTM cell. 

In this PR, I added support for the GRU found in v3 of the [original GRU paper](https://arxiv.org/abs/1406.1078). Current support in Flux is for v1 only. [Tensorflow](https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU) supports several variations, with this as one of the variations. 

While the feature is added and usable in this PR, this is only a first pass at a design and could use further iterations. Some questions I have:
- Should we have new types for each variation of these cells? (another possibility is through parametric options)
- Should we have a shared constructor similar to Tensorflow/Pytorch? (it might make sense to rename the current GRU to GRUv1 if we want to do this).

### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] API changes require approval from a committer (different from the author, if applicable)


Co-authored-by: Matthew Schlegel <[email protected]>
Co-authored-by: Matthew Schlegel <[email protected]>
  • Loading branch information
3 people authored Aug 2, 2021
2 parents 0a21546 + 434c10e commit 5d2a955
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 11 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux Release Notes

## v0.12.7
* Added support for [`GRUv3`](https://github.com/FluxML/Flux.jl/pull/1675)

## v0.12.5
* Added option to configure [`groups`](https://github.com/FluxML/Flux.jl/pull/1531) in `Conv`.

Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
RNN, LSTM, GRU,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
Expand Down
61 changes: 56 additions & 5 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ end

# GRU

function _gru_output(Wi, Wh, b, x, h)
o = size(h, 1)
gx, gh = Wi*x, Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))

return gx, gh, r, z
end

struct GRUCell{A,V,S}
Wi::A
Wh::A
Expand All @@ -195,9 +204,7 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =

function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h′ = (1 .- z) .*.+ z .* h
sz = size(x)
Expand All @@ -212,8 +219,9 @@ Base.show(io::IO, l::GRUCell) =
"""
GRU(in::Integer, out::Integer)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences.
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v1 of the referenced paper.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
Expand All @@ -233,6 +241,49 @@ function Base.getproperty(m::GRUCell, sym::Symbol)
end
end


# GRU v3

struct GRUv3Cell{A,V,S}
Wi::A
Wh::A
b::V
Wh_h̃::A
state0::S
end

GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
= tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3))
h′ = (1 .- z) .*.+ z .* h
sz = size(x)
return h′, reshape(h′, :, sz[2:end]...)
end

@functor GRUv3Cell

Base.show(io::IO, l::GRUv3Cell) =
print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")

"""
GRUv3(in::Integer, out::Integer)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v3 of the referenced paper.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)


@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
4 changes: 2 additions & 2 deletions test/cuda/curnn.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Flux, CUDA, Test

@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(10, 5) |> gpu
x = gpu(rand(10))
(m̄,) = gradient(m -> sum(m(x)), m)
Expand All @@ -12,7 +12,7 @@ using Flux, CUDA, Test
end

@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
@testset for R in [RNN, GRU, LSTM, GRUv3], batch_size in (1, 5)
rnn = R(10, 5)
curnn = fmap(gpu, rnn)

Expand Down
6 changes: 3 additions & 3 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
end

@testset "RNN-shapes" begin
@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m1 = R(3, 5)
m2 = R(3, 5)
x1 = rand(Float32, 3)
Expand All @@ -58,10 +58,10 @@ end
end

@testset "RNN-input-state-eltypes" begin
@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(3, 5)
x = rand(Float64, 3, 1)
Flux.reset!(m)
@test_throws MethodError m(x)
end
end
end

0 comments on commit 5d2a955

Please sign in to comment.