From 048666ed34d6e71e5e65db253dc548c213c17fb1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 20 Apr 2021 17:05:32 -0400 Subject: [PATCH] Squashed 'wrap/' changes from 903694b77..b2144a712 b2144a712 Merge pull request #95 from borglab/feature/empty-str-default-arg 9f1e727d8 Merge pull request #96 from borglab/fix/cmake 97ee2ff0c fix CMake typo 64a599827 support empty strings as default args 7b14ed542 Merge pull request #94 from borglab/fix/cmake-messages 0978641fe clean up 5b9272557 Merge pull request #91 from borglab/feature/enums 56e6f48b3 Merge pull request #93 from borglab/feature/better-template 27cc7cebf better cmake messages a6318b567 fix tests b7f60463f remove export_values() 38304fe0a support for class nested enums 348160740 minor fixes 5b6d66a97 use cpp_class and correct module name 2f7ae0676 add newlines and formatting 6e7cecc50 remove support for enum value assignment c1dc925a6 formatting 798732598 better pybind template f6dad2959 pybind_wrapper fixes with formatting 7b4a06560 Merge branch 'master' into feature/enums 1982b7131 more comprehensive tests for enums 3a0eafd66 code for wrapping enums 398780982 tests for enum support git-subtree-dir: wrap git-subtree-split: b2144a712953dcc3e001c97c2ace791149c97278 --- CMakeLists.txt | 38 +++--- gtwrap/interface_parser/__init__.py | 2 + gtwrap/interface_parser/classes.py | 42 +++--- gtwrap/interface_parser/enum.py | 70 ++++++++++ gtwrap/interface_parser/function.py | 3 + gtwrap/interface_parser/module.py | 10 +- gtwrap/interface_parser/namespace.py | 4 +- gtwrap/interface_parser/tokens.py | 1 + gtwrap/interface_parser/utils.py | 26 ++++ gtwrap/interface_parser/variable.py | 2 + gtwrap/pybind_wrapper.py | 139 ++++++++++++++------ gtwrap/template_instantiator.py | 6 +- templates/pybind_wrapper.tpl.example | 1 + tests/expected/matlab/functions_wrapper.cpp | 5 +- tests/expected/python/enum_pybind.cpp | 51 +++++++ tests/expected/python/functions_pybind.cpp | 2 +- tests/fixtures/enum.i | 23 ++++ tests/fixtures/functions.i | 2 +- tests/fixtures/special_cases.i | 8 ++ tests/test_interface_parser.py | 62 ++++++--- tests/test_pybind_wrapper.py | 11 ++ 21 files changed, 399 insertions(+), 109 deletions(-) create mode 100644 gtwrap/interface_parser/enum.py create mode 100644 gtwrap/interface_parser/utils.py create mode 100644 tests/expected/python/enum_pybind.cpp create mode 100644 tests/fixtures/enum.i diff --git a/CMakeLists.txt b/CMakeLists.txt index 91fbaec645..9e03da0607 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,17 +35,19 @@ configure_package_config_file( INSTALL_INCLUDE_DIR INSTALL_PREFIX ${CMAKE_INSTALL_PREFIX}) -message(STATUS "Package config : ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") +# Set all the install paths +set(GTWRAP_CMAKE_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}) +set(GTWRAP_LIB_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}) +set(GTWRAP_BIN_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}) +set(GTWRAP_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}) # ############################################################################## # Install the package -message(STATUS "CMake : ${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") # Install CMake scripts to the standard CMake script directory. -install( - FILES ${CMAKE_CURRENT_BINARY_DIR}/cmake/gtwrapConfig.cmake - cmake/MatlabWrap.cmake cmake/PybindWrap.cmake cmake/GtwrapUtils.cmake - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_CMAKE_DIR}") +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/cmake/gtwrapConfig.cmake + cmake/MatlabWrap.cmake cmake/PybindWrap.cmake + cmake/GtwrapUtils.cmake DESTINATION "${GTWRAP_CMAKE_INSTALL_DIR}") # Configure the include directory for matlab.h This allows the #include to be # either gtwrap/matlab.h, wrap/matlab.h or something custom. @@ -60,24 +62,26 @@ configure_file(${PROJECT_SOURCE_DIR}/templates/matlab_wrapper.tpl.in # Install the gtwrap python package as a directory so it can be found by CMake # for wrapping. -message(STATUS "Lib path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") -install(DIRECTORY gtwrap - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") +install(DIRECTORY gtwrap DESTINATION "${GTWRAP_LIB_INSTALL_DIR}") # Install pybind11 directory to `CMAKE_INSTALL_PREFIX/lib/gtwrap/pybind11` This # will allow the gtwrapConfig.cmake file to load it later. -install(DIRECTORY pybind11 - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_LIB_DIR}") +install(DIRECTORY pybind11 DESTINATION "${GTWRAP_LIB_INSTALL_DIR}") # Install wrapping scripts as binaries to `CMAKE_INSTALL_PREFIX/bin` so they can # be invoked for wrapping. We use DESTINATION (instead of TYPE) so we can # support older CMake versions. -message(STATUS "Bin path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}") install(PROGRAMS scripts/pybind_wrap.py scripts/matlab_wrap.py - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_BIN_DIR}") + DESTINATION "${GTWRAP_BIN_INSTALL_DIR}") # Install the matlab.h file to `CMAKE_INSTALL_PREFIX/lib/gtwrap/matlab.h`. -message( - STATUS "Header path : ${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}") -install(FILES matlab.h - DESTINATION "${CMAKE_INSTALL_PREFIX}/${INSTALL_INCLUDE_DIR}") +install(FILES matlab.h DESTINATION "${GTWRAP_INCLUDE_INSTALL_DIR}") + +string(ASCII 27 Esc) +set(gtwrap "${Esc}[1;36mgtwrap${Esc}[m") +message(STATUS "${gtwrap} Package config : ${GTWRAP_CMAKE_INSTALL_DIR}") +message(STATUS "${gtwrap} version : ${PROJECT_VERSION}") +message(STATUS "${gtwrap} CMake path : ${GTWRAP_CMAKE_INSTALL_DIR}") +message(STATUS "${gtwrap} library path : ${GTWRAP_LIB_INSTALL_DIR}") +message(STATUS "${gtwrap} binary path : ${GTWRAP_BIN_INSTALL_DIR}") +message(STATUS "${gtwrap} header path : ${GTWRAP_INCLUDE_INSTALL_DIR}") diff --git a/gtwrap/interface_parser/__init__.py b/gtwrap/interface_parser/__init__.py index 8bb1fc7ddc..0f87eaaa9d 100644 --- a/gtwrap/interface_parser/__init__.py +++ b/gtwrap/interface_parser/__init__.py @@ -11,10 +11,12 @@ """ import sys + import pyparsing from .classes import * from .declaration import * +from .enum import * from .function import * from .module import * from .namespace import * diff --git a/gtwrap/interface_parser/classes.py b/gtwrap/interface_parser/classes.py index 9c83821b89..ee4a9725cb 100644 --- a/gtwrap/interface_parser/classes.py +++ b/gtwrap/interface_parser/classes.py @@ -12,13 +12,15 @@ from typing import Iterable, List, Union -from pyparsing import Optional, ZeroOrMore, Literal +from pyparsing import Literal, Optional, ZeroOrMore +from .enum import Enum from .function import ArgumentList, ReturnType from .template import Template -from .tokens import (CLASS, COLON, CONST, IDENT, LBRACE, LPAREN, RBRACE, - RPAREN, SEMI_COLON, STATIC, VIRTUAL, OPERATOR) -from .type import TemplatedType, Type, Typename +from .tokens import (CLASS, COLON, CONST, IDENT, LBRACE, LPAREN, OPERATOR, + RBRACE, RPAREN, SEMI_COLON, STATIC, VIRTUAL) +from .type import TemplatedType, Typename +from .utils import collect_namespaces from .variable import Variable @@ -200,21 +202,6 @@ def __repr__(self) -> str: ) -def collect_namespaces(obj): - """ - Get the chain of namespaces from the lowest to highest for the given object. - - Args: - obj: Object of type Namespace, Class or InstantiatedClass. - """ - namespaces = [] - ancestor = obj.parent - while ancestor and ancestor.name: - namespaces = [ancestor.name] + namespaces - ancestor = ancestor.parent - return [''] + namespaces - - class Class: """ Rule to parse a class defined in the interface file. @@ -230,9 +217,13 @@ class Members: """ Rule for all the members within a class. """ - rule = ZeroOrMore(Constructor.rule ^ StaticMethod.rule ^ Method.rule - ^ Variable.rule ^ Operator.rule).setParseAction( - lambda t: Class.Members(t.asList())) + rule = ZeroOrMore(Constructor.rule # + ^ StaticMethod.rule # + ^ Method.rule # + ^ Variable.rule # + ^ Operator.rule # + ^ Enum.rule # + ).setParseAction(lambda t: Class.Members(t.asList())) def __init__(self, members: List[Union[Constructor, Method, StaticMethod, @@ -242,6 +233,7 @@ def __init__(self, self.static_methods = [] self.properties = [] self.operators = [] + self.enums = [] for m in members: if isinstance(m, Constructor): self.ctors.append(m) @@ -253,6 +245,8 @@ def __init__(self, self.properties.append(m) elif isinstance(m, Operator): self.operators.append(m) + elif isinstance(m, Enum): + self.enums.append(m) _parent = COLON + (TemplatedType.rule ^ Typename.rule)("parent_class") rule = ( @@ -275,6 +269,7 @@ def __init__(self, t.members.static_methods, t.members.properties, t.members.operators, + t.members.enums )) def __init__( @@ -288,6 +283,7 @@ def __init__( static_methods: List[StaticMethod], properties: List[Variable], operators: List[Operator], + enums: List[Enum], parent: str = '', ): self.template = template @@ -312,6 +308,8 @@ def __init__( self.static_methods = static_methods self.properties = properties self.operators = operators + self.enums = enums + self.parent = parent # Make sure ctors' names and class name are the same. diff --git a/gtwrap/interface_parser/enum.py b/gtwrap/interface_parser/enum.py new file mode 100644 index 0000000000..fca7080ef2 --- /dev/null +++ b/gtwrap/interface_parser/enum.py @@ -0,0 +1,70 @@ +""" +GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Parser class and rules for parsing C++ enums. + +Author: Varun Agrawal +""" + +from pyparsing import delimitedList + +from .tokens import ENUM, IDENT, LBRACE, RBRACE, SEMI_COLON +from .type import Typename +from .utils import collect_namespaces + + +class Enumerator: + """ + Rule to parse an enumerator inside an enum. + """ + rule = ( + IDENT("enumerator")).setParseAction(lambda t: Enumerator(t.enumerator)) + + def __init__(self, name): + self.name = name + + def __repr__(self): + return "Enumerator: ({0})".format(self.name) + + +class Enum: + """ + Rule to parse enums defined in the interface file. + + E.g. + ``` + enum Kind { + Dog, + Cat + }; + ``` + """ + + rule = (ENUM + IDENT("name") + LBRACE + + delimitedList(Enumerator.rule)("enumerators") + RBRACE + + SEMI_COLON).setParseAction(lambda t: Enum(t.name, t.enumerators)) + + def __init__(self, name, enumerators, parent=''): + self.name = name + self.enumerators = enumerators + self.parent = parent + + def namespaces(self) -> list: + """Get the namespaces which this class is nested under as a list.""" + return collect_namespaces(self) + + def cpp_typename(self): + """ + Return a Typename with the namespaces and cpp name of this + class. + """ + namespaces_name = self.namespaces() + namespaces_name.append(self.name) + return Typename(namespaces_name) + + def __repr__(self): + return "Enum: {0}".format(self.name) diff --git a/gtwrap/interface_parser/function.py b/gtwrap/interface_parser/function.py index 64c7b176bb..bf9b15256b 100644 --- a/gtwrap/interface_parser/function.py +++ b/gtwrap/interface_parser/function.py @@ -50,6 +50,9 @@ def __init__(self, # This means a tuple has been passed so we convert accordingly elif len(default) > 1: default = tuple(default.asList()) + else: + # set to None explicitly so we can support empty strings + default = None self.default = default self.parent: Union[ArgumentList, None] = None diff --git a/gtwrap/interface_parser/module.py b/gtwrap/interface_parser/module.py index 2a564ec9b5..6412098b8a 100644 --- a/gtwrap/interface_parser/module.py +++ b/gtwrap/interface_parser/module.py @@ -12,14 +12,11 @@ # pylint: disable=unnecessary-lambda, unused-import, expression-not-assigned, no-else-return, protected-access, too-few-public-methods, too-many-arguments -import sys - -import pyparsing # type: ignore -from pyparsing import (ParserElement, ParseResults, ZeroOrMore, - cppStyleComment, stringEnd) +from pyparsing import ParseResults, ZeroOrMore, cppStyleComment, stringEnd from .classes import Class from .declaration import ForwardDeclaration, Include +from .enum import Enum from .function import GlobalFunction from .namespace import Namespace from .template import TypedefTemplateInstantiation @@ -44,7 +41,8 @@ class Module: ^ Class.rule # ^ TypedefTemplateInstantiation.rule # ^ GlobalFunction.rule # - ^ Variable.rule # + ^ Enum.rule # + ^ Variable.rule # ^ Namespace.rule # ).setParseAction(lambda t: Namespace('', t.asList())) + stringEnd) diff --git a/gtwrap/interface_parser/namespace.py b/gtwrap/interface_parser/namespace.py index 502064a2f0..8aa2e71cc1 100644 --- a/gtwrap/interface_parser/namespace.py +++ b/gtwrap/interface_parser/namespace.py @@ -18,6 +18,7 @@ from .classes import Class, collect_namespaces from .declaration import ForwardDeclaration, Include +from .enum import Enum from .function import GlobalFunction from .template import TypedefTemplateInstantiation from .tokens import IDENT, LBRACE, NAMESPACE, RBRACE @@ -68,7 +69,8 @@ class Namespace: ^ Class.rule # ^ TypedefTemplateInstantiation.rule # ^ GlobalFunction.rule # - ^ Variable.rule # + ^ Enum.rule # + ^ Variable.rule # ^ rule # )("content") # BR + RBRACE # diff --git a/gtwrap/interface_parser/tokens.py b/gtwrap/interface_parser/tokens.py index 5d2bdeaf3c..c6a40bc311 100644 --- a/gtwrap/interface_parser/tokens.py +++ b/gtwrap/interface_parser/tokens.py @@ -46,6 +46,7 @@ "#include", ], ) +ENUM = Keyword("enum") ^ Keyword("enum class") ^ Keyword("enum struct") NAMESPACE = Keyword("namespace") BASIS_TYPES = map( Keyword, diff --git a/gtwrap/interface_parser/utils.py b/gtwrap/interface_parser/utils.py new file mode 100644 index 0000000000..78c97edeae --- /dev/null +++ b/gtwrap/interface_parser/utils.py @@ -0,0 +1,26 @@ +""" +GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Various common utilities. + +Author: Varun Agrawal +""" + + +def collect_namespaces(obj): + """ + Get the chain of namespaces from the lowest to highest for the given object. + + Args: + obj: Object of type Namespace, Class, InstantiatedClass, or Enum. + """ + namespaces = [] + ancestor = obj.parent + while ancestor and ancestor.name: + namespaces = [ancestor.name] + namespaces + ancestor = ancestor.parent + return [''] + namespaces diff --git a/gtwrap/interface_parser/variable.py b/gtwrap/interface_parser/variable.py index 80dd5030bc..dffa2de126 100644 --- a/gtwrap/interface_parser/variable.py +++ b/gtwrap/interface_parser/variable.py @@ -46,6 +46,8 @@ def __init__(self, self.name = name if default: self.default = default[0] + else: + self.default = None self.parent = parent diff --git a/gtwrap/pybind_wrapper.py b/gtwrap/pybind_wrapper.py index 88bd05a494..7d0244f068 100755 --- a/gtwrap/pybind_wrapper.py +++ b/gtwrap/pybind_wrapper.py @@ -47,11 +47,15 @@ def _py_args_names(self, args_list): if names: py_args = [] for arg in args_list.args_list: - if arg.default and isinstance(arg.default, str): - arg.default = "\"{arg.default}\"".format(arg=arg) + if isinstance(arg.default, str) and arg.default is not None: + # string default arg + arg.default = ' = "{arg.default}"'.format(arg=arg) + elif arg.default: # Other types + arg.default = ' = {arg.default}'.format(arg=arg) + else: + arg.default = '' argument = 'py::arg("{name}"){default}'.format( - name=arg.name, - default=' = {0}'.format(arg.default) if arg.default else '') + name=arg.name, default='{0}'.format(arg.default)) py_args.append(argument) return ", " + ", ".join(py_args) else: @@ -61,7 +65,10 @@ def _method_args_signature_with_names(self, args_list): """Define the method signature types with the argument names.""" cpp_types = args_list.to_cpp(self.use_boost) names = args_list.args_names() - types_names = ["{} {}".format(ctype, name) for ctype, name in zip(cpp_types, names)] + types_names = [ + "{} {}".format(ctype, name) + for ctype, name in zip(cpp_types, names) + ] return ', '.join(types_names) @@ -69,14 +76,20 @@ def wrap_ctors(self, my_class): """Wrap the constructors.""" res = "" for ctor in my_class.ctors: - res += (self.method_indent + '.def(py::init<{args_cpp_types}>()' - '{py_args_names})'.format( - args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)), - py_args_names=self._py_args_names(ctor.args), - )) + res += ( + self.method_indent + '.def(py::init<{args_cpp_types}>()' + '{py_args_names})'.format( + args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)), + py_args_names=self._py_args_names(ctor.args), + )) return res - def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): + def _wrap_method(self, + method, + cpp_class, + prefix, + suffix, + method_suffix=""): py_method = method.name + method_suffix cpp_method = method.to_cpp() @@ -92,17 +105,20 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): if cpp_method == "pickle": if not cpp_class in self._serializing_classes: - raise ValueError("Cannot pickle a class which is not serializable") + raise ValueError( + "Cannot pickle a class which is not serializable") pickle_method = self.method_indent + \ ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" - return pickle_method.format(cpp_class=cpp_class, indent=self.method_indent) + return pickle_method.format(cpp_class=cpp_class, + indent=self.method_indent) is_method = isinstance(method, instantiator.InstantiatedMethod) is_static = isinstance(method, parser.StaticMethod) return_void = method.return_type.is_void() args_names = method.args.args_names() py_args_names = self._py_args_names(method.args) - args_signature_with_names = self._method_args_signature_with_names(method.args) + args_signature_with_names = self._method_args_signature_with_names( + method.args) caller = cpp_class + "::" if not is_method else "self->" function_call = ('{opt_return} {caller}{function_name}' @@ -136,7 +152,9 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): if method.name == 'print': # Redirect stdout - see pybind docs for why this is a good idea: # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream - ret = ret.replace('self->print', 'py::scoped_ostream_redirect output; self->print') + ret = ret.replace( + 'self->print', + 'py::scoped_ostream_redirect output; self->print') # Make __repr__() call print() internally ret += '''{prefix}.def("__repr__", @@ -156,7 +174,11 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): return ret - def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''): + def wrap_methods(self, + methods, + cpp_class, + prefix='\n' + ' ' * 8, + suffix=''): """ Wrap all the methods in the `cpp_class`. @@ -169,7 +191,8 @@ def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''): if method.name == 'insert' and cpp_class == 'gtsam::Values': name_list = method.args.args_names() type_list = method.args.to_cpp(self.use_boost) - if type_list[0].strip() == 'size_t': # inserting non-wrapped value types + # inserting non-wrapped value types + if type_list[0].strip() == 'size_t': method_suffix = '_' + name_list[1].strip() res += self._wrap_method(method=method, cpp_class=cpp_class, @@ -186,15 +209,18 @@ def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''): return res - def wrap_variable(self, module, module_var, variable, prefix='\n' + ' ' * 8): + def wrap_variable(self, + module, + module_var, + variable, + prefix='\n' + ' ' * 8): """Wrap a variable that's not part of a class (i.e. global) """ return '{prefix}{module_var}.attr("{variable_name}") = {module}{variable_name};'.format( prefix=prefix, module=module, module_var=module_var, - variable_name=variable.name - ) + variable_name=variable.name) def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8): """Wrap all the properties in the `cpp_class`.""" @@ -203,7 +229,8 @@ def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8): res += ('{prefix}.def_{property}("{property_name}", ' '&{cpp_class}::{property_name})'.format( prefix=prefix, - property="readonly" if prop.ctype.is_const else "readwrite", + property="readonly" + if prop.ctype.is_const else "readwrite", cpp_class=cpp_class, property_name=prop.name, )) @@ -227,7 +254,8 @@ def wrap_operators(self, operators, cpp_class, prefix='\n' + ' ' * 8): op.operator)) return res - def wrap_instantiated_class(self, instantiated_class): + def wrap_instantiated_class( + self, instantiated_class: instantiator.InstantiatedClass): """Wrap the class.""" module_var = self._gen_module_var(instantiated_class.namespaces()) cpp_class = instantiated_class.cpp_class() @@ -287,6 +315,18 @@ def wrap_stl_class(self, stl_class): stl_class.properties, cpp_class), )) + def wrap_enum(self, enum, prefix='\n' + ' ' * 8): + """Wrap an enum.""" + module_var = self._gen_module_var(enum.namespaces()) + cpp_class = enum.cpp_typename().to_cpp() + res = '\n py::enum_<{cpp_class}>({module_var}, "{enum.name}", py::arithmetic())'.format( + module_var=module_var, enum=enum, cpp_class=cpp_class) + for enumerator in enum.enumerators: + res += '{prefix}.value("{enumerator.name}", {cpp_class}::{enumerator.name})'.format( + prefix=prefix, enumerator=enumerator, cpp_class=cpp_class) + res += ";\n\n" + return res + def _partial_match(self, namespaces1, namespaces2): for i in range(min(len(namespaces1), len(namespaces2))): if namespaces1[i] != namespaces2[i]: @@ -294,6 +334,8 @@ def _partial_match(self, namespaces1, namespaces2): return True def _gen_module_var(self, namespaces): + """Get the Pybind11 module name from the namespaces.""" + # We skip the first value in namespaces since it is empty sub_module_namespaces = namespaces[len(self.top_module_namespaces):] return "m_{}".format('_'.join(sub_module_namespaces)) @@ -317,7 +359,10 @@ def wrap_namespace(self, namespace): if len(namespaces) < len(self.top_module_namespaces): for element in namespace.content: if isinstance(element, parser.Include): - includes += ("{}\n".format(element).replace('<', '"').replace('>', '"')) + include = "{}\n".format(element) + # replace the angle brackets with quotes + include = include.replace('<', '"').replace('>', '"') + includes += include if isinstance(element, parser.Namespace): ( wrapped_namespace, @@ -330,34 +375,40 @@ def wrap_namespace(self, namespace): module_var = self._gen_module_var(namespaces) if len(namespaces) > len(self.top_module_namespaces): - wrapped += (' ' * 4 + 'pybind11::module {module_var} = ' - '{parent_module_var}.def_submodule("{namespace}", "' - '{namespace} submodule");\n'.format( - module_var=module_var, - namespace=namespace.name, - parent_module_var=self._gen_module_var(namespaces[:-1]), - )) + wrapped += ( + ' ' * 4 + 'pybind11::module {module_var} = ' + '{parent_module_var}.def_submodule("{namespace}", "' + '{namespace} submodule");\n'.format( + module_var=module_var, + namespace=namespace.name, + parent_module_var=self._gen_module_var( + namespaces[:-1]), + )) + # Wrap an include statement, namespace, class or enum for element in namespace.content: if isinstance(element, parser.Include): - includes += ("{}\n".format(element).replace('<', '"').replace('>', '"')) + include = "{}\n".format(element) + # replace the angle brackets with quotes + include = include.replace('<', '"').replace('>', '"') + includes += include elif isinstance(element, parser.Namespace): - ( - wrapped_namespace, - includes_namespace, - ) = self.wrap_namespace( # noqa + wrapped_namespace, includes_namespace = self.wrap_namespace( element) wrapped += wrapped_namespace includes += includes_namespace + elif isinstance(element, instantiator.InstantiatedClass): wrapped += self.wrap_instantiated_class(element) elif isinstance(element, parser.Variable): - wrapped += self.wrap_variable( - module=self._add_namespaces('', namespaces), - module_var=module_var, - variable=element, - prefix='\n' + ' ' * 4 - ) + module = self._add_namespaces('', namespaces) + wrapped += self.wrap_variable(module=module, + module_var=module_var, + variable=element, + prefix='\n' + ' ' * 4) + + elif isinstance(element, parser.Enum): + wrapped += self.wrap_enum(element) # Global functions. all_funcs = [ @@ -388,7 +439,8 @@ def wrap(self): cpp_class=cpp_class, new_name=new_name, ) - boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format(new_name=new_name, ) + boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format( + new_name=new_name, ) holder_type = "PYBIND11_DECLARE_HOLDER_TYPE(TYPE_PLACEHOLDER_DONOTUSE, " \ "{shared_ptr_type}::shared_ptr);" @@ -398,7 +450,8 @@ def wrap(self): include_boost=include_boost, module_name=self.module_name, includes=includes, - holder_type=holder_type.format(shared_ptr_type=('boost' if self.use_boost else 'std')) + holder_type=holder_type.format( + shared_ptr_type=('boost' if self.use_boost else 'std')) if self.use_boost else "", wrapped_namespace=wrapped_namespace, boost_class_export=boost_class_export, diff --git a/gtwrap/template_instantiator.py b/gtwrap/template_instantiator.py index bddaa07a8f..a66fa95445 100644 --- a/gtwrap/template_instantiator.py +++ b/gtwrap/template_instantiator.py @@ -266,7 +266,7 @@ class InstantiatedClass(parser.Class): """ Instantiate the class defined in the interface file. """ - def __init__(self, original, instantiations=(), new_name=''): + def __init__(self, original: parser.Class, instantiations=(), new_name=''): """ Template Instantiations: [T1, U1] @@ -302,6 +302,9 @@ def __init__(self, original, instantiations=(), new_name=''): # Instantiate all operator overloads self.operators = self.instantiate_operators(typenames) + # Set enums + self.enums = original.enums + # Instantiate all instance methods instantiated_methods = \ self.instantiate_class_templates_in_methods(typenames) @@ -330,6 +333,7 @@ def __init__(self, original, instantiations=(), new_name=''): self.static_methods, self.properties, self.operators, + self.enums, parent=self.parent, ) diff --git a/templates/pybind_wrapper.tpl.example b/templates/pybind_wrapper.tpl.example index 8c38ad21c4..bf5b334900 100644 --- a/templates/pybind_wrapper.tpl.example +++ b/templates/pybind_wrapper.tpl.example @@ -5,6 +5,7 @@ #include #include #include +#include #include "gtsam/base/serialization.h" #include "gtsam/nonlinear/utilities.h" // for RedirectCout. diff --git a/tests/expected/matlab/functions_wrapper.cpp b/tests/expected/matlab/functions_wrapper.cpp index b8341b4bae..536733bdcc 100644 --- a/tests/expected/matlab/functions_wrapper.cpp +++ b/tests/expected/matlab/functions_wrapper.cpp @@ -204,9 +204,10 @@ void DefaultFuncInt_8(int nargout, mxArray *out[], int nargin, const mxArray *in } void DefaultFuncString_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("DefaultFuncString",nargout,nargin,1); + checkArguments("DefaultFuncString",nargout,nargin,2); string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string"); - DefaultFuncString(s); + string& name = *unwrap_shared_ptr< string >(in[1], "ptr_string"); + DefaultFuncString(s,name); } void DefaultFuncObj_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { diff --git a/tests/expected/python/enum_pybind.cpp b/tests/expected/python/enum_pybind.cpp new file mode 100644 index 0000000000..5e792b211c --- /dev/null +++ b/tests/expected/python/enum_pybind.cpp @@ -0,0 +1,51 @@ + + +#include +#include +#include +#include +#include "gtsam/nonlinear/utilities.h" // for RedirectCout. + + +#include "wrap/serialization.h" +#include + + + + + +using namespace std; + +namespace py = pybind11; + +PYBIND11_MODULE(enum_py, m_) { + m_.doc() = "pybind11 wrapper of enum_py"; + + + py::enum_(m_, "Kind", py::arithmetic()) + .value("Dog", Kind::Dog) + .value("Cat", Kind::Cat); + + pybind11::module m_gtsam = m_.def_submodule("gtsam", "gtsam submodule"); + + py::enum_(m_gtsam, "VerbosityLM", py::arithmetic()) + .value("SILENT", gtsam::VerbosityLM::SILENT) + .value("SUMMARY", gtsam::VerbosityLM::SUMMARY) + .value("TERMINATION", gtsam::VerbosityLM::TERMINATION) + .value("LAMBDA", gtsam::VerbosityLM::LAMBDA) + .value("TRYLAMBDA", gtsam::VerbosityLM::TRYLAMBDA) + .value("TRYCONFIG", gtsam::VerbosityLM::TRYCONFIG) + .value("DAMPED", gtsam::VerbosityLM::DAMPED) + .value("TRYDELTA", gtsam::VerbosityLM::TRYDELTA); + + + py::class_>(m_gtsam, "Pet") + .def(py::init(), py::arg("name"), py::arg("type")) + .def_readwrite("name", >sam::Pet::name) + .def_readwrite("type", >sam::Pet::type); + + +#include "python/specializations.h" + +} + diff --git a/tests/expected/python/functions_pybind.cpp b/tests/expected/python/functions_pybind.cpp index 2513bcf564..47c540bc09 100644 --- a/tests/expected/python/functions_pybind.cpp +++ b/tests/expected/python/functions_pybind.cpp @@ -31,7 +31,7 @@ PYBIND11_MODULE(functions_py, m_) { m_.def("MultiTemplatedFunctionStringSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction(x, y);}, py::arg("x"), py::arg("y")); m_.def("MultiTemplatedFunctionDoubleSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction(x, y);}, py::arg("x"), py::arg("y")); m_.def("DefaultFuncInt",[](int a){ ::DefaultFuncInt(a);}, py::arg("a") = 123); - m_.def("DefaultFuncString",[](const string& s){ ::DefaultFuncString(s);}, py::arg("s") = "hello"); + m_.def("DefaultFuncString",[](const string& s, const string& name){ ::DefaultFuncString(s, name);}, py::arg("s") = "hello", py::arg("name") = ""); m_.def("DefaultFuncObj",[](const gtsam::KeyFormatter& keyFormatter){ ::DefaultFuncObj(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter); m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction(t);}, py::arg("t")); diff --git a/tests/fixtures/enum.i b/tests/fixtures/enum.i new file mode 100644 index 0000000000..97a5383e69 --- /dev/null +++ b/tests/fixtures/enum.i @@ -0,0 +1,23 @@ +enum Kind { Dog, Cat }; + +namespace gtsam { +enum VerbosityLM { + SILENT, + SUMMARY, + TERMINATION, + LAMBDA, + TRYLAMBDA, + TRYCONFIG, + DAMPED, + TRYDELTA +}; + +class Pet { + enum Kind { Dog, Cat }; + + Pet(const string &name, Kind type); + + string name; + Kind type; +}; +} // namespace gtsam diff --git a/tests/fixtures/functions.i b/tests/fixtures/functions.i index 5e774a05a9..2980286913 100644 --- a/tests/fixtures/functions.i +++ b/tests/fixtures/functions.i @@ -29,5 +29,5 @@ typedef TemplatedFunction TemplatedFunctionRot3; // Check default arguments void DefaultFuncInt(int a = 123); -void DefaultFuncString(const string& s = "hello"); +void DefaultFuncString(const string& s = "hello", const string& name = ""); void DefaultFuncObj(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); diff --git a/tests/fixtures/special_cases.i b/tests/fixtures/special_cases.i index da1170c5c0..87efca54c7 100644 --- a/tests/fixtures/special_cases.i +++ b/tests/fixtures/special_cases.i @@ -26,3 +26,11 @@ class SfmTrack { }; } // namespace gtsam + + +// class VariableIndex { +// VariableIndex(); +// // template +// VariableIndex(const T& graph); +// VariableIndex(const T& graph, size_t nVariables); +// }; diff --git a/tests/test_interface_parser.py b/tests/test_interface_parser.py index 28b645201d..70f044f04a 100644 --- a/tests/test_interface_parser.py +++ b/tests/test_interface_parser.py @@ -19,9 +19,10 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from gtwrap.interface_parser import ( - ArgumentList, Class, Constructor, ForwardDeclaration, GlobalFunction, - Include, Method, Module, Namespace, Operator, ReturnType, StaticMethod, - TemplatedType, Type, TypedefTemplateInstantiation, Typename, Variable) + ArgumentList, Class, Constructor, Enum, Enumerator, ForwardDeclaration, + GlobalFunction, Include, Method, Module, Namespace, Operator, ReturnType, + StaticMethod, TemplatedType, Type, TypedefTemplateInstantiation, Typename, + Variable) class TestInterfaceParser(unittest.TestCase): @@ -180,7 +181,7 @@ def test_argument_list_templated(self): def test_default_arguments(self): """Tests any expression that is a valid default argument""" args = ArgumentList.rule.parseString( - "string s=\"hello\", int a=3, " + "string c = \"\", string s=\"hello\", int a=3, " "int b, double pi = 3.1415, " "gtsam::KeyFormatter kf = gtsam::DefaultKeyFormatter, " "std::vector p = std::vector(), " @@ -188,22 +189,21 @@ def test_default_arguments(self): )[0].args_list # Test for basic types - self.assertEqual(args[0].default, "hello") - self.assertEqual(args[1].default, 3) - # '' is falsy so we can check against it - self.assertEqual(args[2].default, '') - self.assertFalse(args[2].default) + self.assertEqual(args[0].default, "") + self.assertEqual(args[1].default, "hello") + self.assertEqual(args[2].default, 3) + # No default argument should set `default` to None + self.assertIsNone(args[3].default) - self.assertEqual(args[3].default, 3.1415) + self.assertEqual(args[4].default, 3.1415) # Test non-basic type - self.assertEqual(repr(args[4].default.typename), + self.assertEqual(repr(args[5].default.typename), 'gtsam::DefaultKeyFormatter') # Test templated type - self.assertEqual(repr(args[5].default.typename), 'std::vector') + self.assertEqual(repr(args[6].default.typename), 'std::vector') # Test for allowing list as default argument - print(args) - self.assertEqual(args[6].default, (1, 2, 'name', "random", 3.1415)) + self.assertEqual(args[7].default, (1, 2, 'name', "random", 3.1415)) def test_return_type(self): """Test ReturnType""" @@ -424,6 +424,17 @@ def test_class_inheritance(self): self.assertEqual(["gtsam"], ret.parent_class.instantiations[0].namespaces) + def test_class_with_enum(self): + """Test for class with nested enum.""" + ret = Class.rule.parseString(""" + class Pet { + Pet(const string &name, Kind type); + enum Kind { Dog, Cat }; + }; + """)[0] + self.assertEqual(ret.name, "Pet") + self.assertEqual(ret.enums[0].name, "Kind") + def test_include(self): """Test for include statements.""" include = Include.rule.parseString( @@ -460,12 +471,33 @@ def test_global_variable(self): self.assertEqual(variable.ctype.typename.name, "string") self.assertEqual(variable.default, 9.81) - variable = Variable.rule.parseString("const string kGravity = 9.81;")[0] + variable = Variable.rule.parseString( + "const string kGravity = 9.81;")[0] self.assertEqual(variable.name, "kGravity") self.assertEqual(variable.ctype.typename.name, "string") self.assertTrue(variable.ctype.is_const) self.assertEqual(variable.default, 9.81) + def test_enumerator(self): + """Test for enumerator.""" + enumerator = Enumerator.rule.parseString("Dog")[0] + self.assertEqual(enumerator.name, "Dog") + + enumerator = Enumerator.rule.parseString("Cat")[0] + self.assertEqual(enumerator.name, "Cat") + + def test_enum(self): + """Test for enums.""" + enum = Enum.rule.parseString(""" + enum Kind { + Dog, + Cat + }; + """)[0] + self.assertEqual(enum.name, "Kind") + self.assertEqual(enum.enumerators[0].name, "Dog") + self.assertEqual(enum.enumerators[1].name, "Cat") + def test_namespace(self): """Test for namespace parsing.""" namespace = Namespace.rule.parseString(""" diff --git a/tests/test_pybind_wrapper.py b/tests/test_pybind_wrapper.py index 5eff554462..fe5e1950e0 100644 --- a/tests/test_pybind_wrapper.py +++ b/tests/test_pybind_wrapper.py @@ -158,6 +158,17 @@ def test_special_cases(self): self.compare_and_diff('special_cases_pybind.cpp', output) + def test_enum(self): + """ + Test if enum generation is correct. + """ + with open(osp.join(self.INTERFACE_DIR, 'enum.i'), 'r') as f: + content = f.read() + + output = self.wrap_content(content, 'enum_py', + self.PYTHON_ACTUAL_DIR) + + self.compare_and_diff('enum_pybind.cpp', output) if __name__ == '__main__': unittest.main()