Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NeuralODE training failed on GPU with Enzyme #2478

Closed
yunan-l opened this issue Aug 25, 2024 · 8 comments
Closed

NeuralODE training failed on GPU with Enzyme #2478

yunan-l opened this issue Aug 25, 2024 · 8 comments
Assignees
Labels
bug Something isn't working extensions Stuff about package extensions.

Comments

@yunan-l
Copy link

yunan-l commented Aug 25, 2024

Hi, when I try to train a NeuralODE with Discretecallback using sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true), I get:

Enzyme execution failed.
Enzyme: unhandled augmented forward for jl_f_finalizer
Stacktrace:
  [1] finalizer
    @ ./gcutils.jl:87
  [2] _
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83
  [3] CuArray
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79
  [4] derive
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799
  [5] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319
  [6] unsafe_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314
  [7] view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310
  [8] maybeview
    @ ./views.jl:148
  [9] macro expansion
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0
 [10] _getindex
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119
 [11] getproperty
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14
 [12] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0
 [13] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520


Stacktrace:
  [1] finalizer
    @ ./gcutils.jl:87 [inlined]
  [2] _
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83 [inlined]
  [3] CuArray
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79 [inlined]
  [4] derive
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799 [inlined]
  [5] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319 [inlined]
  [6] unsafe_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314 [inlined]
  [7] view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310 [inlined]
  [8] maybeview
    @ ./views.jl:148 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0 [inlined]
 [10] _getindex
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119 [inlined]
 [11] getproperty
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0 [inlined]
 [13] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520
 [14] Chain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:518 [inlined]
 [15] apply
    @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
 [16] dudt
    @ ./In[92]:24 [inlined]
 [17] dudt
    @ ./In[92]:20 [inlined]
 [18] ODEFunction
    @ ~/.julia/packages/SciMLBase/Q1klk/src/scimlfunctions.jl:2335 [inlined]
 [19] #138
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:490 [inlined]
 [20] diffejulia__138_128700_inner_1wrap
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:0
 [21] macro expansion
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:7049 [inlined]
 [22] enzyme_call
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6658 [inlined]
 [23] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6535 [inlined]
 [24] autodiff
    @ ~/.julia/packages/Enzyme/XGb4o/src/Enzyme.jl:320 [inlined]
 [25] _vecjacobian!(dλ::CuArray{Float32, 1, CUDA.DeviceMemory}, y::CuArray{Float32, 1, CUDA.DeviceMemory}, λ::CuArray{Float32, 1, CUDA.DeviceMemory}, p::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, t::Float32, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{SciMLSensitivity.AdjointDiffCache{Nothing, SciMLSensitivity.var"#138#142"{ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Tuple{CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, CuArray{Float32, 1, CUDA.DeviceMemory}, CuArray{Float32, 1, CUDA.DeviceMemory}, CuArray{Float32, 1, CUDA.DeviceMemory}, SciMLSensitivity.var"#138#142"{ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}, CuArray{Float32, 1, CUDA.DeviceMemory}, ODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}}, Nothing, ODEProblem{CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#109#113", SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#110#114"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#111#115"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}, Vector{Float32}, Vector{Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}}, Nothing, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing}, SciMLSensitivity.CheckpointSolution{ODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}}, Nothing, ODEProblem{CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#109#113", SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#110#114"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#111#115"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}, Vector{Float32}, Vector{Vector{CuArray{Float32, 1, CUDA.DeviceMemory}}}, Nothing, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing}, Vector{Tuple{Float32, Float32}}, @NamedTuple{reltol::Float64, abstol::Float64}, Nothing}, ODEProblem{CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#109#113", SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#110#114"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#111#115"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, SciMLBase.FullSpecialize, var"#dudt#52"{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, var"#dudt#51#53"{NeuralODE{Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, LuxCUDADevice{Nothing}, @Kwargs{sensealg::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}}, isautojacvec::EnzymeVJP, dgrad::CuArray{Float32, 1, CUDA.DeviceMemory}, dy::Nothing, W::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/se3y4/src/derivative_wrappers.jl:728
...

the whole log please see attached below:
EnzymeVJP.failed.txt

Here is the main julia code, which works on CPU, but not GPU.

struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, D, K} <:
       Lux.AbstractExplicitContainerLayer{(:model,)}
    model::M
    solver::So
    tspan::T
    device::D
    kwargs::K
end

function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), gpu=nothing, kwargs...)
    device = DetermineDevice(gpu=gpu)
    NeuralODE{typeof(model), typeof(solver), typeof(tspan), typeof(device), typeof(kwargs)}(model, solver, tspan, device, kwargs)
end


function (n::NeuralODE)(u0, ps, st, cb)

    function dudt(u, p, t; st=st)
        u_, st = Lux.apply(n.model, u, p, st)
        return u_
    end
    
    prob = ODEProblem{false}(ODEFunction{false}(dudt), u0, n.tspan, ps)
    
    sensealg = get(n.kwargs, :sensealg, InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true))
    
    tsteps = n.tspan[1]:n.tspan[2]
    
    sol = solve(prob, n.solver, saveat=tsteps, callback = cb, sensealg = sensealg)
    
    return DeviceArray(n.device, Array(sol)), st
end

function loss_neuralode(model, u0, p, st, cb)

    pred, st = model(u0, p, st, cb)
    
    loss = mean((pred .- observed_data).^2)
    
    return loss
end


function train_neuralode!(model, u0, p, st, cb, loss_func, opt_state, η_schedule; N_epochs=1, verbose=true, 
compute_initial_error::Bool=true, scheduler_offset::Int=0)
    
    best_p = copy(p)
    results = (i_epoch = Int[], train_loss=Float32[], learning_rate=Float32[], duration=Float32[], valid_loss=Float32[], test_loss=Float32[], loss_min=[Inf32], i_epoch_min=[1])
    
    progress = Progress(N_epochs, 1)
    
    # initial error 
    lowest_train_err = compute_initial_error ? loss_func(model, u0, p, st, cb) : Inf
    
    if verbose 
        println("______________________________")
        println("starting training epoch")
    end

    for i_epoch in 1:N_epochs

        Optimisers.adjust!(opt_state, η_schedule(i_epoch + scheduler_offset)) 

        epoch_start_time = time()

        losses = zeros(Float32, 1)

        loss_p(p) = loss_func(model, u0, p, st, cb)
        println("training loss:", loss_p(p))

        l, gs = Zygote.withgradient(loss_p, p)

        losses = l
        opt_state, p = Optimisers.update(opt_state, p, gs[1])


        train_err = l
        epoch_time = time() - epoch_start_time

        push!(results[:i_epoch], i_epoch)
        push!(results[:train_loss], train_err)
        push!(results[:learning_rate], η_schedule(i_epoch))
        push!(results[:duration], epoch_time)

        if train_err < lowest_train_err
            lowest_train_err = train_err
            best_p = deepcopy(p)
            results[:loss_min] .= lowest_train_err
            results[:i_epoch_min] .= i_epoch
        end

    end


    return model, best_p, st, results
    
end


#callback
times = [274.0f0]
affect!(integrator) = integrator.u[1:4] .= 0.0f0
cb = PresetTimeCallback(times, affect!; save_positions = (false, false)) #save_positions = (true, true)

const device = DetermineDevice()
CUDA.allowscalar(false) # Makes sure no slow operations are occurring


nn = Chain(
          Dense(4, 16, tanh),
          Dense(16, 16, tanh), 
          Dense(16, 4)
)

rng = Xoshiro(0)
p, st = Lux.setup(rng, nn)
p = ComponentArray(p) |> gdev
st = st |> gdev

