Skip to content

Commit

Permalink
Improvements to edge cases of CheckStream (#21404)
Browse files Browse the repository at this point in the history
* Add test for failure case

* Except StopIteration - make test pass

* Don't attempt to connect to a stream if we get no stream slices

* Make helper method for getting first record for a slice

* Add comments and exit early if stream to check isn't in list of source streams

* move helpers to helper module

* Clarify what it means when StopIteration is returned by helper methods
  • Loading branch information
erohmensing authored Jan 13, 2023
1 parent 55a4715 commit d378294
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
#

import logging
import traceback
from dataclasses import InitVar, dataclass
from typing import Any, List, Mapping, Tuple

from airbyte_cdk.models.airbyte_protocol import SyncMode
from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker
from airbyte_cdk.sources.source import Source
from airbyte_cdk.sources.streams.utils.stream_helper import get_first_record_for_slice, get_first_stream_slice
from dataclasses_jsonschema import JsonSchemaMixin


Expand All @@ -28,34 +29,33 @@ def __post_init__(self, options: Mapping[str, Any]):
self._options = options

def check_connection(self, source: Source, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, any]:
"""Check configuration parameters for a source by attempting to get the first record for each stream in the CheckStream's `stream_name` list."""
streams = source.streams(config)
stream_name_to_stream = {s.name: s for s in streams}
if len(streams) == 0:
return False, f"No streams to connect to from source {source}"
for stream_name in self.stream_names:
if stream_name in stream_name_to_stream.keys():
stream = stream_name_to_stream[stream_name]
try:
# Some streams need a stream slice to read records (eg if they have a SubstreamSlicer)
stream_slice = self._get_stream_slice(stream)
records = stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice)
next(records)
except Exception as error:
return False, f"Unable to connect to stream {stream_name} - {error}"
else:
raise ValueError(f"{stream_name} is not part of the catalog. Expected one of {stream_name_to_stream.keys()}")
return True, None

def _get_stream_slice(self, stream):
# We wrap the return output of stream_slices() because some implementations return types that are iterable,
# but not iterators such as lists or tuples
slices = iter(
stream.stream_slices(
cursor_field=stream.cursor_field,
sync_mode=SyncMode.full_refresh,
)
)
try:
return next(slices)
except StopIteration:
return {}
if stream_name not in stream_name_to_stream.keys():
raise ValueError(f"{stream_name} is not part of the catalog. Expected one of {stream_name_to_stream.keys()}.")
stream = stream_name_to_stream[stream_name]

try:
# Some streams need a stream slice to read records (e.g. if they have a SubstreamSlicer)
# Streams that don't need a stream slice will return `None` as their first stream slice.
stream_slice = get_first_stream_slice(stream)
except StopIteration:
# If stream_slices has no `next()` item (Note - this is different from stream_slices returning [None]!)
# This can happen when a substream's `stream_slices` method does a `for record in parent_records: yield <something>`
# without accounting for the case in which the parent stream is empty.
reason = f"Cannot attempt to connect to stream {stream_name} - no stream slices were found, likely because the parent stream is empty."
return False, reason

try:
get_first_record_for_slice(stream, stream_slice)
return True, None
except StopIteration:
logger.info(f"Successfully connected to stream {stream.name}, but got 0 records.")
return True, None
except Exception as error:
logger.error(f"Encountered an error trying to connect to stream {stream.name}. Error: \n {traceback.format_exc()}")
return False, f"Unable to connect to stream {stream_name} - {error}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, Mapping, Optional

from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.streams.core import Stream, StreamData


def get_first_stream_slice(stream) -> Optional[Mapping[str, Any]]:
"""
Gets the first stream_slice from a given stream's stream_slices.
:param stream: stream
:raises StopIteration: if there is no first slice to return (the stream_slices generator is empty)
:return: first stream slice from 'stream_slices' generator (`None` is a valid stream slice)
"""
# We wrap the return output of stream_slices() because some implementations return types that are iterable,
# but not iterators such as lists or tuples
slices = iter(
stream.stream_slices(
cursor_field=stream.cursor_field,
sync_mode=SyncMode.full_refresh,
)
)
return next(slices)


def get_first_record_for_slice(stream: Stream, stream_slice: Optional[Mapping[str, Any]]) -> StreamData:
"""
Gets the first record for a stream_slice of a stream.
:param stream: stream
:param stream_slice: stream_slice
:raises StopIteration: if there is no first record to return (the read_records generator is empty)
:return: StreamData containing the first record in the slice
"""
records_for_slice = stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice)
return next(records_for_slice)
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import logging
from unittest.mock import MagicMock

import pytest
from airbyte_cdk.sources.declarative.checks.check_stream import CheckStream

logger = None
logger = logging.getLogger("test")
config = dict()

stream_names = ["s1"]
Expand Down Expand Up @@ -49,3 +50,31 @@ def test_check_stream_with_slices_as_list(test_name, record, streams_to_check, s

def mock_read_records(responses, default_response=None, **kwargs):
return lambda stream_slice, sync_mode: responses[frozenset(stream_slice)] if frozenset(stream_slice) in responses else default_response


def test_check_empty_stream():
stream = MagicMock()
stream.name = "s1"
stream.read_records.return_value = iter([])
stream.stream_slices.return_value = iter([None])

source = MagicMock()
source.streams.return_value = [stream]

check_stream = CheckStream(["s1"], options={})
stream_is_available, reason = check_stream.check_connection(source, logger, config)
assert stream_is_available


def test_check_stream_with_no_stream_slices_aborts():
stream = MagicMock()
stream.name = "s1"
stream.stream_slices.return_value = iter([])

source = MagicMock()
source.streams.return_value = [stream]

check_stream = CheckStream(["s1"], options={})
stream_is_available, reason = check_stream.check_connection(source, logger, config)
assert not stream_is_available
assert "no stream slices were found, likely because the parent stream is empty" in reason

0 comments on commit d378294

Please sign in to comment.