Skip to content

Commit

Permalink
Make sure annotations don't make workflow fail
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Apr 24, 2024
1 parent e67708b commit f832dca
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name = "nodify"
description = "Supercharge your functional application with a powerful node system."
readme = "README.md"
license = {file = "LICENSE"}
version = "0.0.4"
version = "0.0.5"

dependencies = []

Expand Down
42 changes: 37 additions & 5 deletions src/nodify/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
constant_cls: str = "ConstantNode",
nodify_constants: bool = False,
nodify_constant_assignments: bool = False,
remove_function_annotations: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
Expand All @@ -41,6 +42,7 @@ def __init__(
self.constant_cls = constant_cls
self.nodify_constants = nodify_constants
self.nodify_constant_assignments = nodify_constant_assignments
self.remove_function_annotations = remove_function_annotations

def visit_Call(self, node):
"""Converts some_module.some_attr(some_args) into Node.from_func(some_module.some_attr)(some_args)"""
Expand Down Expand Up @@ -177,7 +179,6 @@ def visit_IfExp(self, node: ast.IfExp) -> Any:

def visit_Constant(self, node: ast.Constant) -> Any:
if self.nodify_constants:
print(ast.dump(node))
new_node = ast.Call(
func=ast.Name(id=self.constant_cls, ctx=ast.Load()),
args=[node],
Expand All @@ -186,12 +187,30 @@ def visit_Constant(self, node: ast.Constant) -> Any:

ast.fix_missing_locations(new_node)

print(ast.dump(new_node))

return new_node
else:
return self.generic_visit(node)

def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:

if self.remove_function_annotations:

def _remove_annotation(arg):
arg.annotation = None
return arg

node.args.args = [_remove_annotation(arg) for arg in node.args.args]
node.args.kwonlyargs = [
_remove_annotation(arg) for arg in node.args.kwonlyargs
]
node.args.posonlyargs = [
_remove_annotation(arg) for arg in node.args.posonlyargs
]

node.returns = None

return self.generic_visit(node)

# def visit_Compare(self, node: ast.Compare) -> Any:
# """Converts the comparison syntax into CompareNode call."""
# if len(node.ops) > 1:
Expand Down Expand Up @@ -264,14 +283,25 @@ def nodify_func(
# Make sure the first line is at the 0 indentation level.
code = textwrap.dedent(code)

old_signature = inspect.signature(func)

code_obj = nodify_code(
code, transformer_cls, assign_fn, node_cls, namespace=func_namespace
code,
transformer_cls,
assign_fn,
node_cls,
remove_function_annotations=True,
namespace=func_namespace,
)

# Execute the code, and retrieve the new function from the namespace.
exec(code_obj, func_namespace)

return func_namespace[func.__name__]
new_func = func_namespace[func.__name__]

new_func.__signature__ = old_signature

return new_func


def nodify_code(
Expand All @@ -281,6 +311,7 @@ def nodify_code(
node_cls: Type[Node] = Node,
nodify_constants: bool = False,
nodify_constant_assignments: bool = False,
remove_function_annotations: bool = False,
namespace: Optional[dict] = None,
):
if namespace is None:
Expand Down Expand Up @@ -320,6 +351,7 @@ def nodify_code(
node_cls_name=node_cls_name,
nodify_constants=nodify_constants,
nodify_constant_assignments=nodify_constant_assignments,
remove_function_annotations=remove_function_annotations,
)
new_tree = transformer.visit(tree)

Expand Down
14 changes: 13 additions & 1 deletion src/nodify/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


@pytest.fixture(
scope="module", params=["from_func", "explicit_class", "input_operations"]
scope="module",
params=["from_func", "explicit_class", "input_operations", "with_annotations"],
)
def triple_sum(request) -> Type[Workflow]:
"""Returns a workflow that computes a triple sum.
Expand Down Expand Up @@ -56,6 +57,17 @@ def triple_sum(a, b, c):

triple_sum._sum_key = "BinaryOperationNode"

elif request.param == "with_annotations":

with pytest.warns():

@Workflow.from_func
def triple_sum(a: int, b: int, c: int) -> int:
first_sum = a + b
return first_sum + c

triple_sum._sum_key = "BinaryOperationNode"

return triple_sum


Expand Down
6 changes: 3 additions & 3 deletions src/nodify/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,12 +918,12 @@ def assign_workflow_var(value: Any, var_name: str):

# Get the workflow function
work_func = cls.function
# Nodify it, passing the middleware function that will assign the variables to the workflow.
work_func = nodify_func(work_func, assign_fn=assign_workflow_var)

# Get the signature of the function.
sig = inspect.signature(work_func)

# Nodify it, passing the middleware function that will assign the variables to the workflow.
work_func = nodify_func(work_func, assign_fn=assign_workflow_var)

# Run a dryrun of the workflow, so that we can understand how the nodes are connected.
# To this end, nodes must behave lazily.
with temporal_context(lazy=True):
Expand Down

0 comments on commit f832dca

Please sign in to comment.