Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Alternate fix for does #247

Merged
merged 4 commits into from
Dec 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions hamilton/function_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,7 @@ def expand_node(
fn = node_.callable
base_doc = node_.documentation

@functools.wraps(fn)
def df_generator(*args, **kwargs):
def df_generator(*args, **kwargs) -> pd.DataFrame:
df_generated = fn(*args, **kwargs)
if self.fill_with is not None:
for col in self.columns:
Expand All @@ -441,12 +440,8 @@ def df_generator(*args, **kwargs):
return df_generated

output_nodes = [
node.Node(
node_.name,
typ=pd.DataFrame,
doc_string=base_doc,
node_.copy_with(
callabl=df_generator,
tags=node_.tags.copy(),
)
]

Expand Down Expand Up @@ -553,7 +548,6 @@ def expand_node(
fn = node_.callable
base_doc = node_.documentation

@functools.wraps(fn)
def dict_generator(*args, **kwargs):
dict_generated = fn(*args, **kwargs)
if self.fill_with is not None:
Expand All @@ -562,15 +556,7 @@ def dict_generator(*args, **kwargs):
dict_generated[field] = self.fill_with
return dict_generated

output_nodes = [
node.Node(
node_.name,
typ=dict,
doc_string=base_doc,
callabl=dict_generator,
tags=node_.tags.copy(),
)
]
output_nodes = [node_.copy_with(callabl=dict_generator)]

for field, field_type in self.fields.items():
doc_string = base_doc # default doc string of base function.
Expand Down Expand Up @@ -744,7 +730,8 @@ def generate_node(self, fn: Callable, config) -> node.Node:
and the same parameters/types as the original function.
"""

def replacing_function(__fn=fn, **kwargs):
# @functools.wraps(fn)
elijahbenizzy marked this conversation as resolved.
Show resolved Hide resolved
def wrapper_function(**kwargs):
final_kwarg_values = {
key: param_spec.default
for key, param_spec in inspect.signature(fn).parameters.items()
Expand All @@ -754,7 +741,7 @@ def replacing_function(__fn=fn, **kwargs):
final_kwarg_values = does.map_kwargs(final_kwarg_values, self.argument_mapping)
return self.replacing_function(**final_kwarg_values)

return node.Node.from_fn(fn).copy_with(callabl=replacing_function)
return node.Node.from_fn(fn).copy_with(callabl=wrapper_function)


class dynamic_transform(function_modifiers_base.NodeCreator):
Expand Down
37 changes: 37 additions & 0 deletions tests/resources/multiple_decorators_together.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pandas as pd

from hamilton.function_modifiers import does, extract_columns, extract_fields, tag


def _sum_multiply(param0: int, param1: int, param2: int) -> pd.DataFrame:
return pd.DataFrame([{"param0a": param0, "param1b": param1, "param2c": param2}])


def _sum(param0: int, param1: int, param2: int) -> dict:
return {"total": param0 + param1 + param2}


@extract_columns("param1b")
@does(_sum_multiply)
def to_modify(param0: int, param1: int, param2: int = 2) -> pd.DataFrame:
"""This is a dummy function showing extract_columns with does."""
pass


@extract_fields({"total": int})
@does(_sum)
def to_modify_2(param0: int, param1: int, param2: int = 2) -> dict:
"""This is a dummy function showing extract_fields with does."""
pass


def _dummy(**values) -> dict:
return {f"out_{k.split('_')[1]}": v for k, v in values.items()}


@extract_fields({"out_value1": int, "out_value2": str})
@tag(test_key="test-value")
# @check_output(data_type=dict, importance="fail") To fix see https://github.com/stitchfix/hamilton/issues/249
@does(_dummy)
def uber_decorated_function(in_value1: int, in_value2: str) -> dict:
pass
50 changes: 50 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tests.resources.extract_columns_execution_count
import tests.resources.functions_with_generics
import tests.resources.layered_decorators
import tests.resources.multiple_decorators_together
import tests.resources.optional_dependencies
import tests.resources.parametrized_inputs
import tests.resources.parametrized_nodes
Expand Down Expand Up @@ -359,6 +360,55 @@ def test_end_to_end_with_column_extractor_nodes():
)


def test_end_to_end_with_multiple_decorators():
"""Tests that a simple function graph with multiple decorators on a function works end-to-end"""
fg = graph.FunctionGraph(
tests.resources.multiple_decorators_together,
config={"param0": 3, "param1": 1, "in_value1": 42, "in_value2": "string_value"},
)
nodes = fg.get_nodes()
# To help debug issues:
# nodez, user_nodes = fg.get_upstream_nodes([n.name for n in nodes],
# {"param0": 3, "param1": 1,
# "in_value1": 42, "in_value2": "string_value"})
# fg.display(
# nodez,
# user_nodes,
# "all_multiple_decorators",
# render_kwargs=None,
# graphviz_kwargs=None,
# )
results = fg.execute(nodes, {}, {})
df_expected = tests.resources.multiple_decorators_together._sum_multiply(3, 1, 2)
dict_expected = tests.resources.multiple_decorators_together._sum(3, 1, 2)
pd.testing.assert_series_equal(results["param1b"], df_expected["param1b"])
pd.testing.assert_frame_equal(results["to_modify"], df_expected)
assert results["total"] == dict_expected["total"]
assert results["to_modify_2"] == dict_expected
node_dict = {n.name: n for n in nodes}
print(sorted(list(node_dict.keys())))
assert (
node_dict["to_modify"].documentation
== "This is a dummy function showing extract_columns with does."
)
assert (
node_dict["to_modify_2"].documentation
== "This is a dummy function showing extract_fields with does."
)
# tag only applies right now to outer most node layer
assert node_dict["uber_decorated_function"].tags == {
"module": "tests.resources.multiple_decorators_together"
} # tags are not propagated
assert node_dict["out_value1"].tags == {
"module": "tests.resources.multiple_decorators_together",
"test_key": "test-value",
}
assert node_dict["out_value2"].tags == {
"module": "tests.resources.multiple_decorators_together",
"test_key": "test-value",
}


def test_end_to_end_with_config_modifier():
config = {
"fn_1_version": 1,
Expand Down