From a86c40845d00e94164fe857133a54bc1d8fee9af Mon Sep 17 00:00:00 2001 From: Daines Date: Fri, 13 Jan 2023 15:53:03 +0000 Subject: [PATCH] Remove unsafe @cfunction wrappers CVRhsFn_wrapper and similar Addresses https://github.com/SciML/Sundials.jl/issues/381 This is theoretically a breaking change, however any application code that was using the low-level API in this way (passing a Julia function into eg CVodeInit and relying on CVRhsFn_wrapper) would likely segfault randomly due to garbage collection of the wrapper. The problem with CVRhsFn_wrapper and similar is that a temporary Base.CFunction closure is returned by @cfunction($f, ...), which may be garbage collected hence it is not safe to return the enclosed pointer. This is somewhat similar to the problem with NVector addressed by https://github.com/SciML/Sundials.jl/pull/380, but more fundamental as the RHS function needs to persist for the whole lifetime of the solver use. Fix here is to remove these wrappers (which were only used by a few tests of the low-level interface), and require that the caller either uses the @cfunction(f, ...) form to generate a C pointer (ie f known at compile time and the runtime Julia RHS function passed in to Sundials C code as part of the opaque user data, this is what the high-level interface does), or manages the lifetime of the @cfunction explicitly. --- lib/libsundials_api.jl | 58 ++++++++++++++++--------------- src/types_and_consts_additions.jl | 33 ------------------ test/arkstep_Roberts_dns.jl | 4 ++- test/cvode_Roberts_dns.jl | 9 +++-- test/erkstep_nonlin.jl | 4 ++- test/ida_Roberts_dns.jl | 10 ++++-- test/kinsol_mkinTest.jl | 4 ++- test/runtests.jl | 11 +++--- 8 files changed, 57 insertions(+), 76 deletions(-) diff --git a/lib/libsundials_api.jl b/lib/libsundials_api.jl index faa69e3..e32fb8b 100644 --- a/lib/libsundials_api.jl +++ b/lib/libsundials_api.jl @@ -1,16 +1,21 @@ -# Modified from autogenerated code to fix N_Vector https://github.com/SciML/Sundials.jl/issues/367 +# Modified from autogenerated code to fix memory safety: +# +# N_Vector https://github.com/SciML/Sundials.jl/issues/367 # Global edits: # convert(N_Vector, x) -> x # some_arg::::N_Vector -> some_arg::Union{N_Vector, NVector} +# +# Remove automatic use of @cfunction by CVRhsFn_wrapper and similar +# (this is unsafe as a C ptr is returned from the temporary @cfunction closure which may then be garbage collected) function ARKStepCreate(fe::ARKRhsFn, fi::ARKRhsFn, t0::realtype, y0::Union{N_Vector, NVector}) ccall((:ARKStepCreate, libsundials_arkode), ARKStepMemPtr, (ARKRhsFn, ARKRhsFn, realtype, N_Vector), fe, fi, t0, y0) end -function ARKStepCreate(fe, fi, t0, y0) +function ARKStepCreate(fe::ARKRhsFn, fi::ARKRhsFn, t0, y0) __y0 = convert(NVector, y0) - ARKStepCreate(ARKRhsFn_wrapper(fe), ARKRhsFn_wrapper(fi), t0, __y0) + ARKStepCreate(fe, fi, t0, __y0) end function ARKStepResize(arkode_mem, ynew::Union{N_Vector, NVector}, hscale::realtype, t0::realtype, @@ -31,10 +36,9 @@ function ARKStepReInit(arkode_mem, fe::ARKRhsFn, fi::ARKRhsFn, t0::realtype, y0: y0) end -function ARKStepReInit(arkode_mem, fe, fi, t0, y0) +function ARKStepReInit(arkode_mem, fe::ARKRhsFn, fi::ARKRhsFn, t0, y0) __y0 = convert(NVector, y0) - ARKStepReInit(arkode_mem, ARKRhsFn_wrapper(fe), ARKRhsFn_wrapper(fi), t0, - __y0) + ARKStepReInit(arkode_mem, fe, fi, t0, __y0) end function ARKStepSStolerances(arkode_mem, reltol::realtype, abstol::realtype) @@ -940,9 +944,9 @@ function ERKStepCreate(f::ARKRhsFn, t0::realtype, y0::Union{N_Vector, NVector}) (ARKRhsFn, realtype, N_Vector), f, t0, y0) end -function ERKStepCreate(f, t0, y0) +function ERKStepCreate(f::ARKRhsFn, t0, y0) __y0 = convert(NVector, y0) - ERKStepCreate(ARKRhsFn_wrapper(f), t0, __y0) + ERKStepCreate(f, t0, __y0) end function ERKStepResize(arkode_mem, ynew::Union{N_Vector, NVector}, hscale::realtype, t0::realtype, @@ -962,9 +966,9 @@ function ERKStepReInit(arkode_mem, f::ARKRhsFn, t0::realtype, y0::Union{N_Vector (ERKStepMemPtr, ARKRhsFn, realtype, N_Vector), arkode_mem, f, t0, y0) end -function ERKStepReInit(arkode_mem, f, t0, y0) +function ERKStepReInit(arkode_mem, f::ARKRhsFn, t0, y0) __y0 = convert(NVector, y0) - ERKStepReInit(arkode_mem, ARKRhsFn_wrapper(f), t0, __y0) + ERKStepReInit(arkode_mem, f, t0, __y0) end function ERKStepSStolerances(arkode_mem, reltol::realtype, abstol::realtype) @@ -1395,10 +1399,9 @@ function MRIStepCreate(fs::ARKRhsFn, t0::realtype, y0::Union{N_Vector, NVector}, inner_step_mem) end -function MRIStepCreate(fs, t0, y0, inner_step_id, inner_step_mem) +function MRIStepCreate(fs::ARKRhsFn, t0, y0, inner_step_id, inner_step_mem) __y0 = convert(NVector, y0) - MRIStepCreate(ARKRhsFn_wrapper(fs), t0, __y0, inner_step_id, - inner_step_mem) + MRIStepCreate(fs, t0, __y0, inner_step_id, inner_step_mem) end function MRIStepResize(arkode_mem, ynew::Union{N_Vector, NVector}, t0::realtype, resize::ARKVecResizeFn, @@ -1418,9 +1421,9 @@ function MRIStepReInit(arkode_mem, fs::ARKRhsFn, t0::realtype, y0::Union{N_Vecto (MRIStepMemPtr, ARKRhsFn, realtype, N_Vector), arkode_mem, fs, t0, y0) end -function MRIStepReInit(arkode_mem, fs, t0, y0) +function MRIStepReInit(arkode_mem, fs::ARKRhsFn, t0, y0) __y0 = convert(NVector, y0) - MRIStepReInit(arkode_mem, ARKRhsFn_wrapper(fs), t0, __y0) + MRIStepReInit(arkode_mem, fs, t0, __y0) end function MRIStepRootInit(arkode_mem, nrtfn::Cint, g::ARKRootFn) @@ -1670,9 +1673,9 @@ function CVodeInit(cvode_mem, f::CVRhsFn, t0::realtype, y0::Union{N_Vector, NVec (CVODEMemPtr, CVRhsFn, realtype, N_Vector), cvode_mem, f, t0, y0) end -function CVodeInit(cvode_mem, f, t0, y0) +function CVodeInit(cvode_mem, f::CVRhsFn, t0, y0) __y0 = convert(NVector, y0) - CVodeInit(cvode_mem, CVRhsFn_wrapper(f), t0, __y0) + CVodeInit(cvode_mem, f, t0, __y0) end function CVodeReInit(cvode_mem, t0::realtype, y0::Union{N_Vector, NVector}) @@ -1828,8 +1831,8 @@ function CVodeRootInit(cvode_mem, nrtfn::Cint, g::CVRootFn) cvode_mem, nrtfn, g) end -function CVodeRootInit(cvode_mem, nrtfn, g) - CVodeRootInit(cvode_mem, convert(Cint, nrtfn), CVRootFn_wrapper(g)) +function CVodeRootInit(cvode_mem, nrtfn, g::CVRootFn) + CVodeRootInit(cvode_mem, convert(Cint, nrtfn), g) end function CVodeSetRootDirection(cvode_mem, rootdir) @@ -3246,11 +3249,10 @@ function IDAInit(ida_mem, res::IDAResFn, t0::realtype, yy0::Union{N_Vector, NVec (IDAMemPtr, IDAResFn, realtype, N_Vector, N_Vector), ida_mem, res, t0, yy0, yp0) end -function IDAInit(ida_mem, res, t0, yy0, yp0) +function IDAInit(ida_mem, res::IDAResFn, t0, yy0, yp0) __yy0 = convert(NVector, yy0) __yp0 = convert(NVector, yp0) - IDAInit(ida_mem, IDAResFn_wrapper(res), t0, __yy0, - __yp0) + IDAInit(ida_mem, res, t0, __yy0, __yp0) end function IDAReInit(ida_mem, t0::realtype, yy0::Union{N_Vector, NVector}, yp0::Union{N_Vector, NVector}) @@ -3456,8 +3458,8 @@ function IDARootInit(ida_mem, nrtfn::Cint, g::IDARootFn) nrtfn, g) end -function IDARootInit(ida_mem, nrtfn, g) - IDARootInit(ida_mem, convert(Cint, nrtfn), IDARootFn_wrapper(g)) +function IDARootInit(ida_mem, nrtfn, g::IDARootFn) + IDARootInit(ida_mem, convert(Cint, nrtfn), g) end function IDASetRootDirection(ida_mem, rootdir) @@ -4875,9 +4877,9 @@ function KINInit(kinmem, func::KINSysFn, tmpl::Union{N_Vector, NVector}) func, tmpl) end -function KINInit(kinmem, func, tmpl) +function KINInit(kinmem, func::KINSysFn, tmpl) __tmpl = convert(NVector, tmpl) - KINInit(kinmem, KINSysFn_wrapper(func), __tmpl) + KINInit(kinmem, func, __tmpl) end function KINSol(kinmem, uu::Union{N_Vector, NVector}, strategy::Cint, u_scale::Union{N_Vector, NVector}, f_scale::Union{N_Vector, NVector}) @@ -5063,8 +5065,8 @@ function KINSetSysFunc(kinmem, func::KINSysFn) ccall((:KINSetSysFunc, libsundials_kinsol), Cint, (KINMemPtr, KINSysFn), kinmem, func) end -function KINSetSysFunc(kinmem, func) - KINSetSysFunc(kinmem, KINSysFn_wrapper(func)) +function KINSetSysFunc(kinmem, func::KINSysFn) + KINSetSysFunc(kinmem, func) end function KINGetWorkSpace(kinmem, lenrw, leniw) diff --git a/src/types_and_consts_additions.jl b/src/types_and_consts_additions.jl index cc14673..c9618bc 100644 --- a/src/types_and_consts_additions.jl +++ b/src/types_and_consts_additions.jl @@ -6,39 +6,6 @@ function Base.convert(::Type{Matrix}, J::DlsMat) unsafe_wrap(Array, _dlsmat.data, (_dlsmat.M, _dlsmat.N); own = false) end -CVRhsFn_wrapper(fp::CVRhsFn) = fp -CVRhsFn_wrapper(f) = @cfunction($f, Cint, (realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr - -ARKRhsFn_wrapper(fp::ARKRhsFn) = fp -ARKRhsFn_wrapper(f) = @cfunction($f, Cint, (realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr - -CVRootFn_wrapper(fp::CVRootFn) = fp -function CVRootFn_wrapper(f) - @cfunction($f, Cint, - (realtype, N_Vector, Ptr{realtype}, Ptr{Cvoid})).ptr -end - -CVQuadRhsFn_wrapper(fp::CVQuadRhsFn) = fp -function CVQuadRhsFn_wrapper(f) - @cfunction($f, Cint, - (realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr -end - -IDAResFn_wrapper(fp::IDAResFn) = fp -function IDAResFn_wrapper(f) - @cfunction($f, Cint, - (realtype, N_Vector, N_Vector, N_Vector, Ptr{Cvoid})).ptr -end - -IDARootFn_wrapper(fp::IDARootFn) = fp -function IDARootFn_wrapper(f) - @cfunction($f, Cint, - (realtype, N_Vector, N_Vector, Ptr{realtype}, Ptr{Cvoid})).ptr -end - -KINSysFn_wrapper(fp::KINSysFn) = fp -KINSysFn_wrapper(f) = @cfunction($f, Cint, (N_Vector, N_Vector, Ptr{Cvoid})).ptr - function Base.convert(::Type{Matrix}, J::SUNMatrix) _sunmat = unsafe_load(J) _mat = convert(SUNMatrixContent_Dense, _sunmat.content) diff --git a/test/arkstep_Roberts_dns.jl b/test/arkstep_Roberts_dns.jl index 923c0cf..7255deb 100644 --- a/test/arkstep_Roberts_dns.jl +++ b/test/arkstep_Roberts_dns.jl @@ -11,6 +11,8 @@ function f(t, y_nv, ydot_nv, user_data) return Sundials.ARK_SUCCESS end +f_C = @cfunction(f, Cint, (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cvoid})) + neq = 3 t0 = 0.0 @@ -23,7 +25,7 @@ abstol = 1e-11 userdata = nothing h0 = 1e-4 * reltol -mem_ptr = Sundials.ARKStepCreate(C_NULL, f, t0, y0) +mem_ptr = Sundials.ARKStepCreate(C_NULL, f_C, t0, y0) arkStep_mem = Sundials.Handle(mem_ptr) Sundials.@checkflag Sundials.ARKStepSetInitStep(arkStep_mem, h0) Sundials.@checkflag Sundials.ARKStepSetMaxErrTestFails(arkStep_mem, 20) diff --git a/test/cvode_Roberts_dns.jl b/test/cvode_Roberts_dns.jl index 5426514..a9bce46 100644 --- a/test/cvode_Roberts_dns.jl +++ b/test/cvode_Roberts_dns.jl @@ -21,6 +21,9 @@ function g(t, y_nv, gout_ptr, user_data) return Sundials.CV_SUCCESS end +g_C = @cfunction(g, Cint, + (Sundials.realtype, Sundials.N_Vector, Ptr{Sundials.realtype}, Ptr{Cvoid})) + ## Jacobian routine. Compute J(t,y) = df/dy. # broken -- needs a wrapper from Sundials._DlsMat to Matrix and Jac user function wrapper function Jac(N, t, ny, fy, Jptr, user_data, tmp1, tmp2, tmp3) @@ -61,10 +64,10 @@ function getcfunrob(userfun::T) where {T} (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ref{T})) end -Sundials.CVodeInit(cvode_mem, getcfunrob(userfun), t1, convert(Sundials.NVector, y0)) -Sundials.@checkflag Sundials.CVodeInit(cvode_mem, f, t0, y0) +Sundials.@checkflag Sundials.CVodeInit(cvode_mem, getcfunrob(userfun), t1, convert(Sundials.NVector, y0)) +Sundials.@checkflag Sundials.CVodeInit(cvode_mem, getcfunrob(userfun), t0, y0) Sundials.@checkflag Sundials.CVodeSVtolerances(cvode_mem, reltol, abstol) -Sundials.@checkflag Sundials.CVodeRootInit(cvode_mem, 2, g) +Sundials.@checkflag Sundials.CVodeRootInit(cvode_mem, 2, g_C) A = Sundials.SUNDenseMatrix(neq, neq) mat_handle = Sundials.MatrixHandle(A, Sundials.DenseMatrix()) LS = Sundials.SUNLinSol_Dense(convert(Sundials.NVector, y0), A) diff --git a/test/erkstep_nonlin.jl b/test/erkstep_nonlin.jl index 475b174..fbcced5 100644 --- a/test/erkstep_nonlin.jl +++ b/test/erkstep_nonlin.jl @@ -45,7 +45,9 @@ function f(t, y, ydot, user_data) return Sundials.ARK_SUCCESS end -mem_ptr = Sundials.ERKStepCreate(f, t0, y0) +f_C = @cfunction(f, Cint, (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cvoid})) + +mem_ptr = Sundials.ERKStepCreate(f_C, t0, y0) erkStep_mem = Sundials.Handle(mem_ptr) Sundials.@checkflag Sundials.ERKStepSStolerances(erkStep_mem, reltol, abstol) diff --git a/test/ida_Roberts_dns.jl b/test/ida_Roberts_dns.jl index 391aba5..ef519ca 100644 --- a/test/ida_Roberts_dns.jl +++ b/test/ida_Roberts_dns.jl @@ -44,6 +44,9 @@ function resrob(tres, yy_nv, yp_nv, rr_nv, user_data) return Sundials.IDA_SUCCESS end +resrob_C = @cfunction(resrob, Cint, + (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ptr{Cvoid})) + ## Root function routine. Compute functions g_i(t,y) for i = 0,1. function grob(t, yy_nv, yp_nv, gout_ptr, user_data) yy = convert(Vector, yy_nv) @@ -53,6 +56,9 @@ function grob(t, yy_nv, yp_nv, gout_ptr, user_data) return Sundials.IDA_SUCCESS end +grob_C = @cfunction(grob, Cint, + (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ptr{Sundials.realtype}, Ptr{Cvoid})) + ## Define the Jacobian function. BROKEN - JJ is wrong function jacrob(Neq, tt, cj, yy, yp, resvec, JJ, user_data, tempv1, tempv2, tempv3) JJ = pointer_to_array(convert(Ptr{Float64}, JJ), (3, 3)) @@ -78,11 +84,11 @@ avtol = [1e-8, 1e-14, 1e-6] tout1 = 0.4 mem = Sundials.IDACreate() -Sundials.@checkflag Sundials.IDAInit(mem, resrob, t0, yy0, yp0) +Sundials.@checkflag Sundials.IDAInit(mem, resrob_C, t0, yy0, yp0) Sundials.@checkflag Sundials.IDASVtolerances(mem, rtol, avtol) ## Call IDARootInit to specify the root function grob with 2 components -Sundials.@checkflag Sundials.IDARootInit(mem, 2, grob) +Sundials.@checkflag Sundials.IDARootInit(mem, 2, grob_C) ## Call IDADense and set up the linear solver. A = Sundials.SUNDenseMatrix(length(y0), length(y0)) diff --git a/test/kinsol_mkinTest.jl b/test/kinsol_mkinTest.jl index 43b6fa1..442f861 100644 --- a/test/kinsol_mkinTest.jl +++ b/test/kinsol_mkinTest.jl @@ -20,6 +20,8 @@ function sysfn(y_nv, fy_nv, a_in) return Sundials.KIN_SUCCESS end +sysfn_C = @cfunction(sysfn, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ptr{Cvoid})) + ## Initialize problem neq = 2 kmem = Sundials.KINCreate() @@ -27,7 +29,7 @@ Sundials.@checkflag Sundials.KINSetFuncNormTol(kmem, 1.0e-5) Sundials.@checkflag Sundials.KINSetScaledStepTol(kmem, 1.0e-4) Sundials.@checkflag Sundials.KINSetMaxSetupCalls(kmem, 1) y = ones(neq) -Sundials.@checkflag Sundials.KINInit(kmem, sysfn, y) +Sundials.@checkflag Sundials.KINInit(kmem, sysfn_C, y) A = Sundials.SUNDenseMatrix(length(y), length(y)) LS = Sundials.SUNLinSol_Dense(y, A) Sundials.@checkflag Sundials.KINDlsSetLinearSolver(kmem, LS, A) diff --git a/test/runtests.jl b/test/runtests.jl index 7911c11..ff423d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,26 +1,23 @@ using Sundials using Test -# Some tests use @cfunction, but at the moment that isn't supported on non-Intel platforms -const SUPPORT_CFUNCTION = Sys.ARCH ∉ (:aarch64, :ppc64le, :powerpc64le) || - startswith(lowercase(String(Sys.ARCH)), "arm") @testset "Generator" begin include("generator.jl") end @testset "CVODE" begin @testset "Roberts CVODE Simplified" begin include("cvode_Roberts_simplified.jl") end - SUPPORT_CFUNCTION && @testset "Roberts CVODE Direct" begin include("cvode_Roberts_dns.jl") end + @testset "Roberts CVODE Direct" begin include("cvode_Roberts_dns.jl") end #@testset "CVODES Direct" begin include("cvodes_dns.jl") end end @testset "IDA" begin @testset "Roberts IDA Simplified" begin include("ida_Roberts_simplified.jl") end - SUPPORT_CFUNCTION && @testset "Roberts IDA Direct" begin include("ida_Roberts_dns.jl") end + @testset "Roberts IDA Direct" begin include("ida_Roberts_dns.jl") end @testset "Heat IDA Direct" begin include("ida_Heat2D.jl") end # Commented out because still uses the syntax from Grid which is a deprecated package #@testset "Cable IDA Direct" begin include("ida_Cable.jl") end end -SUPPORT_CFUNCTION && @testset "ARK" begin +@testset "ARK" begin @testset "Roberts ARKStep Direct" begin include("arkstep_Roberts_dns.jl") end @testset "NonLinear ERKStep Direct" begin include("erkstep_nonlin.jl") end #@testset "MRI two way couple" begin include("mri_twowaycouple.jl") end @@ -28,7 +25,7 @@ end @testset "Kinsol" begin @testset "Kinsol Simplified" begin include("kinsol_mkin_simplified.jl") end - SUPPORT_CFUNCTION && @testset "Kinsol MKin" begin include("kinsol_mkinTest.jl") end + @testset "Kinsol MKin" begin include("kinsol_mkinTest.jl") end @testset "Kinsol Banded" begin include("kinsol_banded.jl") end end @testset "Handle Tests" begin include("handle_tests.jl") end