From 3ce6b9768ec1e449ea2f7c07717f276b011b4b3f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Sat, 17 Feb 2024 12:57:42 +0100 Subject: [PATCH] feat: Add ByteStream to_string method (#7009) --- haystack/dataclasses/byte_stream.py | 10 ++++++++++ test/dataclasses/test_byte_stream.py | 26 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/haystack/dataclasses/byte_stream.py b/haystack/dataclasses/byte_stream.py index 80b1c50c3b..ee736c001d 100644 --- a/haystack/dataclasses/byte_stream.py +++ b/haystack/dataclasses/byte_stream.py @@ -49,3 +49,13 @@ def from_string( :param meta: Additional metadata to be stored with the ByteStream. """ return cls(data=text.encode(encoding), mime_type=mime_type, meta=meta or {}) + + def to_string(self, encoding: str = "utf-8") -> str: + """ + Convert the ByteStream to a string, metadata will not be included. + + :param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8". + :return: The string representation of the ByteStream. + :raises UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding. + """ + return self.data.decode(encoding) diff --git a/test/dataclasses/test_byte_stream.py b/test/dataclasses/test_byte_stream.py index 57d444b038..4e4199ba19 100644 --- a/test/dataclasses/test_byte_stream.py +++ b/test/dataclasses/test_byte_stream.py @@ -1,3 +1,5 @@ +import pytest + from haystack.dataclasses import ByteStream @@ -35,6 +37,30 @@ def test_from_string(): assert b.meta == {"foo": "bar"} +def test_to_string(): + test_string = "Hello, world!" + b = ByteStream.from_string(test_string) + assert b.to_string() == test_string + + +def test_to_from_string_encoding(): + test_string = "Hello Baščaršija!" + with pytest.raises(UnicodeEncodeError): + ByteStream.from_string(test_string, encoding="ISO-8859-1") + + bs = ByteStream.from_string(test_string) # default encoding is utf-8 + + assert bs.to_string(encoding="ISO-8859-1") != test_string + assert bs.to_string(encoding="utf-8") == test_string + + +def test_to_string_encoding_error(): + # test that it raises ValueError if the encoding is not valid + b = ByteStream.from_string("Hello, world!") + with pytest.raises(UnicodeDecodeError): + b.to_string("utf-16") + + def test_to_file(tmp_path, request): test_str = "Hello, world!\n" test_path = tmp_path / request.node.name