Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true) ' failed #1749

Closed
yunan-l opened this issue Aug 23, 2024 · 1 comment

Comments

@yunan-l
Copy link

yunan-l commented Aug 23, 2024

Hi,
when running

struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, D, K} <:
       Lux.AbstractExplicitContainerLayer{(:model,)}
    model::M
    solver::So
    tspan::T
    device::D
    kwargs::K
end

function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), gpu=nothing, kwargs...)
    device = DetermineDevice(gpu=gpu)
    NeuralODE{typeof(model), typeof(solver), typeof(tspan), typeof(device), typeof(kwargs)}(model, solver, tspan, device, kwargs)
end


function (n::NeuralODE)(u0, ps, st, cb, pp)
    
    function dudt(u, p, t; st=st)
        u_, st = Lux.apply(n.model, u, p, st)
        return u_
    end
    
    prob = ODEProblem{false}(ODEFunction{false}(dudt), u0, n.tspan, ps)
    
    sensealg = get(n.kwargs, :sensealg, InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true))
    
    tsteps = n.tspan[1]:n.tspan[2]
    
    sol = solve(prob, n.solver, saveat=tsteps, callback = cb, sensealg = sensealg)
    
    return DeviceArray(n.device, Array(sol)), st
end

function train_neuralode!(model, u0, p, st, cb, pp, loss_func, opt_state, η_schedule; N_epochs=1, verbose=true, compute_initial_error::Bool=true, scheduler_offset::Int=0)
    
    best_p = copy(p)
    results = (i_epoch = Int[], train_loss=Float32[], learning_rate=Float32[], duration=Float32[], valid_loss=Float32[], test_loss=Float32[], loss_min=[Inf32], i_epoch_min=[1])
    
    progress = Progress(N_epochs, 1)
    
    # initial error 
    lowest_train_err = compute_initial_error ? loss_func(model, u0, p, st, cb, pp) : Inf

    for i_epoch in 1:N_epochs

        Optimisers.adjust!(opt_state, η_schedule(i_epoch + scheduler_offset)) 

        epoch_start_time = time()

        losses = zeros(Float32, 1)

        loss_p(p) = loss_func(model, u0, p, st, cb, pp)

        l, gs = Zygote.withgradient(loss_p, p)

        losses = l
        opt_state, p = Optimisers.update(opt_state, p, gs[1])


        train_err = l
        epoch_time = time() - epoch_start_time

        push!(results[:i_epoch], i_epoch)
        push!(results[:train_loss], train_err)
        push!(results[:learning_rate], η_schedule(i_epoch))
        push!(results[:duration], epoch_time)

        if i_epoch % N_epochs == 0
            monitor(model, u0, p, st, cb, pp)
        end

    end
    return model, best_p, st, results
    
end


using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, Optimisers, Random, Plots, XLSX, DataFrames, SciMLSensitivity, DiffEqCallbacks, Enzyme, CUDA, LuxCUDA, LuxDeviceUtils
Enzyme.API.runtimeActivity!(true)
CUDA.allowscalar(false)


nn = Chain(
          Dense(4, 16, tanh),  
          Dense(16, 16, tanh), 
          Dense(16, 4)

rng = Xoshiro(0)
p, st = Lux.setup(rng, nn)
p = ComponentArray(p) |> gdev
st = st |> gdev

u0 = Float32[0.0, 8.0, 0.0, 12.0] |> gdev
tspan = (0.0f0, 365.0f0) 

neural_ode = NeuralODE(nn; solver=Tsit5(), tspan = tspan, sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP(), checkpointing=true))

loss = loss_neuralode

opt = Optimisers.AdamW(1f-3, (9f-1, 9.99f-1), 1f-6)
opt_state = Optimisers.setup(opt, p)
η_schedule = SinExp(λ0=1f-3,λ1=1f-5,period=20,decay=0.975f0)

println("starting training...")
neural_de, ps, st, results_ad = train_neuralode!(neural_ode, u0, p, st, cb, pp, loss, opt_state, η_schedule; N_epochs=5, verbose=true)

get:

Enzyme execution failed.
Enzyme: unhandled augmented forward for jl_f_finalizer
Stacktrace:
  [1] finalizer
    @ ./gcutils.jl:87
  [2] _
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83
  [3] CuArray
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79
  [4] derive
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799
  [5] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319
  [6] unsafe_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314
  [7] view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310
  [8] maybeview
    @ ./views.jl:148
  [9] macro expansion
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0
 [10] _getindex
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119
 [11] getproperty
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14
 [12] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0
 [13] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520


Stacktrace:
  [1] finalizer
    @ ./gcutils.jl:87 [inlined]
  [2] _
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:83 [inlined]
  [3] CuArray
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:79 [inlined]
  [4] derive
    @ ~/.julia/packages/CUDA/Tl08O/src/array.jl:799 [inlined]
  [5] unsafe_contiguous_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:319 [inlined]
  [6] unsafe_view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:314 [inlined]
  [7] view
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/base.jl:310 [inlined]
  [8] maybeview
    @ ./views.jl:148 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:0 [inlined]
 [10] _getindex
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/array_interface.jl:119 [inlined]
 [11] getproperty
    @ ~/.julia/packages/ComponentArrays/xO4hy/src/namedtuple_interface.jl:14 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:0 [inlined]
 [13] applychain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:520
 [14] Chain
    @ ~/.julia/packages/Lux/PsW4M/src/layers/containers.jl:518 [inlined]
 [15] apply
    @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
 [16] dudt
    @ ./In[92]:24 [inlined]
 [17] dudt
    @ ./In[92]:20 [inlined]
 [18] ODEFunction
    @ ~/.julia/packages/SciMLBase/Q1klk/src/scimlfunctions.jl:2335 [inlined]
 [19] #138
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:490 [inlined]
 [20] diffejulia__138_128700_inner_1wrap
    @ ~/.julia/packages/SciMLSensitivity/se3y4/src/adjoint_common.jl:0
 [21] macro expansion
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:7049 [inlined]
 [22] enzyme_call
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6658 [inlined]
 [23] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/XGb4o/src/compiler.jl:6535 [inlined]
 [24] autodiff
    @ ~/.julia/packages/Enzyme/XGb4o/src/Enzyme.jl:320 [inlined]
    ...

the whole log please see attached.
EnzymeVJP.failed.txt

@wsmoses
Copy link
Member

wsmoses commented Aug 25, 2024

Closed in favor of JuliaGPU/CUDA.jl#2478

@wsmoses wsmoses closed this as not planned Won't fix, can't repro, duplicate, stale Aug 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants