Skip to content

Commit

Permalink
Avoid using cfunction closure (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
metab0t authored Sep 27, 2023
1 parent dce0b6c commit 7a00040
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/SQLite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ mutable struct DB <: DBInterface.Connection
file::String
handle::DBHandle
stmt_wrappers::WeakKeyDict{StmtWrapper,Nothing} # opened prepared statements
registered_UDFs::Vector{Any} # keep registered UDFs alive and not garbage collected
registered_UDF_data::Vector{Any} # keep registered UDFs alive and not garbage collected

function DB(f::AbstractString)
handle_ptr = Ref{DBHandle}()
Expand Down
64 changes: 40 additions & 24 deletions src/UDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,25 @@ end
sqlreturn(context, val::Bool) = sqlreturn(context, Int(val))
sqlreturn(context, val) = sqlreturn(context, sqlserialize(val))

mutable struct ScalarUDFData
func::Function
end

mutable struct AggregateUDFData
init::Any
step::Function
final::Function
end

function wrap_scalarfunc(
func,
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
udf_data =
unsafe_pointer_to_objref(C.sqlite3_user_data(context))::ScalarUDFData
func = udf_data.func

args = [sqlvalue(values, i) for i in 1:nargs]
ret = func(args...)
sqlreturn(context, ret)
Expand All @@ -70,12 +83,15 @@ function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int)
end

function wrap_stepfunc(
init,
func,
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
udf_data =
unsafe_pointer_to_objref(C.sqlite3_user_data(context))::AggregateUDFData
init = udf_data.init
func = udf_data.step

args = [sqlvalue(values, i) for i in 1:nargs]

intsize = sizeof(Int)
Expand Down Expand Up @@ -143,12 +159,15 @@ function wrap_stepfunc(
end

function wrap_finalfunc(
init,
func,
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
udf_data =
unsafe_pointer_to_objref(C.sqlite3_user_data(context))::AggregateUDFData
init = udf_data.init
func = udf_data.final

acptr = convert(Ptr{UInt8}, C.sqlite3_aggregate_context(context, 0))

# step function wasn't run
Expand Down Expand Up @@ -201,7 +220,7 @@ Register a scalar (first method) or aggregate (second method) function
with a [`SQLite.DB`](@ref).
"""
function register(
db,
db::DB,
func::Function;
nargs::Int = -1,
name::AbstractString = string(func),
Expand All @@ -212,11 +231,12 @@ function register(
nargs < -1 && (nargs = -1)
@assert sizeof(name) <= 255 "size of function name must be <= 255"

f =
(context, nargs, values) ->
wrap_scalarfunc(func, context, nargs, values)
cfunc = @cfunction($f, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
push!(db.registered_UDFs, cfunc)
udf_data = ScalarUDFData(func)
push!(db.registered_UDF_data, udf_data)
udf_data_ptr = pointer_from_objref(udf_data)

cfunc =
@cfunction(wrap_scalarfunc, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))

# TODO: allow the other encodings
enc = C.SQLITE_UTF8
Expand All @@ -227,7 +247,7 @@ function register(
name,
nargs,
enc,
C_NULL,
udf_data_ptr,
cfunc,
C_NULL,
C_NULL,
Expand All @@ -237,7 +257,7 @@ end

# as above but for aggregate functions
function register(
db,
db::DB,
init,
step::Function,
final::Function = identity;
Expand All @@ -249,16 +269,12 @@ function register(
nargs < -1 && (nargs = -1)
@assert sizeof(name) <= 255 "size of function name must be <= 255 chars"

s =
(context, nargs, values) ->
wrap_stepfunc(init, step, context, nargs, values)
cs = @cfunction($s, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
f =
(context, nargs, values) ->
wrap_finalfunc(init, final, context, nargs, values)
cf = @cfunction($f, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
push!(db.registered_UDFs, cs)
push!(db.registered_UDFs, cf)
udf_data = AggregateUDFData(init, step, final)
push!(db.registered_UDF_data, udf_data)
udf_data_ptr = pointer_from_objref(udf_data)

cs = @cfunction(wrap_stepfunc, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
cf = @cfunction(wrap_finalfunc, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))

enc = C.SQLITE_UTF8
enc = isdeterm ? enc | C.SQLITE_DETERMINISTIC : enc
Expand All @@ -268,7 +284,7 @@ function register(
name,
nargs,
enc,
C_NULL,
udf_data_ptr,
C_NULL,
cs,
cf,
Expand Down

0 comments on commit 7a00040

Please sign in to comment.