Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source Salesforce: fix pagination in REST API streams #9151

Merged
merged 10 commits into from
Jan 18, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,29 @@ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
def url_base(self) -> str:
return self.sf_api.instance_url

def path(self, **kwargs) -> str:
def path(self, next_page_token: Mapping[str, Any] = None, **kwargs) -> str:
if next_page_token:
"""
If `next_page_token` is set, subsequent requests use `nextRecordsUrl`.
"""
return next_page_token
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nextRecordsUrl is relative url, so we can return it here instead original url

return f"/services/data/{self.sf_api.version}/queryAll"

def next_page_token(self, response: requests.Response) -> str:
response_data = response.json()
if len(response_data["records"]) == self.page_size and self.primary_key and self.name not in UNSUPPORTED_FILTERING_STREAMS:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main issue was here, self.page_size = 2000, but real records count per page was different than 2000, so this method returned None: only the first page was read in this case.

return f"WHERE {self.primary_key} >= '{response_data['records'][-1][self.primary_key]}' "
return response_data.get("nextRecordsUrl")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vitaliizazmic
If REST API query result has more than 1 page, the next page url is present in response body as nextRecordsUrl.


def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
"""
Salesforce SOQL Query: https://developer.salesforce.com/docs/atlas.en-us.232.0.api_rest.meta/api_rest/dome_queryall.htm
"""
if next_page_token:
"""
If `next_page_token` is set, subsequent requests use `nextRecordsUrl`, and do not include any parameters.
"""
return {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we trying to get 2nd and more pages we don't need to send any params, use just nextRecordsUrl for api call.


selected_properties = self.get_json_schema().get("properties", {})

Expand All @@ -70,11 +79,9 @@ def request_params(
}

query = f"SELECT {','.join(selected_properties.keys())} FROM {self.name} "
if next_page_token:
query += next_page_token

if self.primary_key and self.name not in UNSUPPORTED_FILTERING_STREAMS:
query += f"ORDER BY {self.primary_key} ASC LIMIT {self.page_size}"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in REST API, we don't need to LIMIT results set, because limit does not work for all streams.
I noticed, that when we limit by 2000 (page_size) some stream always returns 1000 records per page, other stream 465 per page.

query += f"ORDER BY {self.primary_key} ASC"

return {"q": query}

Expand Down Expand Up @@ -259,6 +266,32 @@ def next_page_token(self, last_record: dict) -> str:
if self.primary_key and self.name not in UNSUPPORTED_FILTERING_STREAMS:
return f"WHERE {self.primary_key} >= '{last_record[self.primary_key]}' "

def request_params(
Copy link
Contributor Author

@augan-rymkhan augan-rymkhan Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the change, request_params was inherited from SalesfroceStream. But in this PR, parent's method was changed to respect new pagination approach. So I just added this method here not to break BULK API functionality.

self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
"""
Salesforce SOQL Query: https://developer.salesforce.com/docs/atlas.en-us.232.0.api_rest.meta/api_rest/dome_queryall.htm
"""

selected_properties = self.get_json_schema().get("properties", {})

# Salesforce BULK API currently does not support loading fields with data type base64 and compound data
if self.sf_api.api_type == "BULK":
selected_properties = {
key: value
for key, value in selected_properties.items()
if value.get("format") != "base64" and "object" not in value["type"]
}

query = f"SELECT {','.join(selected_properties.keys())} FROM {self.name} "
if next_page_token:
query += next_page_token

if self.primary_key and self.name not in UNSUPPORTED_FILTERING_STREAMS:
query += f"ORDER BY {self.primary_key} ASC LIMIT {self.page_size}"

return {"q": query}

def read_records(
self,
sync_mode: SyncMode,
Expand Down Expand Up @@ -305,14 +338,15 @@ def format_start_date(start_date: Optional[str]) -> Optional[str]:
if start_date:
return pendulum.parse(start_date).strftime("%Y-%m-%dT%H:%M:%SZ")

def next_page_token(self, response: requests.Response) -> str:
response_data = response.json()
if len(response_data["records"]) == self.page_size and self.name not in UNSUPPORTED_FILTERING_STREAMS:
return response_data["records"][-1][self.cursor_field]

def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
if next_page_token:
"""
If `next_page_token` is set, subsequent requests use `nextRecordsUrl`, and do not include any parameters.
"""
return {}

selected_properties = self.get_json_schema().get("properties", {})

# Salesforce BULK API currently does not support loading fields with data type base64 and compound data
Expand All @@ -324,13 +358,13 @@ def request_params(
}

stream_date = stream_state.get(self.cursor_field)
start_date = next_page_token or stream_date or self.start_date
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need next_page_token here, because next_page_token is relative URL for the next chunk of results.

start_date = stream_date or self.start_date

query = f"SELECT {','.join(selected_properties.keys())} FROM {self.name} "
if start_date:
query += f"WHERE {self.cursor_field} >= {start_date} "
if self.name not in UNSUPPORTED_FILTERING_STREAMS:
query += f"ORDER BY {self.cursor_field} ASC LIMIT {self.page_size}"
query += f"ORDER BY {self.cursor_field} ASC"
return {"q": query}

@property
Expand All @@ -352,3 +386,26 @@ class BulkIncrementalSalesforceStream(BulkSalesforceStream, IncrementalSalesforc
def next_page_token(self, last_record: dict) -> str:
if self.name not in UNSUPPORTED_FILTERING_STREAMS:
return last_record[self.cursor_field]

def request_params(
Copy link
Contributor Author

@augan-rymkhan augan-rymkhan Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the change, this method was inherited from IncrementalSalesforceStream. In this PR request_params is overriden not to break existing functionality, because parents methods were changed.

self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
selected_properties = self.get_json_schema().get("properties", {})

# Salesforce BULK API currently does not support loading fields with data type base64 and compound data
if self.sf_api.api_type == "BULK":
selected_properties = {
key: value
for key, value in selected_properties.items()
if value.get("format") != "base64" and "object" not in value["type"]
}

stream_date = stream_state.get(self.cursor_field)
start_date = next_page_token or stream_date or self.start_date

query = f"SELECT {','.join(selected_properties.keys())} FROM {self.name} "
if start_date:
query += f"WHERE {self.cursor_field} >= {start_date} "
if self.name not in UNSUPPORTED_FILTERING_STREAMS:
query += f"ORDER BY {self.cursor_field} ASC LIMIT {self.page_size}"
return {"q": query}
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,45 @@ def test_discover_with_streams_criteria_param(streams_criteria, predicted_filter
)
filtered_streams = sf_object.get_validated_streams(config=updated_config)
assert sorted(filtered_streams) == sorted(predicted_filtered_streams)


def test_pagination_rest(stream_rest_config, stream_rest_api):
stream: SalesforceStream = _generate_stream("Account", stream_rest_config, stream_rest_api)
stream._wait_timeout = 0.1 # maximum wait timeout will be 6 seconds
next_page_url = "/services/data/v52.0/query/012345"
with requests_mock.Mocker() as m:
resp_1 = {
"done": False,
"totalSize": 4,
"nextRecordsUrl": next_page_url,
"records": [
{
"ID": 1,
"LastModifiedDate": "2021-11-15",
},
{
"ID": 2,
"LastModifiedDate": "2021-11-16",
},
],
}
resp_2 = {
"done": True,
"totalSize": 4,
"records": [
{
"ID": 3,
"LastModifiedDate": "2021-11-17",
},
{
"ID": 4,
"LastModifiedDate": "2021-11-18",
},
],
}

m.register_uri("GET", stream.path(), json=resp_1)
m.register_uri("GET", next_page_url, json=resp_2)

records = [record for record in stream.read_records(sync_mode=SyncMode.full_refresh)]
assert len(records) == 4