Skip to content

Commit

Permalink
continue work on overloading
Browse files Browse the repository at this point in the history
  • Loading branch information
goatshriek committed Aug 30, 2024
1 parent d210b22 commit ad9e07a
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 20 deletions.
19 changes: 19 additions & 0 deletions lib/wrapture/class_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def child?
@spec.key?('parent')
end

# A list of constructor functions for the class.
def constructors
@functions.select(&:constructor?)
end

# A list of includes needed for the declaration of the class.
def declaration_includes
includes = @spec['includes'].dup
Expand Down Expand Up @@ -203,6 +208,11 @@ def definition_includes
includes.uniq
end

# The destructor function for the class, or nil if there isn't one.
def destructor
@functions.select(&:destructor?).first
end

# Calls the given block for each line of the class documentation.
def documentation(&block)
@doc&.format_as_doxygen(max_line_length: 78) { |line| block.call(line) }
Expand Down Expand Up @@ -235,6 +245,15 @@ def libraries
@functions.flat_map(&:libraries).concat(@spec['libraries'])
end

# An array of methods of the class. This is a subset of the list of
# functions without the constructors and destructors.
#
# Named with a specs suffix to avoid conflicts with Ruby's "methods"
# instance method.
def method_specs
@functions.select { |spec| !spec.constructor? && !spec.destructor? }
end

# The name of the class.
def name
@spec['name']
Expand Down
151 changes: 132 additions & 19 deletions lib/wrapture/python_wrapper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ def class_functions(class_spec)
functions
end

# The functions of the given spec where functions that are overloads of each
# other are grouped together. Functions that are overloaded are represented
# as an array of function specs. Functions that are not overloaded are in an
# array by themselves.
def class_function_groups(class_spec)
groups = []

funcs = class_functions(class_spec)
groups.append(funcs.select(&:constructor?))
groups.append(funcs.select(&:destructor?))
methods = funcs.select do |func_spec|
!func_spec.constructor? && !func_spec.destructor?
end.group_by(&:name).values

groups.concat(methods)
end

# Creates a Python object using a variable with the given name and type.
def create_python_object(type, name)
if type.name == 'int'
Expand All @@ -245,7 +262,7 @@ def create_python_object(type, name)
"PyBool_FromLong(#{name})"
else
# TODO: default case
''
'// TODO default case'
end
end

Expand Down Expand Up @@ -292,7 +309,9 @@ def define_class_methods(class_spec)
snake_name = class_spec.snake_case_name
yield "static PyMethodDef #{snake_name}_methods[] = {"

class_spec.functions.each do |func_spec|
# class_spec.functions.each do |func_spec|
# class_function_groups(class_spec).map(&:first).each do |func_spec|
class_spec.method_specs.each do |func_spec|
wrapper_name = function_wrapper_name(func_spec)
yield " { .ml_name = \"#{func_spec.name}\","
yield " .ml_meth = ( PyCFunction ) #{wrapper_name},"
Expand All @@ -310,8 +329,12 @@ def define_class_type_object(class_spec, &block)
define_class_type_struct(class_spec) { |line| block.call(line) }
yield ''

class_functions(class_spec).each do |func_spec|
define_function_wrapper(func_spec, &block)
class_function_groups(class_spec).each do |func_group|
if func_group.length == 1
define_function_wrapper(func_group[0], &block)
else
define_function_group_wrapper(func_group, &block)
end
yield ''
end

Expand Down Expand Up @@ -432,10 +455,87 @@ def define_enum_constructor(enum_spec)
yield '}'
end

# Defines a function that parses and validates parameters of a function.
def define_function_arg_parser(func_spec, name = nil)
yield 'static int'
name = "parse_#{function_wrapper_name(func_spec)}" if name.nil?
signature_declarations = []
func_spec.params.each do |param_spec|
param_type_spec = func_spec.resolve_type(param_spec.type)
param_type = if func_spec.owner.scope.type?(param_type_spec)
"#{self.class.type_struct_name(param_type_spec)} *"
else
"#{param_type_spec} *"
end
signature_declarations << "#{param_type} #{param_spec.name}"
end
wrapped_args = signature_declarations.join(', ')
yield "#{name}( PyObject *args, PyObject *kwds, #{wrapped_args} ) {"
format_str = "\"#{function_args_format(func_spec)}\""
params = func_spec.param_names.join(', ')
yield " return PyArg_ParseTuple( args, #{format_str}, #{params} );"
yield '}'
end

# Defines a function that determines which function in the provided group to
# call based on the parameters, and then calls it in the python interpreter.
def define_function_group_wrapper(func_group, &block)
base_name = function_wrapper_name(func_group[0])
no_args = nil

func_group.each_with_index do |func_spec, i|
wrapper_name = "#{base_name}_#{i}"

if func_spec.params?
parser_name = "parse_#{wrapper_name}"
define_function_arg_parser(func_spec, parser_name, &block)
yield ''
else
no_args = wrapper_name
end

define_function_wrapper(func_spec, wrapper_name, &block)
yield ''
end

yield 'static PyObject *'
yield "#{base_name}( #{function_params(func_group[0]).join(', ')} ) {"
yield ' int parse_result;'
func_group.each_with_index do |func_spec, i|
function_param_locals(func_spec, "_#{i}") do |stmt|
yield " #{stmt}".rstrip
end
end
yield ''

func_group.each_with_index do |func_spec, i|
next unless func_spec.params?

wrapper_name = "#{base_name}_#{i}"
parser_name = "parse_#{base_name}_#{i}"
parsed_args = "&#{func_spec.param_names.join("_#{i}, &")}_#{i}"
yield " parse_result = #{parser_name}( args, kwds, #{parsed_args} );"
yield ' if( parse_result ){'
yield " return #{wrapper_name}( type, args, kwds );"
yield ' }'
yield ''
end

if no_args
yield ' // TODO check to make sure there are no args for this call'
yield " return #{no_args}( type, args, kwds );"
end

yield ' // TODO throw an error for no overload found'
yield '}'
end

# Defines the function that the python interpreter will call for the given
# function spec.
def define_function_wrapper(func_spec, &block)
name = function_wrapper_name(func_spec)
# function spec. If +name+ is provided it will be used as the name of the
# function instead of deriving it from the spec.
def define_function_wrapper(func_spec, name = nil, &block)
name = function_wrapper_name(func_spec) if name.nil?

owner_snake_name = func_spec.owner.snake_case_name
type_struct_name = "#{owner_snake_name}_type_struct"

Expand All @@ -445,17 +545,31 @@ def define_function_wrapper(func_spec, &block)
wrapped_call(func_spec, &block)
yield ' Py_TYPE( self )->tp_free( ( PyObject * ) self );'
else
if func_spec.params?
define_function_arg_parser(func_spec, &block)
yield ''
end
yield 'static PyObject *'
yield "#{name}( #{function_params(func_spec).join(', ')} ) {"

function_locals(func_spec) { |declaration| yield " #{declaration}" }

if func_spec.params?
parsed_args = "&#{func_spec.param_names.join(', &')}"
yield " if( !parse_#{name}( args, kwds, #{parsed_args} ) ){"
yield ' return NULL;'
yield ' }'
yield ''
end

if func_spec.constructor?
yield " self = ( #{type_struct_name} * ) type->tp_alloc( type, 0 );"
yield ' if( !self ) {'
yield ' return NULL;'
yield ' }'
yield ''

yield '' unless func_spec.owner.constants.empty?

func_spec.owner.constants.each do |constant_spec|
field_name = constant_spec.snake_case_name
field_value = constant_spec.value
Expand All @@ -477,7 +591,7 @@ def define_function_wrapper(func_spec, &block)
def define_module(&block)
yield '#define PY_SSIZE_T_CLEAN'
yield '#include <Python.h>'
yield '#include <stddef.h>' # TODO: for offsetof(), only add if needed
yield '#include <stddef.h> // for offsetof()' # TODO: only add if needed
yield '#if PY_VERSION_HEX < 0x30C00F0 // under Python 3.12.0'
yield ' #include <structmember.h> // for PyMemberDef'
yield ' #define Py_T_INT T_INT'
Expand Down Expand Up @@ -585,7 +699,7 @@ def function_locals(spec, &block)

# Yields a declaration of each local variable needed for params by a
# function.
def function_param_locals(spec)
def function_param_locals(spec, suffix = '')
return unless spec.params?

spec.params.each do |param_spec|
Expand All @@ -595,10 +709,10 @@ def function_param_locals(spec)
else
param_type_spec.to_s
end
yield "#{param_type} #{param_spec.name};"
yield "#{param_type} #{param_spec.name}#{suffix};"
end

yield ''
yield '' unless spec.optional_params.empty?

spec.optional_params.each do |param_spec|
assignment = "#{param_spec.name} = "
Expand All @@ -611,12 +725,6 @@ def function_param_locals(spec)
end
yield "#{assignment};"
end

format_str = "\"#{function_args_format(spec)}\""
parsed_args = "&#{spec.param_names.join(', &')}"
yield "if( !PyArg_ParseTuple( args, #{format_str}, #{parsed_args} ) ) {"
yield ' return NULL;'
yield '}'
end

# A list of parameters for the given function's wrapper.
Expand Down Expand Up @@ -695,7 +803,12 @@ def return_statement(func_spec)
elsif func_spec.void_return?
'Py_RETURN_NONE;'
else
"return #{create_python_object(func_spec.return_type, 'return_val')};"
return_value = create_python_object(func_spec.return_type, 'return_val')
if return_value.empty?
'return;'
else
"return #{return_value};"
end
end
end

Expand Down
3 changes: 3 additions & 0 deletions sig/class_spec.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ module Wrapture
attr_reader struct: (Wrapture::StructSpec | nil)
def initialize: (spec_hash spec, ?scope: Wrapture::Scope) -> void
def child?: -> bool
def constructors: -> Array[Wrapture::FunctionSpec]
def declaration_includes: -> Array[String]
def definition_includes: -> Array[String]
def destructor: -> Wrapture::FunctionSpec
def documentation: { (String) -> void } -> void
def equivalent_member?: -> bool
def factory?: -> bool
def libraries: -> Array[String]
def method_specs: -> Array[Wrapture::FunctionSpec]
def name: -> String
def namespace: -> String
def overloads?: (untyped parent_spec) -> bool
Expand Down
5 changes: 4 additions & 1 deletion sig/python_wrapper.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ module Wrapture
def cast: (Wrapture::ClassSpec class_spec, String var_name, String to) -> String
def castable?: (spec_hash wrapped_param) -> bool
def class_functions: (Wrapture::ClassSpec) -> Array[Wrapture::FunctionSpec]
def class_function_groups: (Wrapture::ClassSpec) -> Array[Array[Wrapture::FunctionSpec]]
def create_python_object: (Wrapture::TypeSpec, String) -> String
def default_constructor: (Wrapture::ClassSpec) -> Wrapture::FunctionSpec
def default_destructor: (Wrapture::ClassSpec) -> Wrapture::FunctionSpec
Expand All @@ -24,7 +25,9 @@ module Wrapture
def define_class_type_objects: { (String) -> void } -> void
def define_class_type_struct: (Wrapture::ClassSpec) { (String) -> void } -> void
def define_enum_constructor: (Wrapture::EnumSpec) { (String) -> void } -> void
def define_function_wrapper: { (String) -> void } -> void
def define_function_arg_parser: (Wrapture::FunctionSpec) { (String) -> void } -> void
def define_function_group_wrapper: (Array[Wrapture::FunctionSpec]) { (String) -> void } -> void
def define_function_wrapper: (Wrapture::FunctionSpec) { (String) -> void } -> void
def define_module: { (String) -> void } -> void
def define_scope_type_objects: { (String) -> void } -> void
def equivalent_member_declaration: -> String
Expand Down

0 comments on commit ad9e07a

Please sign in to comment.