diff --git a/haystack/components/converters/output_adapter.py b/haystack/components/converters/output_adapter.py index f6ed4c9ca7..ea7efac8e4 100644 --- a/haystack/components/converters/output_adapter.py +++ b/haystack/components/converters/output_adapter.py @@ -6,6 +6,7 @@ from typing_extensions import TypeAlias from haystack import component, default_to_dict, default_from_dict +from haystack.utils.callable_serialization import serialize_callable, deserialize_callable from haystack.utils.type_serialization import serialize_type, deserialize_type @@ -124,13 +125,17 @@ def run(self, **kwargs): return adapted_outputs def to_dict(self) -> Dict[str, Any]: - # todo should we serialize the custom filters? And if so, can we do the same as for callback handlers? - return default_to_dict(self, template=self.template, output_type=serialize_type(self.output_type)) + se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()} + return default_to_dict( + self, template=self.template, output_type=serialize_type(self.output_type), custom_filters=se_filters + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "OutputAdapter": init_params = data.get("init_parameters", {}) init_params["output_type"] = deserialize_type(init_params["output_type"]) + for name, filter_func in init_params.get("custom_filters", {}).items(): + init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None return default_from_dict(cls, data) def _extract_variables(self, env: NativeEnvironment) -> Set[str]: diff --git a/test/components/converters/test_output_adapter.py b/test/components/converters/test_output_adapter.py index 795cd30d97..7e07c9d58d 100644 --- a/test/components/converters/test_output_adapter.py +++ b/test/components/converters/test_output_adapter.py @@ -7,6 +7,14 @@ from haystack.components.converters.output_adapter import OutputAdaptationException +def custom_filter_to_sede(value): + return value.upper() + + +def another_custom_filter(value): + return value.upper() + + class TestOutputAdapter: # OutputAdapter can be initialized with a valid Jinja2 template string and output type. def test_initialized_with_valid_template_and_output_type(self): @@ -84,6 +92,40 @@ def test_sede(self): assert adapter.template == deserialized_adapter.template assert adapter.output_type == deserialized_adapter.output_type + # OutputAdapter can be serialized to a dictionary and deserialized along with custom filters + def test_sede_with_custom_filters(self): + # NOTE: filters need to be declared in a namespace visible to the deserialization function + custom_filters = {"custom_filter": custom_filter_to_sede} + adapter = OutputAdapter( + template="{{ documents[0].content|custom_filter }}", output_type=str, custom_filters=custom_filters + ) + adapter_dict = adapter.to_dict() + deserialized_adapter = OutputAdapter.from_dict(adapter_dict) + + assert adapter.template == deserialized_adapter.template + assert adapter.output_type == deserialized_adapter.output_type + assert adapter.custom_filters == deserialized_adapter.custom_filters == custom_filters + + # invoke the custom filter to check if it is deserialized correctly + assert deserialized_adapter.custom_filters["custom_filter"]("test") == "TEST" + + # OutputAdapter can be serialized to a dictionary and deserialized along with multiple custom filters + def test_sede_with_multiple_custom_filters(self): + # NOTE: filters need to be declared in a namespace visible to the deserialization function + custom_filters = {"custom_filter": custom_filter_to_sede, "another_custom_filter": another_custom_filter} + adapter = OutputAdapter( + template="{{ documents[0].content|custom_filter }}", output_type=str, custom_filters=custom_filters + ) + adapter_dict = adapter.to_dict() + deserialized_adapter = OutputAdapter.from_dict(adapter_dict) + + assert adapter.template == deserialized_adapter.template + assert adapter.output_type == deserialized_adapter.output_type + assert adapter.custom_filters == deserialized_adapter.custom_filters == custom_filters + + # invoke the custom filter to check if it is deserialized correctly + assert deserialized_adapter.custom_filters["custom_filter"]("test") == "TEST" + def test_output_adapter_in_pipeline(self): @component class DocumentProducer: