Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(taps): End RESTStream pagination if an empty page is received #1918

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/guides/pagination-classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,6 @@ class can be used to handle this pattern.

from singer_sdk.pagination import BaseOffsetPaginator

class MyPaginator(BaseOffsetPaginator):
def has_more(self, response):
data = response.json()
return data.get("has_more", False)


class MyStream(RESTStream):
def get_new_paginator(self):
Expand Down
25 changes: 0 additions & 25 deletions singer_sdk/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,19 +337,6 @@ def get_next(self, response: Response) -> str | None:
class BasePageNumberPaginator(BaseAPIPaginator[int], metaclass=ABCMeta):
"""Paginator class for APIs that use page number."""

@abstractmethod
def has_more(self, response: Response) -> bool:
"""Override this method to check if the endpoint has any pages left.

Args:
response: API response object.

Returns:
Boolean flag used to indicate if the endpoint has more pages.

"""
...

def get_next(self, response: Response) -> int | None: # noqa: ARG002
"""Get the next page number.

Expand Down Expand Up @@ -383,18 +370,6 @@ def __init__(
super().__init__(start_value, *args, **kwargs)
self._page_size = page_size

@abstractmethod
def has_more(self, response: Response) -> bool:
"""Override this method to check if the endpoint has any pages left.

Args:
response: API response object.

Returns:
Boolean flag used to indicate if the endpoint has more pages.
"""
...

def get_next(self, response: Response) -> int | None: # noqa: ARG002
"""Get the next page offset.

Expand Down
15 changes: 14 additions & 1 deletion singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def request_records(self, context: dict | None) -> t.Iterable[dict]:
"""
paginator = self.get_new_paginator()
decorated_request = self.request_decorator(self._request)
pages = 0

with metrics.http_request_counter(self.name, self.path) as request_counter:
request_counter.context = context
Expand All @@ -396,7 +397,19 @@ def request_records(self, context: dict | None) -> t.Iterable[dict]:
resp = decorated_request(prepared_request, context)
request_counter.increment()
self.update_sync_costs(prepared_request, resp, context)
yield from self.parse_response(resp)
records = iter(self.parse_response(resp))
try:
first_record = next(records)
except StopIteration:
self.logger.info(
"Pagination stopped after %d pages because no records were "
"found in the last response",
pages,
)
break
yield first_record
yield from records
pages += 1

paginator.advance(resp)

Expand Down
87 changes: 66 additions & 21 deletions tests/core/rest/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import json
import typing as t
from urllib.parse import parse_qs, urlparse

import pytest
from requests import Response
from requests import PreparedRequest, Response

from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import (
Expand All @@ -20,6 +21,10 @@
SinglePagePaginator,
first,
)
from singer_sdk.streams.rest import RESTStream

if t.TYPE_CHECKING:
from singer_sdk.tap_base import Tap


def test_paginator_base_missing_implementation():
Expand Down Expand Up @@ -47,26 +52,6 @@ def test_single_page_paginator():
assert paginator.count == 1


def test_paginator_page_number_missing_implementation():
"""Validate that `BasePageNumberPaginator` implementation requires `has_more`."""

with pytest.raises(
TypeError,
match="Can't instantiate abstract class .* '?has_more'?",
):
BasePageNumberPaginator(1)


def test_paginator_offset_missing_implementation():
"""Validate that `BaseOffsetPaginator` implementation requires `has_more`."""

with pytest.raises(
TypeError,
match="Can't instantiate abstract class .* '?has_more'?",
):
BaseOffsetPaginator(0, 100)


def test_paginator_hateoas_missing_implementation():
"""Validate that `BaseHATEOASPaginator` implementation requires `get_next_url`."""

Expand Down Expand Up @@ -352,3 +337,63 @@ def get_next_url(self, response: Response) -> str | None:
paginator.advance(response)
assert paginator.finished
assert paginator.count == 3


def test_break_pagination(tap: Tap, caplog: pytest.LogCaptureFixture):
class MyAPIStream(RESTStream[int]):
"""My API stream."""

name = "my-api-stream"
url_base = "https://my.api.test"
path = "/path/to/resource"
schema = {"type": "object", "properties": {"id": {"type": "integer"}}} # noqa: RUF012

def parse_response(self, response: Response) -> t.Iterable[dict]:
return response.json()

def get_new_paginator(self) -> BasePageNumberPaginator:
return BasePageNumberPaginator(1)

def get_url_params(
self,
context: dict | None, # noqa: ARG002
next_page_token: int | None,
) -> dict[str, t.Any] | str:
params = {}
if next_page_token:
params["page"] = next_page_token
return params

def _request(
self,
prepared_request: PreparedRequest,
context: dict | None, # noqa: ARG002
) -> Response:
r = Response()
r.status_code = 200

parsed = urlparse(prepared_request.url)
query = parse_qs(parsed.query)

if query.get("page", ["1"]) == ["1"]:
r._content = json.dumps(
[
{"id": 1},
{"id": 2},
]
).encode()
else:
r._content = json.dumps([]).encode()

return r

stream = MyAPIStream(tap=tap)

records_iter = stream.request_records(context=None)

next(records_iter)
next(records_iter)
with pytest.raises(StopIteration):
next(records_iter)

assert "Pagination stopped after 1 pages" in caplog.text