From f832dca0d00ed71035e5752da0f5ce8cdef926e3 Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Wed, 24 Apr 2024 13:42:14 +0200 Subject: [PATCH] Make sure annotations don't make workflow fail --- pyproject.toml | 2 +- src/nodify/parse.py | 42 +++++++++++++++++++++++++++---- src/nodify/tests/test_workflow.py | 14 ++++++++++- src/nodify/workflow.py | 6 ++--- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6eb8882..a39ed08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ name = "nodify" description = "Supercharge your functional application with a powerful node system." readme = "README.md" license = {file = "LICENSE"} -version = "0.0.4" +version = "0.0.5" dependencies = [] diff --git a/src/nodify/parse.py b/src/nodify/parse.py index 5be9e8e..0cac48b 100644 --- a/src/nodify/parse.py +++ b/src/nodify/parse.py @@ -32,6 +32,7 @@ def __init__( constant_cls: str = "ConstantNode", nodify_constants: bool = False, nodify_constant_assignments: bool = False, + remove_function_annotations: bool = False, **kwargs, ): super().__init__(*args, **kwargs) @@ -41,6 +42,7 @@ def __init__( self.constant_cls = constant_cls self.nodify_constants = nodify_constants self.nodify_constant_assignments = nodify_constant_assignments + self.remove_function_annotations = remove_function_annotations def visit_Call(self, node): """Converts some_module.some_attr(some_args) into Node.from_func(some_module.some_attr)(some_args)""" @@ -177,7 +179,6 @@ def visit_IfExp(self, node: ast.IfExp) -> Any: def visit_Constant(self, node: ast.Constant) -> Any: if self.nodify_constants: - print(ast.dump(node)) new_node = ast.Call( func=ast.Name(id=self.constant_cls, ctx=ast.Load()), args=[node], @@ -186,12 +187,30 @@ def visit_Constant(self, node: ast.Constant) -> Any: ast.fix_missing_locations(new_node) - print(ast.dump(new_node)) - return new_node else: return self.generic_visit(node) + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + + if self.remove_function_annotations: + + def _remove_annotation(arg): + arg.annotation = None + return arg + + node.args.args = [_remove_annotation(arg) for arg in node.args.args] + node.args.kwonlyargs = [ + _remove_annotation(arg) for arg in node.args.kwonlyargs + ] + node.args.posonlyargs = [ + _remove_annotation(arg) for arg in node.args.posonlyargs + ] + + node.returns = None + + return self.generic_visit(node) + # def visit_Compare(self, node: ast.Compare) -> Any: # """Converts the comparison syntax into CompareNode call.""" # if len(node.ops) > 1: @@ -264,14 +283,25 @@ def nodify_func( # Make sure the first line is at the 0 indentation level. code = textwrap.dedent(code) + old_signature = inspect.signature(func) + code_obj = nodify_code( - code, transformer_cls, assign_fn, node_cls, namespace=func_namespace + code, + transformer_cls, + assign_fn, + node_cls, + remove_function_annotations=True, + namespace=func_namespace, ) # Execute the code, and retrieve the new function from the namespace. exec(code_obj, func_namespace) - return func_namespace[func.__name__] + new_func = func_namespace[func.__name__] + + new_func.__signature__ = old_signature + + return new_func def nodify_code( @@ -281,6 +311,7 @@ def nodify_code( node_cls: Type[Node] = Node, nodify_constants: bool = False, nodify_constant_assignments: bool = False, + remove_function_annotations: bool = False, namespace: Optional[dict] = None, ): if namespace is None: @@ -320,6 +351,7 @@ def nodify_code( node_cls_name=node_cls_name, nodify_constants=nodify_constants, nodify_constant_assignments=nodify_constant_assignments, + remove_function_annotations=remove_function_annotations, ) new_tree = transformer.visit(tree) diff --git a/src/nodify/tests/test_workflow.py b/src/nodify/tests/test_workflow.py index cc3b8e1..1f6fd8d 100644 --- a/src/nodify/tests/test_workflow.py +++ b/src/nodify/tests/test_workflow.py @@ -13,7 +13,8 @@ @pytest.fixture( - scope="module", params=["from_func", "explicit_class", "input_operations"] + scope="module", + params=["from_func", "explicit_class", "input_operations", "with_annotations"], ) def triple_sum(request) -> Type[Workflow]: """Returns a workflow that computes a triple sum. @@ -56,6 +57,17 @@ def triple_sum(a, b, c): triple_sum._sum_key = "BinaryOperationNode" + elif request.param == "with_annotations": + + with pytest.warns(): + + @Workflow.from_func + def triple_sum(a: int, b: int, c: int) -> int: + first_sum = a + b + return first_sum + c + + triple_sum._sum_key = "BinaryOperationNode" + return triple_sum diff --git a/src/nodify/workflow.py b/src/nodify/workflow.py index e708c98..d6672a4 100644 --- a/src/nodify/workflow.py +++ b/src/nodify/workflow.py @@ -918,12 +918,12 @@ def assign_workflow_var(value: Any, var_name: str): # Get the workflow function work_func = cls.function - # Nodify it, passing the middleware function that will assign the variables to the workflow. - work_func = nodify_func(work_func, assign_fn=assign_workflow_var) - # Get the signature of the function. sig = inspect.signature(work_func) + # Nodify it, passing the middleware function that will assign the variables to the workflow. + work_func = nodify_func(work_func, assign_fn=assign_workflow_var) + # Run a dryrun of the workflow, so that we can understand how the nodes are connected. # To this end, nodes must behave lazily. with temporal_context(lazy=True):