diff --git a/src/layers/hamiltonian.jl b/src/layers/hamiltonian.jl index 5cad99d..66e8dd9 100644 --- a/src/layers/hamiltonian.jl +++ b/src/layers/hamiltonian.jl @@ -64,6 +64,11 @@ end hamiltonian_forward(::AutoForwardDiff, model, x) = ForwardDiff.gradient(sum ∘ model, x) +function (hnn::HamiltonianNN)(x::AbstractVector, ps, st) + y, stₙ = hnn(reshape(x, :, 1), ps, st) + return vec(y), stₙ +end + function (hnn::HamiltonianNN{FST})(x::AbstractArray{T, N}, ps, st) where {FST, T, N} model = StatefulLuxLayer{FST}(hnn.model, ps, st.model)