diff --git a/hamilton/function_modifiers.py b/hamilton/function_modifiers.py index a2c75bba..72c08a27 100644 --- a/hamilton/function_modifiers.py +++ b/hamilton/function_modifiers.py @@ -431,8 +431,7 @@ def expand_node( fn = node_.callable base_doc = node_.documentation - @functools.wraps(fn) - def df_generator(*args, **kwargs): + def df_generator(*args, **kwargs) -> pd.DataFrame: df_generated = fn(*args, **kwargs) if self.fill_with is not None: for col in self.columns: @@ -441,12 +440,8 @@ def df_generator(*args, **kwargs): return df_generated output_nodes = [ - node.Node( - node_.name, - typ=pd.DataFrame, - doc_string=base_doc, + node_.copy_with( callabl=df_generator, - tags=node_.tags.copy(), ) ] @@ -553,7 +548,6 @@ def expand_node( fn = node_.callable base_doc = node_.documentation - @functools.wraps(fn) def dict_generator(*args, **kwargs): dict_generated = fn(*args, **kwargs) if self.fill_with is not None: @@ -562,15 +556,7 @@ def dict_generator(*args, **kwargs): dict_generated[field] = self.fill_with return dict_generated - output_nodes = [ - node.Node( - node_.name, - typ=dict, - doc_string=base_doc, - callabl=dict_generator, - tags=node_.tags.copy(), - ) - ] + output_nodes = [node_.copy_with(callabl=dict_generator)] for field, field_type in self.fields.items(): doc_string = base_doc # default doc string of base function. @@ -744,7 +730,8 @@ def generate_node(self, fn: Callable, config) -> node.Node: and the same parameters/types as the original function. """ - def replacing_function(__fn=fn, **kwargs): + # @functools.wraps(fn) + def wrapper_function(**kwargs): final_kwarg_values = { key: param_spec.default for key, param_spec in inspect.signature(fn).parameters.items() @@ -754,7 +741,7 @@ def replacing_function(__fn=fn, **kwargs): final_kwarg_values = does.map_kwargs(final_kwarg_values, self.argument_mapping) return self.replacing_function(**final_kwarg_values) - return node.Node.from_fn(fn).copy_with(callabl=replacing_function) + return node.Node.from_fn(fn).copy_with(callabl=wrapper_function) class dynamic_transform(function_modifiers_base.NodeCreator): diff --git a/tests/resources/multiple_decorators_together.py b/tests/resources/multiple_decorators_together.py new file mode 100644 index 00000000..d911d967 --- /dev/null +++ b/tests/resources/multiple_decorators_together.py @@ -0,0 +1,37 @@ +import pandas as pd + +from hamilton.function_modifiers import does, extract_columns, extract_fields, tag + + +def _sum_multiply(param0: int, param1: int, param2: int) -> pd.DataFrame: + return pd.DataFrame([{"param0a": param0, "param1b": param1, "param2c": param2}]) + + +def _sum(param0: int, param1: int, param2: int) -> dict: + return {"total": param0 + param1 + param2} + + +@extract_columns("param1b") +@does(_sum_multiply) +def to_modify(param0: int, param1: int, param2: int = 2) -> pd.DataFrame: + """This is a dummy function showing extract_columns with does.""" + pass + + +@extract_fields({"total": int}) +@does(_sum) +def to_modify_2(param0: int, param1: int, param2: int = 2) -> dict: + """This is a dummy function showing extract_fields with does.""" + pass + + +def _dummy(**values) -> dict: + return {f"out_{k.split('_')[1]}": v for k, v in values.items()} + + +@extract_fields({"out_value1": int, "out_value2": str}) +@tag(test_key="test-value") +# @check_output(data_type=dict, importance="fail") To fix see https://github.com/stitchfix/hamilton/issues/249 +@does(_dummy) +def uber_decorated_function(in_value1: int, in_value2: str) -> dict: + pass diff --git a/tests/test_graph.py b/tests/test_graph.py index 3d131630..c556ab09 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -14,6 +14,7 @@ import tests.resources.extract_columns_execution_count import tests.resources.functions_with_generics import tests.resources.layered_decorators +import tests.resources.multiple_decorators_together import tests.resources.optional_dependencies import tests.resources.parametrized_inputs import tests.resources.parametrized_nodes @@ -359,6 +360,55 @@ def test_end_to_end_with_column_extractor_nodes(): ) +def test_end_to_end_with_multiple_decorators(): + """Tests that a simple function graph with multiple decorators on a function works end-to-end""" + fg = graph.FunctionGraph( + tests.resources.multiple_decorators_together, + config={"param0": 3, "param1": 1, "in_value1": 42, "in_value2": "string_value"}, + ) + nodes = fg.get_nodes() + # To help debug issues: + # nodez, user_nodes = fg.get_upstream_nodes([n.name for n in nodes], + # {"param0": 3, "param1": 1, + # "in_value1": 42, "in_value2": "string_value"}) + # fg.display( + # nodez, + # user_nodes, + # "all_multiple_decorators", + # render_kwargs=None, + # graphviz_kwargs=None, + # ) + results = fg.execute(nodes, {}, {}) + df_expected = tests.resources.multiple_decorators_together._sum_multiply(3, 1, 2) + dict_expected = tests.resources.multiple_decorators_together._sum(3, 1, 2) + pd.testing.assert_series_equal(results["param1b"], df_expected["param1b"]) + pd.testing.assert_frame_equal(results["to_modify"], df_expected) + assert results["total"] == dict_expected["total"] + assert results["to_modify_2"] == dict_expected + node_dict = {n.name: n for n in nodes} + print(sorted(list(node_dict.keys()))) + assert ( + node_dict["to_modify"].documentation + == "This is a dummy function showing extract_columns with does." + ) + assert ( + node_dict["to_modify_2"].documentation + == "This is a dummy function showing extract_fields with does." + ) + # tag only applies right now to outer most node layer + assert node_dict["uber_decorated_function"].tags == { + "module": "tests.resources.multiple_decorators_together" + } # tags are not propagated + assert node_dict["out_value1"].tags == { + "module": "tests.resources.multiple_decorators_together", + "test_key": "test-value", + } + assert node_dict["out_value2"].tags == { + "module": "tests.resources.multiple_decorators_together", + "test_key": "test-value", + } + + def test_end_to_end_with_config_modifier(): config = { "fn_1_version": 1,