Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom model can not be trained #2187

Closed
PokeLu opened this issue Feb 11, 2023 · 10 comments
Closed

Custom model can not be trained #2187

PokeLu opened this issue Feb 11, 2023 · 10 comments

Comments

@PokeLu
Copy link

PokeLu commented Feb 11, 2023

I am pretty new to Julia and Flux. I am trying to build a simple neural network but using an attention layer. Such an attention model includes a self-defined attention layer and some non-trainable hyperparameters in the struct. I wrote the code as follows using Julia 1.8.2 and Flux v0.13.9, which works fine in the inference(feed-forward) mode:

using Flux

struct Attention
    W
    v
end

Attention(vehile_embedding_dim::Integer) = Attention(
    Dense(vehile_embedding_dim => vehile_embedding_dim, tanh),
    Dense(vehile_embedding_dim, 1, bias=false, init=Flux.zeros32)
)

function (a::Attention)(inputs)
    alphas = [a.v(e) for e in a.W.(inputs)]
    alphas = sigmoid.(alphas)
    output = sum([alpha.*input for (alpha, input) in zip(alphas, inputs)])
    return output
end

Flux.@functor Attention

struct AttentionNet 
    embedding
    attention
    fc_output
    vehicle_num::Integer #non-trainable hyperparameter
    vehicle_dim::Integer #non-trainable hyperparameter
end

AttentionNet(vehicle_num::Integer, vehicle_dim::Integer, embedding_dim::Integer) = AttentionNet(
    Dense(vehicle_dim+1 => embedding_dim, relu),
    Attention(embedding_dim),
    Dense(1+embedding_dim => 1),
    vehicle_num,
    vehicle_dim
)

function (a_net::AttentionNet)(x)
    time_idx = x[[1], :]
    vehicle_states = [x[2+a_net.vehicle_dim*(i-1):2+a_net.vehicle_dim*i-1, :] for i in 1:a_net.vehicle_num]
    vehicle_states = [vcat(time_idx, vehicle_state) for vehicle_state in vehicle_states]

    vehicle_embeddings = a_net.embedding.(vehicle_states)
    attention_output = a_net.attention(vehicle_embeddings)
    
    x = a_net.fc_output(vcat(time_idx, attention_output))
    return x
end

Flux.@functor AttentionNet
Flux.trainable(a_net::AttentionNet) = (a_net.embedding, a_net.attention, a_net.fc_output)



fake_inputs = rand(22, 32)
fake_outputs = rand(1, 32)
a_net = AttentionNet(3, 7, 64)|> gpu
opt = Adam(.01)
opt_state = Flux.setup(opt, a_net)

data = Flux.DataLoader((fake_inputs, fake_outputs)|>gpu, batchsize=32, shuffle=true)

Flux.train!(a_net, data, opt_state) do m, x, y
    Flux.mse(m(x), y)
end

But when I trained it, I got the following error message and a warning:

┌ Warning: trainable(x) should now return a NamedTuple with the field names, not a Tuple
└ @ Optimisers C:\Users\Herr LU\.julia\packages\Optimisers\SoKJO\src\interface.jl:164
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at C:\Users\Herr LU\.julia\packages\InitialValues\OWP8V\src\InitialValues.jl:154
  +(::ChainRulesCore.Tangent{P}, ::P) where P at C:\Users\Herr LU\.julia\packages\ChainRulesCore\C73ay\src\tangent_arithmetic.jl:146
  ...
Stacktrace:
  [1] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:17
  [2] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, zs::Base.RefValue{Any}) (repeats 2 times)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:22
  [3] Pullback
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:39 [inlined]
  [4] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
  [5] Pullback
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:62 [inlined]
  [6] #208
    @ C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:206 [inlined]
  [7] #2066#back
    @ C:\Users\Herr LU\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
  [8] Pullback
    @ C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:102 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float32)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45
 [11] withgradient(f::Function, args::AttentionNet)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:133
 [12] macro expansion
    @ C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:102 [inlined]
 [13] macro expansion
    @ C:\Users\Herr LU\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [14] train!(loss::Function, model::AttentionNet, data::MLUtils.DataLoader{Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Named
Tuple{(:embedding, :attention, :fc_output, :vehicle_num, :vehicle_dim), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArr
ay{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64
, Float64}}}, Tuple{}}}, NamedTuple{(:W, :v), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.De
viceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, N
amedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}, Tupl
e{}}}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Opt
imisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, Tuple{}, Tuple{}}}; cb::Nothing)
    @ Flux.Train C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:100
 [15] train!(loss::Function, model::AttentionNet, data::MLUtils.DataLoader{Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Named
Tuple{(:embedding, :attention, :fc_output, :vehicle_num, :vehicle_dim), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArr
ay{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64
, Float64}}}, Tuple{}}}, NamedTuple{(:W, :v), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.De
viceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, N
amedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}, Tupl
e{}}}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Opt
imisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, Tuple{}, Tuple{}}})
    @ Flux.Train C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:97
 [16] top-level scope
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:61

