From 38eb87b885c27344e3b139eda44968976c9a9867 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 11:26:50 -0400 Subject: [PATCH] fix: allow numbers as input to spline layer --- ext/BoltzDataInterpolationsExt.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/BoltzDataInterpolationsExt.jl b/ext/BoltzDataInterpolationsExt.jl index 8dcf15c..ffb81af 100644 --- a/ext/BoltzDataInterpolationsExt.jl +++ b/ext/BoltzDataInterpolationsExt.jl @@ -4,13 +4,14 @@ using DataInterpolations: AbstractInterpolation using Boltz: Boltz, Layers, Utils -for train_grid in (true, false) +for train_grid in (true, false), tType in (AbstractVector, Number) grid_expr = train_grid ? :(grid = ps.grid) : :(grid = st.grid) + sol_expr = tType === Number ? :(sol = interp(t)) : :(sol = interp.(t)) @eval function (spl::Layers.SplineLayer{$(train_grid), Basis})( - t::AbstractVector, ps, st) where {Basis <: AbstractInterpolation} + t::$(tType), ps, st) where {Basis <: AbstractInterpolation} $(grid_expr) interp = construct_basis(Basis, ps.saved_points, grid; extrapolate=true) - sol = interp.(t) + $(sol_expr) spl.in_dims == () && return sol, st return Utils.mapreduce_stack(sol), st end