Skip to content

Commit

Permalink
feat: Add ByteStream to_string method (#7009)
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje authored Feb 17, 2024
1 parent 3f85a63 commit 3ce6b97
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
10 changes: 10 additions & 0 deletions haystack/dataclasses/byte_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ def from_string(
:param meta: Additional metadata to be stored with the ByteStream.
"""
return cls(data=text.encode(encoding), mime_type=mime_type, meta=meta or {})

def to_string(self, encoding: str = "utf-8") -> str:
"""
Convert the ByteStream to a string, metadata will not be included.
:param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8".
:return: The string representation of the ByteStream.
:raises UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding.
"""
return self.data.decode(encoding)
26 changes: 26 additions & 0 deletions test/dataclasses/test_byte_stream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from haystack.dataclasses import ByteStream


Expand Down Expand Up @@ -35,6 +37,30 @@ def test_from_string():
assert b.meta == {"foo": "bar"}


def test_to_string():
test_string = "Hello, world!"
b = ByteStream.from_string(test_string)
assert b.to_string() == test_string


def test_to_from_string_encoding():
test_string = "Hello Baščaršija!"
with pytest.raises(UnicodeEncodeError):
ByteStream.from_string(test_string, encoding="ISO-8859-1")

bs = ByteStream.from_string(test_string) # default encoding is utf-8

assert bs.to_string(encoding="ISO-8859-1") != test_string
assert bs.to_string(encoding="utf-8") == test_string


def test_to_string_encoding_error():
# test that it raises ValueError if the encoding is not valid
b = ByteStream.from_string("Hello, world!")
with pytest.raises(UnicodeDecodeError):
b.to_string("utf-16")


def test_to_file(tmp_path, request):
test_str = "Hello, world!\n"
test_path = tmp_path / request.node.name
Expand Down

0 comments on commit 3ce6b97

Please sign in to comment.