Skip to content

Commit

Permalink
Low-code: Pass stream_slice to read_records when reading from CheckSt…
Browse files Browse the repository at this point in the history
…ream (#17804)

* Implement a test

* Implement fix

* rename

* extract method

* bump
  • Loading branch information
girarda authored Oct 17, 2022
1 parent 3f2af7d commit df72bbd
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
4 changes: 4 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 0.1.100

- Low-code: Pass stream_slice to read_records when reading from CheckStream

## 0.1.99

- Low-code: Fix default stream schema loader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,22 @@ def check_connection(self, source: Source, logger: logging.Logger, config: Mappi
if stream_name in stream_name_to_stream.keys():
stream = stream_name_to_stream[stream_name]
try:
records = stream.read_records(sync_mode=SyncMode.full_refresh)
# Some streams need a stream slice to read records (eg if they have a SubstreamSlicer)
stream_slice = self._get_stream_slice(stream)
records = stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice)
next(records)
except Exception as error:
return False, f"Unable to connect to stream {stream_name} - {error}"
else:
raise ValueError(f"{stream_name} is not part of the catalog. Expected one of {stream_name_to_stream.keys()}")
return True, None

def _get_stream_slice(self, stream):
slices = stream.stream_slices(
cursor_field=stream.cursor_field,
sync_mode=SyncMode.full_refresh,
)
try:
return next(slices)
except StopIteration:
return {}
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.1.99",
version="0.1.100",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@


@pytest.mark.parametrize(
"test_name, record, streams_to_check, expectation",
"test_name, record, streams_to_check, stream_slice, expectation",
[
("test success check", record, stream_names, (True, None)),
("test fail check", None, stream_names, (True, None)),
("test try to check invalid stream", record, ["invalid_stream_name"], None),
("test_success_check", record, stream_names, {}, (True, None)),
("test_success_check_stream_slice", record, stream_names, {"slice": "slice_value"}, (True, None)),
("test_fail_check", None, stream_names, {}, (True, None)),
("test_try_to_check_invalid stream", record, ["invalid_stream_name"], {}, None),
],
)
def test_check_stream(test_name, record, streams_to_check, expectation):
def test_check_stream(test_name, record, streams_to_check, stream_slice, expectation):
stream = MagicMock()
stream.name = "s1"
stream.read_records.return_value = iter([record])
stream.stream_slices.return_value = iter([stream_slice])
stream.read_records.side_effect = mock_read_records({frozenset(stream_slice): iter([record])})

source = MagicMock()
source.streams.return_value = [stream]
Expand All @@ -38,3 +40,7 @@ def test_check_stream(test_name, record, streams_to_check, expectation):
else:
with pytest.raises(ValueError):
check_stream.check_connection(source, logger, config)


def mock_read_records(responses, default_response=None, **kwargs):
return lambda stream_slice, sync_mode: responses[frozenset(stream_slice)] if frozenset(stream_slice) in responses else default_response

0 comments on commit df72bbd

Please sign in to comment.