u0 = Float32[0.0, 8.0, 0.0, 12.0] |> gdev
tspan = (0.0f0, 365.0f0) 

neural_ode = NeuralODE(nn; solver=Tsit5(), tspan = tspan, sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true))

loss = loss_neuralode

opt = Optimisers.AdamW(1f-3, (9f-1, 9.99f-1), 1f-6)
opt_state = Optimisers.setup(opt, p)
η_schedule = SinExp(λ0=1f-3,λ1=1f-5,period=20,decay=0.975f0)

println("starting training...")
neural_de, ps, st, results_ad = train_neuralode!(neural_ode, u0, p, st, cb, loss, opt_state, η_schedule; N_epochs=5, verbose=true)

https://discourse.julialang.org/t/neuralode-training-failed-on-gpu-with-enzyme/118537/5 @wsmoses

@yunan-l yunan-l added the bug Something isn't working label Aug 25, 2024
@wsmoses
Copy link
Contributor

wsmoses commented Aug 25, 2024

@maleadt I'd assign myself, but I don't have permissions. For ease, would it be possible to have permissions added?

@wsmoses
Copy link
Contributor

wsmoses commented Aug 25, 2024

@yunan-l can you paste the versions of CUDA/Enzyme you're using, as well as system inforamtion?

@yunan-l
Copy link
Author

yunan-l commented Aug 26, 2024

@yunan-l can you paste the versions of CUDA/Enzyme you're using, as well as system inforamtion?

sure,
the CUDA version,

CUDA runtime 12.5, artifact installation
CUDA driver 12.6
NVIDIA driver 535.154.5, originally for CUDA 12.2

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+535.154.5

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.0
- LLVM: 15.0.7

1 device:
  0: NVIDIA H100 80GB HBM3 (sm_90, 75.982 GiB / 79.647 GiB available)

the Enzyme version,

Enzyme v0.12.32

the system information,

Platform Info,
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 9554 64-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
  Threads: 2 on 128 virtual cores

@maleadt maleadt added the extensions Stuff about package extensions. label Aug 26, 2024
@yunan-l
Copy link
Author

yunan-l commented Aug 27, 2024

Hi @wsmoses, there is an update error, which seems to be after I updated the CUDA.jl. Now, the CUDA version is below,

CUDA.versioninfo()

CUDA runtime 12.5, artifact installation
CUDA driver 12.6
NVIDIA driver 535.154.5, originally for CUDA 12.2

CUDA libraries: 
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+535.154.5

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.0
- LLVM: 15.0.7

1 device:
  0: NVIDIA H100 80GB HBM3 (sm_90, 78.214 GiB / 79.647 GiB available)

the only difference with the above version is the - CUBLAS.

Error & Stacktrace ⚠️

