From f914452914b41a3128d65a521c79f65981eeb822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Tue, 9 Jan 2024 13:33:30 -0600 Subject: [PATCH] feat(taps): End RESTStream pagination if an empty page is received (#1918) feat: Auto-detect end of pagination --- docs/guides/pagination-classes.md | 5 -- singer_sdk/pagination.py | 25 --------- singer_sdk/streams/rest.py | 15 +++++- tests/core/rest/test_pagination.py | 87 ++++++++++++++++++++++-------- 4 files changed, 80 insertions(+), 52 deletions(-) diff --git a/docs/guides/pagination-classes.md b/docs/guides/pagination-classes.md index 897712423..c77ef579e 100644 --- a/docs/guides/pagination-classes.md +++ b/docs/guides/pagination-classes.md @@ -83,11 +83,6 @@ class can be used to handle this pattern. from singer_sdk.pagination import BaseOffsetPaginator -class MyPaginator(BaseOffsetPaginator): - def has_more(self, response): - data = response.json() - return data.get("has_more", False) - class MyStream(RESTStream): def get_new_paginator(self): diff --git a/singer_sdk/pagination.py b/singer_sdk/pagination.py index 238740768..5bf55ca4c 100644 --- a/singer_sdk/pagination.py +++ b/singer_sdk/pagination.py @@ -337,19 +337,6 @@ def get_next(self, response: Response) -> str | None: class BasePageNumberPaginator(BaseAPIPaginator[int], metaclass=ABCMeta): """Paginator class for APIs that use page number.""" - @abstractmethod - def has_more(self, response: Response) -> bool: - """Override this method to check if the endpoint has any pages left. - - Args: - response: API response object. - - Returns: - Boolean flag used to indicate if the endpoint has more pages. - - """ - ... - def get_next(self, response: Response) -> int | None: # noqa: ARG002 """Get the next page number. @@ -383,18 +370,6 @@ def __init__( super().__init__(start_value, *args, **kwargs) self._page_size = page_size - @abstractmethod - def has_more(self, response: Response) -> bool: - """Override this method to check if the endpoint has any pages left. - - Args: - response: API response object. - - Returns: - Boolean flag used to indicate if the endpoint has more pages. - """ - ... - def get_next(self, response: Response) -> int | None: # noqa: ARG002 """Get the next page offset. diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index f8dbeadc9..a9cfa568a 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -384,6 +384,7 @@ def request_records(self, context: dict | None) -> t.Iterable[dict]: """ paginator = self.get_new_paginator() decorated_request = self.request_decorator(self._request) + pages = 0 with metrics.http_request_counter(self.name, self.path) as request_counter: request_counter.context = context @@ -396,7 +397,19 @@ def request_records(self, context: dict | None) -> t.Iterable[dict]: resp = decorated_request(prepared_request, context) request_counter.increment() self.update_sync_costs(prepared_request, resp, context) - yield from self.parse_response(resp) + records = iter(self.parse_response(resp)) + try: + first_record = next(records) + except StopIteration: + self.logger.info( + "Pagination stopped after %d pages because no records were " + "found in the last response", + pages, + ) + break + yield first_record + yield from records + pages += 1 paginator.advance(resp) diff --git a/tests/core/rest/test_pagination.py b/tests/core/rest/test_pagination.py index 09e9d04b2..6ef34f139 100644 --- a/tests/core/rest/test_pagination.py +++ b/tests/core/rest/test_pagination.py @@ -4,9 +4,10 @@ import json import typing as t +from urllib.parse import parse_qs, urlparse import pytest -from requests import Response +from requests import PreparedRequest, Response from singer_sdk.helpers.jsonpath import extract_jsonpath from singer_sdk.pagination import ( @@ -20,6 +21,10 @@ SinglePagePaginator, first, ) +from singer_sdk.streams.rest import RESTStream + +if t.TYPE_CHECKING: + from singer_sdk.tap_base import Tap def test_paginator_base_missing_implementation(): @@ -47,26 +52,6 @@ def test_single_page_paginator(): assert paginator.count == 1 -def test_paginator_page_number_missing_implementation(): - """Validate that `BasePageNumberPaginator` implementation requires `has_more`.""" - - with pytest.raises( - TypeError, - match="Can't instantiate abstract class .* '?has_more'?", - ): - BasePageNumberPaginator(1) - - -def test_paginator_offset_missing_implementation(): - """Validate that `BaseOffsetPaginator` implementation requires `has_more`.""" - - with pytest.raises( - TypeError, - match="Can't instantiate abstract class .* '?has_more'?", - ): - BaseOffsetPaginator(0, 100) - - def test_paginator_hateoas_missing_implementation(): """Validate that `BaseHATEOASPaginator` implementation requires `get_next_url`.""" @@ -352,3 +337,63 @@ def get_next_url(self, response: Response) -> str | None: paginator.advance(response) assert paginator.finished assert paginator.count == 3 + + +def test_break_pagination(tap: Tap, caplog: pytest.LogCaptureFixture): + class MyAPIStream(RESTStream[int]): + """My API stream.""" + + name = "my-api-stream" + url_base = "https://my.api.test" + path = "/path/to/resource" + schema = {"type": "object", "properties": {"id": {"type": "integer"}}} # noqa: RUF012 + + def parse_response(self, response: Response) -> t.Iterable[dict]: + return response.json() + + def get_new_paginator(self) -> BasePageNumberPaginator: + return BasePageNumberPaginator(1) + + def get_url_params( + self, + context: dict | None, # noqa: ARG002 + next_page_token: int | None, + ) -> dict[str, t.Any] | str: + params = {} + if next_page_token: + params["page"] = next_page_token + return params + + def _request( + self, + prepared_request: PreparedRequest, + context: dict | None, # noqa: ARG002 + ) -> Response: + r = Response() + r.status_code = 200 + + parsed = urlparse(prepared_request.url) + query = parse_qs(parsed.query) + + if query.get("page", ["1"]) == ["1"]: + r._content = json.dumps( + [ + {"id": 1}, + {"id": 2}, + ] + ).encode() + else: + r._content = json.dumps([]).encode() + + return r + + stream = MyAPIStream(tap=tap) + + records_iter = stream.request_records(context=None) + + next(records_iter) + next(records_iter) + with pytest.raises(StopIteration): + next(records_iter) + + assert "Pagination stopped after 1 pages" in caplog.text