Skip to content

Commit

Permalink
Add OutputAdapter sede for custom filters (#6985)
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje authored Feb 13, 2024
1 parent ea72759 commit 6a776e6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
9 changes: 7 additions & 2 deletions haystack/components/converters/output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing_extensions import TypeAlias

from haystack import component, default_to_dict, default_from_dict
from haystack.utils.callable_serialization import serialize_callable, deserialize_callable
from haystack.utils.type_serialization import serialize_type, deserialize_type


Expand Down Expand Up @@ -124,13 +125,17 @@ def run(self, **kwargs):
return adapted_outputs

def to_dict(self) -> Dict[str, Any]:
# todo should we serialize the custom filters? And if so, can we do the same as for callback handlers?
return default_to_dict(self, template=self.template, output_type=serialize_type(self.output_type))
se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
return default_to_dict(
self, template=self.template, output_type=serialize_type(self.output_type), custom_filters=se_filters
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OutputAdapter":
init_params = data.get("init_parameters", {})
init_params["output_type"] = deserialize_type(init_params["output_type"])
for name, filter_func in init_params.get("custom_filters", {}).items():
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
return default_from_dict(cls, data)

def _extract_variables(self, env: NativeEnvironment) -> Set[str]:
Expand Down
42 changes: 42 additions & 0 deletions test/components/converters/test_output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from haystack.components.converters.output_adapter import OutputAdaptationException


def custom_filter_to_sede(value):
return value.upper()


def another_custom_filter(value):
return value.upper()


class TestOutputAdapter:
# OutputAdapter can be initialized with a valid Jinja2 template string and output type.
def test_initialized_with_valid_template_and_output_type(self):
Expand Down Expand Up @@ -84,6 +92,40 @@ def test_sede(self):
assert adapter.template == deserialized_adapter.template
assert adapter.output_type == deserialized_adapter.output_type

# OutputAdapter can be serialized to a dictionary and deserialized along with custom filters
def test_sede_with_custom_filters(self):
# NOTE: filters need to be declared in a namespace visible to the deserialization function
custom_filters = {"custom_filter": custom_filter_to_sede}
adapter = OutputAdapter(
template="{{ documents[0].content|custom_filter }}", output_type=str, custom_filters=custom_filters
)
adapter_dict = adapter.to_dict()
deserialized_adapter = OutputAdapter.from_dict(adapter_dict)

assert adapter.template == deserialized_adapter.template
assert adapter.output_type == deserialized_adapter.output_type
assert adapter.custom_filters == deserialized_adapter.custom_filters == custom_filters

# invoke the custom filter to check if it is deserialized correctly
assert deserialized_adapter.custom_filters["custom_filter"]("test") == "TEST"

# OutputAdapter can be serialized to a dictionary and deserialized along with multiple custom filters
def test_sede_with_multiple_custom_filters(self):
# NOTE: filters need to be declared in a namespace visible to the deserialization function
custom_filters = {"custom_filter": custom_filter_to_sede, "another_custom_filter": another_custom_filter}
adapter = OutputAdapter(
template="{{ documents[0].content|custom_filter }}", output_type=str, custom_filters=custom_filters
)
adapter_dict = adapter.to_dict()
deserialized_adapter = OutputAdapter.from_dict(adapter_dict)

assert adapter.template == deserialized_adapter.template
assert adapter.output_type == deserialized_adapter.output_type
assert adapter.custom_filters == deserialized_adapter.custom_filters == custom_filters

# invoke the custom filter to check if it is deserialized correctly
assert deserialized_adapter.custom_filters["custom_filter"]("test") == "TEST"

def test_output_adapter_in_pipeline(self):
@component
class DocumentProducer:
Expand Down

0 comments on commit 6a776e6

Please sign in to comment.