Skip to content

Commit

Permalink
Remove unsafe @cfunction wrappers CVRhsFn_wrapper and similar
Browse files Browse the repository at this point in the history
Addresses #381

This is theoretically a breaking change, however any application code
that was using the low-level API in this way (passing a Julia function into
eg CVodeInit and relying on CVRhsFn_wrapper) would likely segfault randomly
due to garbage collection of the wrapper.

The problem with CVRhsFn_wrapper and similar is that a temporary
Base.CFunction closure is returned by  @cfunction($f, ...),
which may be garbage collected hence it is not safe
to return the enclosed pointer.
This is somewhat similar to the problem with
NVector addressed by #380,
but more fundamental as the RHS function
needs to persist for the whole lifetime of the solver use.

Fix here is to remove these wrappers (which were only used by a few tests
of the low-level interface), and require that the caller either uses
the @cfunction(f, ...) form to generate a C pointer (ie f known at
compile time and the runtime Julia RHS function passed in to
Sundials C code as part of the opaque user data, this is what the
high-level interface does), or manages the lifetime of the @cfunction explicitly.
  • Loading branch information
sjdaines authored and ChrisRackauckas committed Jan 13, 2023
1 parent f38f804 commit a86c408
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 76 deletions.
58 changes: 30 additions & 28 deletions lib/libsundials_api.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# Modified from autogenerated code to fix N_Vector https://github.com/SciML/Sundials.jl/issues/367
# Modified from autogenerated code to fix memory safety:
#
# N_Vector https://github.com/SciML/Sundials.jl/issues/367
# Global edits:
# convert(N_Vector, x) -> x
# some_arg::::N_Vector -> some_arg::Union{N_Vector, NVector}
#
# Remove automatic use of @cfunction by CVRhsFn_wrapper and similar
# (this is unsafe as a C ptr is returned from the temporary @cfunction closure which may then be garbage collected)

function ARKStepCreate(fe::ARKRhsFn, fi::ARKRhsFn, t0::realtype, y0::Union{N_Vector, NVector})
ccall((:ARKStepCreate, libsundials_arkode), ARKStepMemPtr,
(ARKRhsFn, ARKRhsFn, realtype, N_Vector), fe, fi, t0, y0)
end

function ARKStepCreate(fe, fi, t0, y0)
function ARKStepCreate(fe::ARKRhsFn, fi::ARKRhsFn, t0, y0)
__y0 = convert(NVector, y0)
ARKStepCreate(ARKRhsFn_wrapper(fe), ARKRhsFn_wrapper(fi), t0, __y0)
ARKStepCreate(fe, fi, t0, __y0)
end

function ARKStepResize(arkode_mem, ynew::Union{N_Vector, NVector}, hscale::realtype, t0::realtype,
Expand All @@ -31,10 +36,9 @@ function ARKStepReInit(arkode_mem, fe::ARKRhsFn, fi::ARKRhsFn, t0::realtype, y0:
y0)
end

function ARKStepReInit(arkode_mem, fe, fi, t0, y0)
function ARKStepReInit(arkode_mem, fe::ARKRhsFn, fi::ARKRhsFn, t0, y0)
__y0 = convert(NVector, y0)
ARKStepReInit(arkode_mem, ARKRhsFn_wrapper(fe), ARKRhsFn_wrapper(fi), t0,
__y0)
ARKStepReInit(arkode_mem, fe, fi, t0, __y0)
end

function ARKStepSStolerances(arkode_mem, reltol::realtype, abstol::realtype)
Expand Down Expand Up @@ -940,9 +944,9 @@ function ERKStepCreate(f::ARKRhsFn, t0::realtype, y0::Union{N_Vector, NVector})
(ARKRhsFn, realtype, N_Vector), f, t0, y0)
end

function ERKStepCreate(f, t0, y0)
function ERKStepCreate(f::ARKRhsFn, t0, y0)
__y0 = convert(NVector, y0)
ERKStepCreate(ARKRhsFn_wrapper(f), t0, __y0)
ERKStepCreate(f, t0, __y0)
end

