Skip to content

Commit

Permalink
[ISSUE #20771] adding slices to connector builder read request (#21605)
Browse files Browse the repository at this point in the history
* [ISSUE #20771] adding slices to connector builder read request

* [ISSUE #20771] formatting

* [ISSUE #20771] set flag when limit requests reached (#21619)

* [ISSUE #20771] set flag when limit requests reached

* [ISSUE #20771] assert proper value on test read objects __init__

* [ISSUE #20771] code review and fix edge case
  • Loading branch information
maxi297 authored Jan 20, 2023
1 parent 6631698 commit cf63ee5
Show file tree
Hide file tree
Showing 16 changed files with 417 additions and 85 deletions.
9 changes: 7 additions & 2 deletions airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union

from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
AirbyteLogMessage,
AirbyteMessage,
AirbyteStateMessage,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
Level,
Status,
SyncMode,
)
Expand Down Expand Up @@ -232,7 +235,8 @@ def _read_incremental(
has_slices = False
for _slice in slices:
has_slices = True
logger.debug("Processing stream slice", extra={"slice": _slice})
if logger.isEnabledFor(logging.DEBUG):
yield AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"slice:{json.dumps(_slice)}"))
records = stream_instance.read_records(
sync_mode=SyncMode.incremental,
stream_slice=_slice,
Expand Down Expand Up @@ -281,7 +285,8 @@ def _read_full_refresh(
)
total_records_counter = 0
for _slice in slices:
logger.debug("Processing stream slice", extra={"slice": _slice})
if logger.isEnabledFor(logging.DEBUG):
yield AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"slice:{json.dumps(_slice)}"))
record_data_or_messages = stream_instance.read_records(
stream_slice=_slice,
sync_mode=SyncMode.full_refresh,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@


class ModelToComponentFactory:
def __init__(self, is_test_read=False):
def __init__(self, limit_pages_fetched_per_slice: int = None, limit_slices_fetched: int = None):
self._init_mappings()
self._is_test_read = is_test_read
self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice
self._limit_slices_fetched = limit_slices_fetched

def _init_mappings(self):
self.PYDANTIC_MODEL_TO_CONSTRUCTOR: [Type[BaseModel], Callable] = {
Expand Down Expand Up @@ -482,8 +483,8 @@ def create_default_paginator(self, model: DefaultPaginatorModel, config: Config,
config=config,
options=model.options,
)
if self._is_test_read:
return PaginatorTestReadDecorator(paginator)
if self._limit_pages_fetched_per_slice:
return PaginatorTestReadDecorator(paginator, self._limit_pages_fetched_per_slice)
return paginator

def create_dpath_extractor(self, model: DpathExtractorModel, config: Config, **kwargs) -> DpathExtractor:
Expand Down Expand Up @@ -681,7 +682,7 @@ def create_simple_retriever(self, model: SimpleRetrieverModel, config: Config, *
self._create_component_from_model(model=model.stream_slicer, config=config) if model.stream_slicer else SingleSlice(options={})
)

if self._is_test_read:
if self._limit_slices_fetched:
return SimpleRetrieverTestReadDecorator(
name=model.name,
paginator=paginator,
Expand All @@ -690,6 +691,7 @@ def create_simple_retriever(self, model: SimpleRetrieverModel, config: Config, *
record_selector=record_selector,
stream_slicer=stream_slicer,
config=config,
maximum_number_of_slices=self._limit_slices_fetched,
options=model.options,
)
return SimpleRetriever(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,11 @@ class PaginatorTestReadDecorator(Paginator):
_DEFAULT_PAGINATION_LIMIT = 5

def __init__(self, decorated, maximum_number_of_pages: int = None):
if maximum_number_of_pages and maximum_number_of_pages < 1:
raise ValueError(f"The maximum number of pages on a test read needs to be strictly positive. Got {maximum_number_of_pages}")
self._maximum_number_of_pages = maximum_number_of_pages if maximum_number_of_pages else self._DEFAULT_PAGINATION_LIMIT
self._decorated = decorated
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL
self._maximum_number_of_pages = maximum_number_of_pages if maximum_number_of_pages else self._DEFAULT_PAGINATION_LIMIT

def next_page_token(self, response: requests.Response, last_records: List[Mapping[str, Any]]) -> Optional[Mapping[str, Any]]:
if self._page_count >= self._maximum_number_of_pages:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,18 +417,26 @@ def _parse_records_and_emit_request_and_responses(self, request, response, strea
yield from self.parse_response(response, stream_slice=stream_slice, stream_state=stream_state)


@dataclass
class SimpleRetrieverTestReadDecorator(SimpleRetriever):
"""
In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of
slices that are queried throughout a read command.
"""

_MAXIMUM_NUMBER_OF_SLICES = 5
maximum_number_of_slices: int = 5

def __post_init__(self, options: Mapping[str, Any]):
super().__post_init__(options)
if self.maximum_number_of_slices and self.maximum_number_of_slices < 1:
raise ValueError(
f"The maximum number of slices on a test read needs to be strictly positive. Got {self.maximum_number_of_slices}"
)

def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Optional[StreamState] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
return islice(super().stream_slices(sync_mode=sync_mode, stream_state=stream_state), self._MAXIMUM_NUMBER_OF_SLICES)
return islice(super().stream_slices(sync_mode=sync_mode, stream_state=stream_state), self.maximum_number_of_slices)


def _prepared_request_to_airbyte_message(request: requests.PreparedRequest) -> AirbyteMessage:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,19 +633,25 @@ def test_response_to_airbyte_message(test_name, response_body, response_headers,


def test_limit_stream_slices():
maximum_number_of_slices = 4
stream_slicer = MagicMock()
stream_slicer.stream_slices.return_value = [{"date": f"2022-01-0{day}"} for day in range(1, 10)]
stream_slicer.stream_slices.return_value = _generate_slices(maximum_number_of_slices * 2)
retriever = SimpleRetrieverTestReadDecorator(
name="stream_name",
primary_key=primary_key,
requester=MagicMock(),
paginator=MagicMock(),
record_selector=MagicMock(),
stream_slicer=stream_slicer,
maximum_number_of_slices=maximum_number_of_slices,
options={},
config={},
)

truncated_slices = retriever.stream_slices(sync_mode=SyncMode.incremental, stream_state=None)
truncated_slices = list(retriever.stream_slices(sync_mode=SyncMode.incremental, stream_state=None))

assert truncated_slices == [{"date": f"2022-01-0{day}"} for day in range(1, 6)]
assert truncated_slices == _generate_slices(maximum_number_of_slices)


def _generate_slices(number_of_slices):
return [{"date": f"2022-01-0{day + 1}"} for day in range(number_of_slices)]
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from jsonschema.exceptions import ValidationError
from unittest.mock import patch

logger = logging.getLogger("airbyte")

Expand Down Expand Up @@ -542,6 +543,95 @@ def test_manifest_without_at_least_one_stream(self, construct_using_pydantic_mod
ManifestDeclarativeSource(source_config=manifest, construct_using_pydantic_models=construct_using_pydantic_models)


@patch("airbyte_cdk.sources.declarative.declarative_source.DeclarativeSource.read")
def test_given_debug_when_read_then_set_log_level(self, declarative_source_read):
any_valid_manifest = {
"version": "version",
"definitions": {
"schema_loader": {"name": "{{ options.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml"},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"},
},
"requester": {
"path": "/v3/marketing/lists",
"authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"},
"request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
"streams": [
{
"type": "DeclarativeStream",
"$options": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"},
"schema_loader": {
"name": "{{ options.stream_name }}",
"file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml",
},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {
"type": "CursorPagination",
"cursor_value": "{{ response._metadata.next }}",
"page_size": 10,
},
},
"requester": {
"path": "/v3/marketing/lists",
"authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"},
"request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
{
"type": "DeclarativeStream",
"$options": {"name": "stream_with_custom_requester", "primary_key": "id", "url_base": "https://api.sendgrid.com"},
"schema_loader": {
"name": "{{ options.stream_name }}",
"file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml",
},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {
"type": "CursorPagination",
"cursor_value": "{{ response._metadata.next }}",
"page_size": 10,
},
},
"requester": {
"type": "CustomRequester",
"class_name": "unit_tests.sources.declarative.external_component.SampleCustomComponent",
"path": "/v3/marketing/lists",
"custom_request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
],
"check": {"type": "CheckStream", "stream_names": ["lists"]},
}
source = ManifestDeclarativeSource(source_config=any_valid_manifest, debug=True, construct_using_pydantic_models=True)

debug_logger = logging.getLogger("logger.debug")
list(source.read(debug_logger, {}, {}, {}))

assert debug_logger.isEnabledFor(logging.DEBUG)


def test_generate_schema():
schema_str = ManifestDeclarativeSource.generate_schema()
schema = json.loads(schema_str)
Expand Down
51 changes: 51 additions & 0 deletions airbyte-cdk/python/unit_tests/sources/test_abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,57 @@ def test_valid_full_refresh_read_with_slices(mocker):
assert expected == messages


def test_read_full_refresh_with_slices_sends_slice_messages(mocker):
"""Given the logger is debug and a full refresh, AirbyteMessages are sent for slices"""
debug_logger = logging.getLogger("airbyte.debug")
debug_logger.setLevel(logging.DEBUG)
slices = [{"1": "1"}, {"2": "2"}]
stream = MockStream(
[({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices],
name="s1",
)

mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)

src = MockSource(streams=[stream])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream, SyncMode.full_refresh),
]
)

messages = src.read(debug_logger, {}, catalog)

assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages)))


def test_read_incremental_with_slices_sends_slice_messages(mocker):
"""Given the logger is debug and a incremental, AirbyteMessages are sent for slices"""
debug_logger = logging.getLogger("airbyte.debug")
debug_logger.setLevel(logging.DEBUG)
slices = [{"1": "1"}, {"2": "2"}]
stream = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_slice": s, 'stream_state': {}}, [s]) for s in slices],
name="s1",
)

MockStream.supports_incremental = mocker.PropertyMock(return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)

src = MockSource(streams=[stream])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream, SyncMode.incremental),
]
)

messages = src.read(debug_logger, {}, catalog)

assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages)))