I followed the instruction from the official tutorial on custom layers, but it doesn’t specify how to get custom layers properly trained. How should I properly define a custom model with some non-trainable hyperparameters?

@mcabbott
Copy link
Member

The immediate problem is that it's trying to tell you to define trainable differently:

Flux.trainable(a_net::AttentionNet) = (embedding = a_net.embedding, attention = a_net.attention, fc_output = a_net.fc_output)

It needs to return a NamedTuple, with a subset of fieldnames(AttentionNet).

But in fact it need not be define at all. Integers, or other scalars in the struct, will be ignored anyway -- Flux cannot treat these as trainable parameters. Simply deleting that method definition removes the problem.

Sadly, after fixing that, I run into another error:

julia> Flux.train!(a_net, data, opt_state) do m, x, y
           Flux.mse(m(x), y)
       end
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(8 => 64, relu)  # 576 parameterssummary(x) = "8×32 Matrix{Float64}"
└ @ Flux ~/.julia/dev/Flux/src/layers/stateless.jl:77
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::NamedTuple{(:contents,), Tuple{Matrix{Float64}}})

Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:578
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any)
   @ InitialValues ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
  +(::ChainRulesCore.AbstractThunk, ::Any)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_arithmetic.jl:122
  ...

Stacktrace:
  [1] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{Matrix{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:17

This one is much deeper in the weeds. Somehow a type instability is causing something to be boxed, and Zygote is trying to add its gradient to another in a different format, and... why do I know these things? You might be able to hack around it something like this:

julia> using Zygote

julia> Zygote.accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,)}) = Zygote.accum(x[], y)

julia> Base.:+(x::NamedTuple{(:contents,)}, y::Base.RefValue{Any}) = Zygote.accum(x, y[])

julia> Zygote.refresh()

julia> Flux.train!(a_net, data, opt_state) do m, x, y
           Flux.mse(m(x), y)
       end

I have not looked closely but suspect that your code can be made more Zygote-friendly. Maybe sum([alpha.*input for (alpha, input) in zip(alphas, inputs)]) can be re-written as one array operation, not a loop with zip and broadcasting... but not sure.

@ToucheSir
Copy link
Member

I believe the boxing comes from these two lines:

vehicle_states = [x[2+a_net.vehicle_dim*(i-1):2+a_net.vehicle_dim*i-1, :] for i in 1:a_net.vehicle_num]
vehicle_states = [vcat(time_idx, vehicle_state) for vehicle_state in vehicle_states]

If you rename the second vehicle_states to something else like vehicle_states2, I suspect the error will go away. Probably better would be to use ntuple and map with tuples instead of array comprehensions, but that can be addressed after we figure out what's generating this Box first.

@PokeLu
Copy link
Author

PokeLu commented Feb 12, 2023

@ToucheSir

I followed your suggestion by changing the variable name:

    vehicle_states = [x[2+a_net.vehicle_dim*(i-1):2+a_net.vehicle_dim*i-1, :] for i in 1:a_net.vehicle_num]
    vehicle_states_with_time = [vcat(time_idx, vehicle_state) for vehicle_state in vehicle_states]

    vehicle_embeddings = a_net.embedding.(vehicle_states_with_time)

But still the same error.

@PokeLu
Copy link
Author

PokeLu commented Feb 12, 2023

I have not looked closely but suspect that your code can be made more Zygote-friendly

@mcabbott Thanks for your answer. I am not very familiar with programming in a Zygote manner. What is meant by 'Zygote-friendly'? Using array operations and avoiding all broadcasting behaviors or vector operations? But sometimes it could be inevitable to do so.