function ERKStepResize(arkode_mem, ynew::Union{N_Vector, NVector}, hscale::realtype, t0::realtype,
Expand All @@ -962,9 +966,9 @@ function ERKStepReInit(arkode_mem, f::ARKRhsFn, t0::realtype, y0::Union{N_Vector
(ERKStepMemPtr, ARKRhsFn, realtype, N_Vector), arkode_mem, f, t0, y0)
end

function ERKStepReInit(arkode_mem, f, t0, y0)
function ERKStepReInit(arkode_mem, f::ARKRhsFn, t0, y0)
__y0 = convert(NVector, y0)
ERKStepReInit(arkode_mem, ARKRhsFn_wrapper(f), t0, __y0)
ERKStepReInit(arkode_mem, f, t0, __y0)
end

function ERKStepSStolerances(arkode_mem, reltol::realtype, abstol::realtype)
Expand Down Expand Up @@ -1395,10 +1399,9 @@ function MRIStepCreate(fs::ARKRhsFn, t0::realtype, y0::Union{N_Vector, NVector},
inner_step_mem)
end

function MRIStepCreate(fs, t0, y0, inner_step_id, inner_step_mem)
function MRIStepCreate(fs::ARKRhsFn, t0, y0, inner_step_id, inner_step_mem)
__y0 = convert(NVector, y0)
MRIStepCreate(ARKRhsFn_wrapper(fs), t0, __y0, inner_step_id,
inner_step_mem)
MRIStepCreate(fs, t0, __y0, inner_step_id, inner_step_mem)
end

function MRIStepResize(arkode_mem, ynew::Union{N_Vector, NVector}, t0::realtype, resize::ARKVecResizeFn,
Expand All @@ -1418,9 +1421,9 @@ function MRIStepReInit(arkode_mem, fs::ARKRhsFn, t0::realtype, y0::Union{N_Vecto
(MRIStepMemPtr, ARKRhsFn, realtype, N_Vector), arkode_mem, fs, t0, y0)
end

function MRIStepReInit(arkode_mem, fs, t0, y0)
function MRIStepReInit(arkode_mem, fs::ARKRhsFn, t0, y0)
__y0 = convert(NVector, y0)
MRIStepReInit(arkode_mem, ARKRhsFn_wrapper(fs), t0, __y0)
MRIStepReInit(arkode_mem, fs, t0, __y0)
end

function MRIStepRootInit(arkode_mem, nrtfn::Cint, g::ARKRootFn)
Expand Down Expand Up @@ -1670,9 +1673,9 @@ function CVodeInit(cvode_mem, f::CVRhsFn, t0::realtype, y0::Union{N_Vector, NVec
(CVODEMemPtr, CVRhsFn, realtype, N_Vector), cvode_mem, f, t0, y0)
end

function CVodeInit(cvode_mem, f, t0, y0)
function CVodeInit(cvode_mem, f::CVRhsFn, t0, y0)
__y0 = convert(NVector, y0)
CVodeInit(cvode_mem, CVRhsFn_wrapper(f), t0, __y0)
CVodeInit(cvode_mem, f, t0, __y0)
end

function CVodeReInit(cvode_mem, t0::realtype, y0::Union{N_Vector, NVector})
Expand Down Expand Up @@ -1828,8 +1831,8 @@ function CVodeRootInit(cvode_mem, nrtfn::Cint, g::CVRootFn)
cvode_mem, nrtfn, g)
end

function CVodeRootInit(cvode_mem, nrtfn, g)
CVodeRootInit(cvode_mem, convert(Cint, nrtfn), CVRootFn_wrapper(g))
function CVodeRootInit(cvode_mem, nrtfn, g::CVRootFn)
CVodeRootInit(cvode_mem, convert(Cint, nrtfn), g)
end

function CVodeSetRootDirection(cvode_mem, rootdir)
Expand Down Expand Up @@ -3246,11 +3249,10 @@ function IDAInit(ida_mem, res::IDAResFn, t0::realtype, yy0::Union{N_Vector, NVec
(IDAMemPtr, IDAResFn, realtype, N_Vector, N_Vector), ida_mem, res, t0, yy0, yp0)
end

function IDAInit(ida_mem, res, t0, yy0, yp0)
function IDAInit(ida_mem, res::IDAResFn, t0, yy0, yp0)
__yy0 = convert(NVector, yy0)
__yp0 = convert(NVector, yp0)
IDAInit(ida_mem, IDAResFn_wrapper(res), t0, __yy0,
__yp0)
IDAInit(ida_mem, res, t0, __yy0, __yp0)
end

function IDAReInit(ida_mem, t0::realtype, yy0::Union{N_Vector, NVector}, yp0::Union{N_Vector, NVector})
Expand Down Expand Up @@ -3456,8 +3458,8 @@ function IDARootInit(ida_mem, nrtfn::Cint, g::IDARootFn)
nrtfn, g)
end

function IDARootInit(ida_mem, nrtfn, g)
IDARootInit(ida_mem, convert(Cint, nrtfn), IDARootFn_wrapper(g))
function IDARootInit(ida_mem, nrtfn, g::IDARootFn)
IDARootInit(ida_mem, convert(Cint, nrtfn), g)
end

function IDASetRootDirection(ida_mem, rootdir)
Expand Down Expand Up @@ -4875,9 +4877,9 @@ function KINInit(kinmem, func::KINSysFn, tmpl::Union{N_Vector, NVector})
func, tmpl)
end

function KINInit(kinmem, func, tmpl)
function KINInit(kinmem, func::KINSysFn, tmpl)
__tmpl = convert(NVector, tmpl)
KINInit(kinmem, KINSysFn_wrapper(func), __tmpl)
KINInit(kinmem, func, __tmpl)
end

function KINSol(kinmem, uu::Union{N_Vector, NVector}, strategy::Cint, u_scale::Union{N_Vector, NVector}, f_scale::Union{N_Vector, NVector})
Expand Down Expand Up @@ -5063,8 +5065,8 @@ function KINSetSysFunc(kinmem, func::KINSysFn)
ccall((:KINSetSysFunc, libsundials_kinsol), Cint, (KINMemPtr, KINSysFn), kinmem, func)
end

function KINSetSysFunc(kinmem, func)
KINSetSysFunc(kinmem, KINSysFn_wrapper(func))
function KINSetSysFunc(kinmem, func::KINSysFn)
KINSetSysFunc(kinmem, func)
end

function KINGetWorkSpace(kinmem, lenrw, leniw)
Expand Down
33 changes: 0 additions & 33 deletions src/types_and_consts_additions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,6 @@ function Base.convert(::Type{Matrix}, J::DlsMat)
unsafe_wrap(Array, _dlsmat.data, (_dlsmat.M, _dlsmat.N); own = false)
end

CVRhsFn_wrapper(fp::CVRhsFn) = fp
CVRhsFn_wrapper(f) = @cfunction($f, Cint, (realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr

ARKRhsFn_wrapper(fp::ARKRhsFn) = fp
ARKRhsFn_wrapper(f) = @cfunction($f, Cint, (realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr

CVRootFn_wrapper(fp::CVRootFn) = fp
function CVRootFn_wrapper(f)
@cfunction($f, Cint,
(realtype, N_Vector, Ptr{realtype}, Ptr{Cvoid})).ptr
end

CVQuadRhsFn_wrapper(fp::CVQuadRhsFn) = fp
function CVQuadRhsFn_wrapper(f)
@cfunction($f, Cint,
(realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr
end

IDAResFn_wrapper(fp::IDAResFn) = fp
function IDAResFn_wrapper(f)
@cfunction($f, Cint,
(realtype, N_Vector, N_Vector, N_Vector, Ptr{Cvoid})).ptr
end

IDARootFn_wrapper(fp::IDARootFn) = fp
function IDARootFn_wrapper(f)
@cfunction($f, Cint,
(realtype, N_Vector, N_Vector, Ptr{realtype}, Ptr{Cvoid})).ptr
end

KINSysFn_wrapper(fp::KINSysFn) = fp
KINSysFn_wrapper(f) = @cfunction($f, Cint, (N_Vector, N_Vector, Ptr{Cvoid})).ptr

function Base.convert(::Type{Matrix}, J::SUNMatrix)
_sunmat = unsafe_load(J)
_mat = convert(SUNMatrixContent_Dense, _sunmat.content)
Expand Down
4 changes: 3 additions & 1 deletion test/arkstep_Roberts_dns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ function f(t, y_nv, ydot_nv, user_data)
return Sundials.ARK_SUCCESS
end

f_C = @cfunction(f, Cint, (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cvoid}))

neq = 3

t0 = 0.0
Expand All @@ -23,7 +25,7 @@ abstol = 1e-11
userdata = nothing
h0 = 1e-4 * reltol

mem_ptr = Sundials.ARKStepCreate(C_NULL, f, t0, y0)
mem_ptr = Sundials.ARKStepCreate(C_NULL, f_C, t0, y0)
arkStep_mem = Sundials.Handle(mem_ptr)
Sundials.@checkflag Sundials.ARKStepSetInitStep(arkStep_mem, h0)
Sundials.@checkflag Sundials.ARKStepSetMaxErrTestFails(arkStep_mem, 20)
Expand Down
9 changes: 6 additions & 3 deletions test/cvode_Roberts_dns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ function g(t, y_nv, gout_ptr, user_data)
return Sundials.CV_SUCCESS
end

g_C = @cfunction(g, Cint,
(Sundials.realtype, Sundials.N_Vector, Ptr{Sundials.realtype}, Ptr{Cvoid}))

## Jacobian routine. Compute J(t,y) = df/dy.
# broken -- needs a wrapper from Sundials._DlsMat to Matrix and Jac user function wrapper
function Jac(N, t, ny, fy, Jptr, user_data, tmp1, tmp2, tmp3)
Expand Down Expand Up @@ -61,10 +64,10 @@ function getcfunrob(userfun::T) where {T}
(Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ref{T}))
end

Sundials.CVodeInit(cvode_mem, getcfunrob(userfun), t1, convert(Sundials.NVector, y0))
Sundials.@checkflag Sundials.CVodeInit(cvode_mem, f, t0, y0)
Sundials.@checkflag Sundials.CVodeInit(cvode_mem, getcfunrob(userfun), t1, convert(Sundials.NVector, y0))
Sundials.@checkflag Sundials.CVodeInit(cvode_mem, getcfunrob(userfun), t0, y0)
Sundials.@checkflag Sundials.CVodeSVtolerances(cvode_mem, reltol, abstol)
Sundials.@checkflag Sundials.CVodeRootInit(cvode_mem, 2, g)
Sundials.@checkflag Sundials.CVodeRootInit(cvode_mem, 2, g_C)
A = Sundials.SUNDenseMatrix(neq, neq)
mat_handle = Sundials.MatrixHandle(A, Sundials.DenseMatrix())
LS = Sundials.SUNLinSol_Dense(convert(Sundials.NVector, y0), A)
Expand Down
4 changes: 3 additions & 1 deletion test/erkstep_nonlin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ function f(t, y, ydot, user_data)
return Sundials.ARK_SUCCESS
end

mem_ptr = Sundials.ERKStepCreate(f, t0, y0)
f_C = @cfunction(f, Cint, (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cvoid}))

mem_ptr = Sundials.ERKStepCreate(f_C, t0, y0)
erkStep_mem = Sundials.Handle(mem_ptr)
Sundials.@checkflag Sundials.ERKStepSStolerances(erkStep_mem, reltol, abstol)

Expand Down
10 changes: 8 additions & 2 deletions test/ida_Roberts_dns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ function resrob(tres, yy_nv, yp_nv, rr_nv, user_data)
return Sundials.IDA_SUCCESS
end

resrob_C = @cfunction(resrob, Cint,
(Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cvoid}))

## Root function routine. Compute functions g_i(t,y) for i = 0,1.
function grob(t, yy_nv, yp_nv, gout_ptr, user_data)
yy = convert(Vector, yy_nv)
Expand All @@ -53,6 +56,9 @@ function grob(t, yy_nv, yp_nv, gout_ptr, user_data)
return Sundials.IDA_SUCCESS
end

grob_C = @cfunction(grob, Cint,
(Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ptr{Sundials.realtype}, Ptr{Cvoid}))

## Define the Jacobian function. BROKEN - JJ is wrong
function jacrob(Neq, tt, cj, yy, yp, resvec, JJ, user_data, tempv1, tempv2, tempv3)
JJ = pointer_to_array(convert(Ptr{Float64}, JJ), (3, 3))
Expand All @@ -78,11 +84,11 @@ avtol = [1e-8, 1e-14, 1e-6]
tout1 = 0.4

mem = Sundials.IDACreate()
Sundials.@checkflag Sundials.IDAInit(mem, resrob, t0, yy0, yp0)
Sundials.@checkflag Sundials.IDAInit(mem, resrob_C, t0, yy0, yp0)
Sundials.@checkflag Sundials.IDASVtolerances(mem, rtol, avtol)

## Call IDARootInit to specify the root function grob with 2 components
Sundials.@checkflag Sundials.IDARootInit(mem, 2, grob)
Sundials.@checkflag Sundials.IDARootInit(mem, 2, grob_C)

## Call IDADense and set up the linear solver.
A = Sundials.SUNDenseMatrix(length(y0), length(y0))
Expand Down
4 changes: 3 additions & 1 deletion test/kinsol_mkinTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ function sysfn(y_nv, fy_nv, a_in)
return Sundials.KIN_SUCCESS
end

sysfn_C = @cfunction(sysfn, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ptr{Cvoid}))

## Initialize problem
neq = 2
kmem = Sundials.KINCreate()
Sundials.@checkflag Sundials.KINSetFuncNormTol(kmem, 1.0e-5)
Sundials.@checkflag Sundials.KINSetScaledStepTol(kmem, 1.0e-4)
Sundials.@checkflag Sundials.KINSetMaxSetupCalls(kmem, 1)
y = ones(neq)
Sundials.@checkflag Sundials.KINInit(kmem, sysfn, y)
Sundials.@checkflag Sundials.KINInit(kmem, sysfn_C, y)
A = Sundials.SUNDenseMatrix(length(y), length(y))
LS = Sundials.SUNLinSol_Dense(y, A)
Sundials.@checkflag Sundials.KINDlsSetLinearSolver(kmem, LS, A)
Expand Down
11 changes: 4 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
using Sundials
using Test

# Some tests use @cfunction, but at the moment that isn't supported on non-Intel platforms
const SUPPORT_CFUNCTION = Sys.ARCH (:aarch64, :ppc64le, :powerpc64le) ||
startswith(lowercase(String(Sys.ARCH)), "arm")
@testset "Generator" begin include("generator.jl") end

@testset "CVODE" begin
@testset "Roberts CVODE Simplified" begin include("cvode_Roberts_simplified.jl") end
SUPPORT_CFUNCTION && @testset "Roberts CVODE Direct" begin include("cvode_Roberts_dns.jl") end
@testset "Roberts CVODE Direct" begin include("cvode_Roberts_dns.jl") end
#@testset "CVODES Direct" begin include("cvodes_dns.jl") end
end

@testset "IDA" begin
@testset "Roberts IDA Simplified" begin include("ida_Roberts_simplified.jl") end
SUPPORT_CFUNCTION && @testset "Roberts IDA Direct" begin include("ida_Roberts_dns.jl") end
@testset "Roberts IDA Direct" begin include("ida_Roberts_dns.jl") end
@testset "Heat IDA Direct" begin include("ida_Heat2D.jl") end
# Commented out because still uses the syntax from Grid which is a deprecated package
#@testset "Cable IDA Direct" begin include("ida_Cable.jl") end
end

SUPPORT_CFUNCTION && @testset "ARK" begin
@testset "ARK" begin
@testset "Roberts ARKStep Direct" begin include("arkstep_Roberts_dns.jl") end
@testset "NonLinear ERKStep Direct" begin include("erkstep_nonlin.jl") end
#@testset "MRI two way couple" begin include("mri_twowaycouple.jl") end
end

@testset "Kinsol" begin
@testset "Kinsol Simplified" begin include("kinsol_mkin_simplified.jl") end
SUPPORT_CFUNCTION && @testset "Kinsol MKin" begin include("kinsol_mkinTest.jl") end
@testset "Kinsol MKin" begin include("kinsol_mkinTest.jl") end
@testset "Kinsol Banded" begin include("kinsol_banded.jl") end
end
@testset "Handle Tests" begin include("handle_tests.jl") end
Expand Down

0 comments on commit a86c408

Please sign in to comment.