class TestIncrementalRead:
@pytest.mark.parametrize(
"use_legacy",
Expand Down
22 changes: 13 additions & 9 deletions airbyte-cdk/python/unit_tests/sources/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def test_internal_config_limit(abstract_source, catalog):
logger_mock.level = logging.DEBUG
del catalog.streams[1]
STREAM_LIMIT = 2
SLICE_DEBUG_LOG_COUNT = 1
FULL_RECORDS_NUMBER = 3
streams = abstract_source.streams(None)
http_stream = streams[0]
Expand All @@ -398,7 +399,7 @@ def test_internal_config_limit(abstract_source, catalog):

catalog.streams[0].sync_mode = SyncMode.full_refresh
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
# Check if log line matches number of limit
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
Expand All @@ -407,13 +408,13 @@ def test_internal_config_limit(abstract_source, catalog):
# No limit, check if state record produced for incremental stream
catalog.streams[0].sync_mode = SyncMode.incremental
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == FULL_RECORDS_NUMBER + 1
assert len(records) == FULL_RECORDS_NUMBER + SLICE_DEBUG_LOG_COUNT + 1
assert records[-1].type == Type.STATE

# Set limit and check if state is produced when limit is set for incremental stream
logger_mock.reset_mock()
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT + 1
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT + 1
assert records[-1].type == Type.STATE
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
Expand All @@ -425,40 +426,43 @@ def test_internal_config_limit(abstract_source, catalog):

def test_source_config_no_transform(abstract_source, catalog):
logger_mock = MagicMock()
SLICE_DEBUG_LOG_COUNT = 1
logger_mock.level = logging.DEBUG
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [[{"value": 23}] * 5] * 2
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2 * 5
assert [r.record.data for r in records] == [{"value": 23}] * 2 * 5
assert len(records) == 2 * (5 + SLICE_DEBUG_LOG_COUNT)
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": 23}] * 2 * 5
assert http_stream.get_json_schema.call_count == 5
assert non_http_stream.get_json_schema.call_count == 5


def test_source_config_transform(abstract_source, catalog):
logger_mock = MagicMock()
logger_mock.level = logging.DEBUG
SLICE_DEBUG_LOG_COUNT = 2
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
non_http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2
assert [r.record.data for r in records] == [{"value": "23"}] * 2
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}] * 2


def test_source_config_transform_and_no_transform(abstract_source, catalog):
logger_mock = MagicMock()
logger_mock.level = logging.DEBUG
SLICE_DEBUG_LOG_COUNT = 2
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2
assert [r.record.data for r in records] == [{"value": "23"}, {"value": 23}]
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}, {"value": 23}]
Loading

0 comments on commit cf63ee5

Please sign in to comment.