Skip to content

Commit

Permalink
fix: allow numbers as input to spline layer
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 28, 2024
1 parent 81935dd commit 38eb87b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions ext/BoltzDataInterpolationsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ using DataInterpolations: AbstractInterpolation

using Boltz: Boltz, Layers, Utils

for train_grid in (true, false)
for train_grid in (true, false), tType in (AbstractVector, Number)
grid_expr = train_grid ? :(grid = ps.grid) : :(grid = st.grid)
sol_expr = tType === Number ? :(sol = interp(t)) : :(sol = interp.(t))
@eval function (spl::Layers.SplineLayer{$(train_grid), Basis})(
t::AbstractVector, ps, st) where {Basis <: AbstractInterpolation}
t::$(tType), ps, st) where {Basis <: AbstractInterpolation}
$(grid_expr)
interp = construct_basis(Basis, ps.saved_points, grid; extrapolate=true)
sol = interp.(t)
$(sol_expr)
spl.in_dims == () && return sol, st
return Utils.mapreduce_stack(sol), st
end
Expand Down

0 comments on commit 38eb87b

Please sign in to comment.