Skip to content

Commit

Permalink
fix tests to use new exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Apr 30, 2024
1 parent b92c515 commit 54cdaeb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions tests/integration/io/parquet/test_reads_public_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pyarrow import parquet as pq

import daft
from daft.exceptions import ConnectTimeoutError, ReadTimeoutError
from daft.filesystem import get_filesystem, get_protocol_from_path
from daft.table import MicroPartition, Table

Expand Down Expand Up @@ -413,7 +414,7 @@ def test_connect_timeout(multithreaded_io):
)
)

with pytest.raises(ValueError, match="HTTP connect timeout"):
with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match=f"timed out when trying to connect to {url}"):
MicroPartition.read_parquet(url, io_config=connect_timeout_config, multithreaded_io=multithreaded_io).to_arrow()


Expand All @@ -434,5 +435,5 @@ def test_read_timeout(multithreaded_io):
)
)

with pytest.raises(ValueError, match="HTTP read timeout"):
with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match=f"timed out when trying to connect to {url}"):
MicroPartition.read_parquet(url, io_config=read_timeout_config, multithreaded_io=multithreaded_io).to_arrow()
5 changes: 3 additions & 2 deletions tests/integration/io/test_url_download_public_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import daft
from daft.exceptions import ConnectTimeoutError, ReadTimeoutError


@pytest.mark.integration()
Expand Down Expand Up @@ -81,7 +82,7 @@ def test_url_download_aws_s3_public_bucket_native_downloader_with_connect_timeou
)
)

with pytest.raises(ValueError, match="HTTP connect timeout"):
with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match="timed out when trying to connect to"):
df = df.with_column(
"data", df["urls"].url.download(io_config=connect_timeout_config, use_native_downloader=True)
).collect()
Expand All @@ -101,7 +102,7 @@ def test_url_download_aws_s3_public_bucket_native_downloader_with_read_timeout(s
)
)

with pytest.raises(ValueError, match="HTTP read timeout"):
with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match="timed out when trying to connect to"):
df = df.with_column(
"data", df["urls"].url.download(io_config=read_timeout_config, use_native_downloader=True)
).collect()

0 comments on commit 54cdaeb

Please sign in to comment.