Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(taps): Add api costs hook #704

Merged
merged 14 commits into from
Jun 21, 2022
15 changes: 15 additions & 0 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class Stream(metaclass=abc.ABCMeta):
parent_stream_type: Optional[Type["Stream"]] = None
ignore_parent_replication_key: bool = False

# Internal API cost aggregator
_sync_costs: Dict[str, int] = {}

def __init__(
self,
tap: TapBaseClass,
Expand Down Expand Up @@ -864,6 +867,18 @@ def _write_request_duration_log(
extra_tags["context"] = context
self._write_metric_log(metric=request_duration_metric, extra_tags=extra_tags)

def log_sync_costs(self) -> None:
"""Log a summary of Sync costs.

The costs are calculated via `calculate_sync_cost`.
This method can be overridden to log results in a custom
format. It is only called once at the end of the life of
the stream.
"""
if len(self._sync_costs) > 0:
msg = f"Total Sync costs for stream {self.name}: {self._sync_costs}"
self.logger.info(msg)

def _check_max_record_limit(self, record_count: int) -> None:
"""TODO.

Expand Down
56 changes: 56 additions & 0 deletions singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def request_records(self, context: Optional[dict]) -> Iterable[dict]:
context, next_page_token=next_page_token
)
resp = decorated_request(prepared_request, context)
self.update_sync_costs(prepared_request, resp, context)
yield from self.parse_response(resp)
previous_token = copy.deepcopy(next_page_token)
next_page_token = self.get_next_page_token(
Expand All @@ -343,8 +344,63 @@ def request_records(self, context: Optional[dict]) -> Iterable[dict]:
# Cycle until get_next_page_token() no longer returns a value
finished = not next_page_token

def update_sync_costs(
self,
request: requests.PreparedRequest,
response: requests.Response,
context: Optional[Dict],
) -> Dict[str, int]:
"""Update internal calculation of Sync costs.

Args:
request: the Request object that was just called.
response: the `requests.Response` object
context: the context passed to the call

Returns:
A dict of costs (for the single request) whose keys are
the "cost domains". See `calculate_sync_cost` for details.
"""
call_costs = self.calculate_sync_cost(request, response, context)
self._sync_costs = {
k: self._sync_costs.get(k, 0) + call_costs.get(k, 0)
for k in call_costs.keys()
}
return self._sync_costs

# Overridable:

def calculate_sync_cost(
self,
request: requests.PreparedRequest,
response: requests.Response,
context: Optional[Dict],
) -> Dict[str, int]:
"""Calculate the cost of the last API call made.

This method can optionally be implemented in streams to calculate
the costs (in arbitrary units to be defined by the tap developer)
associated with a single API/network call. The request and response objects
are available in the callback, as well as the context.

The method returns a dict where the keys are arbitrary cost dimensions,
and the values the cost along each dimension for this one call. For
instance: { "rest": 0, "graphql": 42 } for a call to github's graphql API.
All keys should be present in the dict.

This method can be overridden by tap streams. By default it won't do
anything.

Args:
request: the API Request object that was just called.
response: the `requests.Response` object
context: the context passed to the call

Returns:
A dict of accumulated costs whose keys are the "cost domains".
"""
return {}

def prepare_request_payload(
self, context: Optional[dict], next_page_token: Optional[_TToken]
) -> Optional[dict]:
Expand Down
1 change: 1 addition & 0 deletions singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def sync_all(self) -> None:

stream.sync()
stream.finalize_state_progress_markers()
stream.log_sync_costs()

# Command Line Execution

Expand Down
30 changes: 30 additions & 0 deletions tests/core/test_streams.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Stream tests."""

import logging
from typing import Any, Dict, Iterable, List, Optional, cast

import pendulum
Expand Down Expand Up @@ -357,3 +358,32 @@ def test_cached_jsonpath():

# cached objects should point to the same memory location
assert recompiled is compiled


def test_sync_costs_calculation(tap: SimpleTestTap, caplog):
"""Test sync costs are added up correctly."""
fake_request = requests.PreparedRequest()
fake_response = requests.Response()

stream = RestTestStream(tap)

def calculate_test_cost(
request: requests.PreparedRequest,
response: requests.Response,
context: Optional[Dict],
):
return {"dim1": 1, "dim2": 2}

stream.calculate_sync_cost = calculate_test_cost
stream.update_sync_costs(fake_request, fake_response, None)
stream.update_sync_costs(fake_request, fake_response, None)
assert stream._sync_costs == {"dim1": 2, "dim2": 4}

with caplog.at_level(logging.INFO, logger=tap.name):
stream.log_sync_costs()

assert len(caplog.records) == 1

for record in caplog.records:
assert record.levelname == "INFO"
assert f"Total Sync costs for stream {stream.name}" in record.message