No augmented forward pass found for cublasLtMatmulDescCreate
 at context:   %173 = call i32 @cublasLtMatmulDescCreate(i64 %bitcast_coercion, i32 %unbox32, i32 0) #469 [ "jl_roots"({} addrspace(10)* %166) ], !dbg !535

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:218
 [2] macro expansion
   @ ~/.julia/packages/CUDA/Tl08O/lib/cublas/libcublasLt.jl:400
 [3] #1158
   @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:35
 [4] retry_reclaim
   @ ~/.julia/packages/CUDA/Tl08O/src/memory.jl:434
 [5] check
   @ ~/.julia/packages/CUDA/Tl08O/lib/cublas/libcublas.jl:24
 [6] cublasLtMatmulDescCreate
   @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:34
 [7] cublaslt_matmul_fused!
   @ ~/.julia/packages/LuxLib/mR6WV/ext/LuxLibCUDAExt/cublaslt.jl:63



Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:218 [inlined]
  [2] macro expansion
    @ ~/.julia/packages/CUDA/Tl08O/lib/cublas/libcublasLt.jl:400 [inlined]
  [3] #1158
    @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:35 [inlined]
  [4] retry_reclaim
    @ ~/.julia/packages/CUDA/Tl08O/src/memory.jl:434 [inlined]
  [5] check
    @ ~/.julia/packages/CUDA/Tl08O/lib/cublas/libcublas.jl:24 [inlined]
  [6] cublasLtMatmulDescCreate
    @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:34 [inlined]
  [7] cublaslt_matmul_fused!
    @ ~/.julia/packages/LuxLib/mR6WV/ext/LuxLibCUDAExt/cublaslt.jl:63
  [8] cublaslt_matmul_fused!
    @ ~/.julia/packages/LuxLib/mR6WV/ext/LuxLibCUDAExt/cublaslt.jl:13 [inlined]
  [9] cublasLt_fused_dense!
    @ ~/.julia/packages/LuxLib/mR6WV/ext/LuxLibCUDAExt/cublaslt.jl:196
 [10] cublasLt_fused_dense!
    @ ~/.julia/packages/LuxLib/mR6WV/ext/LuxLibCUDAExt/cublaslt.jl:194 [inlined]
 [11] fused_dense!
    @ ~/.julia/packages/LuxLib/mR6WV/src/impl/dense.jl:38 [inlined]
 [12] fused_dense
    @ ~/.julia/packages/LuxLib/mR6WV/src/impl/dense.jl:24 [inlined]
 [13] fused_dense
    @ ~/.julia/packages/LuxLib/mR6WV/src/impl/dense.jl:11 [inlined]
 [14] fused_dense_bias_activation
    @ ~/.julia/packages/LuxLib/mR6WV/src/api/dense.jl:31 [inlined]
 [15] Dense
    @ ~/.julia/packages/Lux/PsW4M/src/layers/basic.jl:366
 [16] Dense
    @ ~/.julia/packages/Lux/PsW4M/src/layers/basic.jl:356
 [17] apply
    @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
 [18] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0 [inlined]
 [19] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520
 [20] Chain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:518 [inlined]
 [21] apply
    @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
 [22] dudt
    @ ./In[5]:21 [inlined]
 [23] dudt
    @ ./In[5]:18 [inlined]
 [24] ODEFunction
    @ ~/.julia/packages/SciMLBase/HReyK/src/scimlfunctions.jl:2335 [inlined]
 [25] #138
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:490 [inlined]
 [26] diffejulia__138_34195_inner_1wrap
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:0
 [27] macro expansion
    @ ~/.julia/packages/Enzyme/YWQiS/src/compiler.jl:7099 [inlined]
 [28] enzyme_call
    @ ~/.julia/packages/Enzyme/YWQiS/src/compiler.jl:6708 [inlined]
 [29] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/YWQiS/src/compiler.jl:6585 [inlined]
 [30] autodiff
    @ ~/.julia/packages/Enzyme/YWQiS/src/Enzyme.jl:320 [inlined]
    ....

the whole log please see attached.

EnzymeVJP.failed_after_update_CUDA.txt

@wsmoses
Copy link
Contributor

wsmoses commented Sep 2, 2024

oh yeah, so here the issue is we added support for some cublas routines, but I haven't seen this cublasLtMatmulDescCreate stuff before

@wsmoses
Copy link
Contributor

wsmoses commented Sep 2, 2024

@avik-pal it looks like this might be best implemented as a custom rule for fused_dense! in LuxLib.jl?

@wsmoses
Copy link
Contributor

wsmoses commented Sep 2, 2024

closing in favor of LuxDL/LuxLib.jl#148

@wsmoses wsmoses closed this as completed Sep 2, 2024
@avik-pal
Copy link

avik-pal commented Sep 8, 2024

This particular error should for cuBLASLt is not fixed upstream. (Note you need to install Lux v1 and LuxLib v1.1 for the patch)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working extensions Stuff about package extensions.
Projects
None yet
Development

No branches or pull requests

8 participants
@maleadt @wsmoses @avik-pal @yunan-l and others