From 70c9d22b15bb238bfd2b2c6438302ac7cd6f3d86 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 12:02:41 -0400 Subject: [PATCH] fix: allow vectors as input to hamiltonian nn --- src/layers/hamiltonian.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/layers/hamiltonian.jl b/src/layers/hamiltonian.jl index 5cad99d..66e8dd9 100644 --- a/src/layers/hamiltonian.jl +++ b/src/layers/hamiltonian.jl @@ -64,6 +64,11 @@ end hamiltonian_forward(::AutoForwardDiff, model, x) = ForwardDiff.gradient(sum ∘ model, x) +function (hnn::HamiltonianNN)(x::AbstractVector, ps, st) + y, stₙ = hnn(reshape(x, :, 1), ps, st) + return vec(y), stₙ +end + function (hnn::HamiltonianNN{FST})(x::AbstractArray{T, N}, ps, st) where {FST, T, N} model = StatefulLuxLayer{FST}(hnn.model, ps, st.model)