@ToucheSir
Copy link
Member

Broadcasting and "vector operations" are fine. Comprehensions over large (in terms of # of elements) arrays which reference multiple variables are often tricky.

@PokeLu
Copy link
Author

PokeLu commented Feb 12, 2023

Comprehensions over large (in terms of # of elements) arrays

I don't fully understand this. Could you refer to which part of my code could be tricky for Zygote and the reason for that?

@ToucheSir
Copy link
Member

ToucheSir commented Feb 12, 2023

Ok, the error was what I thought it was but I was looking in the wrong place! If you change these two lines:

x = a_net.fc_output(vcat(time_idx, attention_output))
return x

To:

y = a_net.fc_output(vcat(time_idx, attention_output))
return y

Or just:

return a_net.fc_output(vcat(time_idx, attention_output))

To explain what's going on, let's use this simplified example:

function hasbox(x)
  nums = [x + i for i in 1:5]
  x = sum(nums)
  return x
end
julia> @code_warntype hasbox(1)
MethodInstance for hasbox(::Int64)
  from hasbox(x) in Main at REPL[7]:1
Arguments
  #self#::Core.Const(hasbox)
  x@_2::Int64
Locals
  #23::var"#23#24"
  nums::Vector
  x@_5::Union{}
  x@_6::Union{Int64, Core.Box}
Body::Any
1 ─       (x@_6 = x@_2)
│         (x@_6 = Core.Box(x@_6::Int64))
│         (#23 = %new(Main.:(var"#23#24"), x@_6::Core.Box))
│   %4  = #23::var"#23#24"
│   %5  = (1:5)::Core.Const(1:5)
│   %6  = Base.Generator(%4, %5)::Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#23#24"}, Any[var"#23#24", Core.Const(1:5)])
│         (nums = Base.collect(%6))
│   %8  = Main.sum(nums)::Any
│         Core.setfield!(x@_6::Core.Box, :contents, %8)
│   %10 = Core.isdefined(x@_6::Core.Box, :contents)::Bool
└──       goto #3 if not %10
2 ─       goto #4
3 ─       Core.NewvarNode(:(x@_5))
└──       x@_5
4 ┄ %15 = Core.getfield(x@_6::Core.Box, :contents)::Any
└──       return %15

That's a lot to take in, but the important part is the x@_6::Union{Int64, Core.Box} . That means something is triggering https://docs.julialang.org/en/v1/manual/performance-tips/#man-performance-captured (see also JuliaLang/julia#15276).

But there are no anonymous functions or closures in hasbox, so what gives? As it turns out, the [... for ... in ...] syntax actually creates one! #23 = %new(Main.:(var"#23#24"), x@_6::Core.Box) is essentially defining an anonymous function generator_callback = i -> x + 1 which is passed to Base.Generator. Because x is reassigned to after this function is defined, we hit the captured variables in closures problem. That also explains why renaming the second assignment of x to something else like y solves the issue.

Edit: one way to make the issue more obvious is to use map or any other function that takes a callback instead of an array comprehension:

function hasbox(x)
  nums = map(i -> x + i, 1:5)
  x = sum(nums)
  return x
end

Here it's clear that x is being captured from the outside scope. Ultimately the choice of syntax is up to you, but I do find more people are tripped up by how Flux/Zygote interact with the more "magical" syntax of array comprehensions.

@PokeLu
Copy link
Author

PokeLu commented Feb 12, 2023

Thanks for the reply from @ToucheSir. The explanation is very clear and now the problem is solved. If there is no one else has anything to add, we can close this issue. (But I think Flux can be more specific when reporting bugs, such error messages make it pretty hard to track where the problem is.)

@ToucheSir
Copy link
Member

Glad we could help. Part of the reason the errors aren't better is that they don't actually come from Flux. Flux handles automatic differentiation through Zygote.jl, which relies on Julia compiler internals to function. Unfortunately this can lead to the occasional obscure error because what Zygote sees is much lower level than the code you write. Will have a think about if there's anything we can do to improve there, but no promises :)

@paulxshen
Copy link
Contributor

Had very similar error when switching from implicit parameter definition to explicit. Resolved when I switched back.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants