diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index 713c56d821..bbac19dc1c 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -7,6 +7,7 @@ from pyarrow import parquet as pq import daft +from daft.exceptions import ConnectTimeoutError, ReadTimeoutError from daft.filesystem import get_filesystem, get_protocol_from_path from daft.table import MicroPartition, Table @@ -413,7 +414,7 @@ def test_connect_timeout(multithreaded_io): ) ) - with pytest.raises(ValueError, match="HTTP connect timeout"): + with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match=f"timed out when trying to connect to {url}"): MicroPartition.read_parquet(url, io_config=connect_timeout_config, multithreaded_io=multithreaded_io).to_arrow() @@ -434,5 +435,5 @@ def test_read_timeout(multithreaded_io): ) ) - with pytest.raises(ValueError, match="HTTP read timeout"): + with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match=f"timed out when trying to connect to {url}"): MicroPartition.read_parquet(url, io_config=read_timeout_config, multithreaded_io=multithreaded_io).to_arrow() diff --git a/tests/integration/io/test_url_download_public_aws_s3.py b/tests/integration/io/test_url_download_public_aws_s3.py index c4ba79d16e..19b7e55b5e 100644 --- a/tests/integration/io/test_url_download_public_aws_s3.py +++ b/tests/integration/io/test_url_download_public_aws_s3.py @@ -3,6 +3,7 @@ import pytest import daft +from daft.exceptions import ConnectTimeoutError, ReadTimeoutError @pytest.mark.integration() @@ -81,7 +82,7 @@ def test_url_download_aws_s3_public_bucket_native_downloader_with_connect_timeou ) ) - with pytest.raises(ValueError, match="HTTP connect timeout"): + with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match="timed out when trying to connect to"): df = df.with_column( "data", df["urls"].url.download(io_config=connect_timeout_config, use_native_downloader=True) ).collect() @@ -101,7 +102,7 @@ def test_url_download_aws_s3_public_bucket_native_downloader_with_read_timeout(s ) ) - with pytest.raises(ValueError, match="HTTP read timeout"): + with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match="timed out when trying to connect to"): df = df.with_column( "data", df["urls"].url.download(io_config=read_timeout_config, use_native_downloader=True) ).collect()