Skip to content

Commit

Permalink
CDK: VCR -> requests_cache + SQLite (#17990)
Browse files Browse the repository at this point in the history
Signed-off-by: Sergey Chvalyuk <[email protected]>
  • Loading branch information
grubberr authored Oct 19, 2022
1 parent e21d54f commit 258b23c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 40 deletions.
4 changes: 4 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 0.2.0

- Replace caching method: VCR.py -> requests-cache with SQLite backend

## 0.1.104

- Protocol change: `supported_sync_modes` is now a required properties on AirbyteStream. [#15591](https://github.com/airbytehq/airbyte/pull/15591)
Expand Down
52 changes: 20 additions & 32 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/http/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import logging
import os
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
from urllib.parse import urljoin

import requests
import vcr
import vcr.cassette as Cassette
import requests_cache
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.streams.core import Stream
from requests.auth import AuthBase
from requests_cache.session import CachedSession

from .auth.core import HttpAuthenticator, NoAuth
from .exceptions import DefaultBackoffException, RequestBodyException, UserDefinedBackoffException
Expand All @@ -23,8 +24,6 @@
# list of all possible HTTP methods which can be used for sending of request bodies
BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH")

logging.getLogger("vcr").setLevel(logging.ERROR)


class HttpStream(Stream, ABC):
"""
Expand All @@ -36,25 +35,23 @@ class HttpStream(Stream, ABC):

# TODO: remove legacy HttpAuthenticator authenticator references
def __init__(self, authenticator: Union[AuthBase, HttpAuthenticator] = None):
self._session = requests.Session()
if self.use_cache:
self._session = self.request_cache()
else:
self._session = requests.Session()

self._authenticator: HttpAuthenticator = NoAuth()
if isinstance(authenticator, AuthBase):
self._session.auth = authenticator
elif authenticator:
self._authenticator = authenticator

if self.use_cache:
self.cache_file = self.request_cache()
# we need this attr to get metadata about cassettes, such as record play count, all records played, etc.
self.cassete = None

@property
def cache_filename(self):
"""
Override if needed. Return the name of cache file
"""
return f"{self.name}.yml"
return f"{self.name}.sqlite"

@property
def use_cache(self):
Expand All @@ -63,19 +60,19 @@ def use_cache(self):
"""
return False

def request_cache(self) -> Cassette:
def request_cache(self) -> CachedSession:
self.clear_cache()
return requests_cache.CachedSession(self.cache_filename)

def clear_cache(self):
"""
Builds VCR instance.
It deletes file everytime we create it, normally should be called only once.
We can't use NamedTemporaryFile here because yaml serializer doesn't work well with empty files.
remove cache file only once
"""

try:
os.remove(self.cache_filename)
except FileNotFoundError:
pass

return vcr.use_cassette(self.cache_filename, record_mode="new_episodes", serializer="yaml")
STREAM_CACHE_FILES = globals().setdefault("STREAM_CACHE_FILES", set())
if self.cache_filename not in STREAM_CACHE_FILES:
with suppress(FileNotFoundError):
os.remove(self.cache_filename)
STREAM_CACHE_FILES.add(self.cache_filename)

@property
@abstractmethod
Expand Down Expand Up @@ -415,16 +412,7 @@ def read_records(
)
request_kwargs = self.request_kwargs(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)

if self.use_cache:
# use context manager to handle and store cassette metadata
with self.cache_file as cass:
self.cassete = cass
# vcr tries to find records based on the request, if such records exist, return from cache file
# else make a request and save record in cache file
response = self._send_request(request, request_kwargs)

else:
response = self._send_request(request, request_kwargs)
response = self._send_request(request, request_kwargs)
yield from self.parse_response(response, stream_state=stream_state, stream_slice=stream_slice)

next_page_token = self.next_page_token(response)
Expand Down
4 changes: 2 additions & 2 deletions airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.1.104",
version="0.2.0",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down Expand Up @@ -53,7 +53,7 @@
"pydantic~=1.9.2",
"PyYAML~=5.4",
"requests",
"vcrpy",
"requests_cache",
"Deprecated~=1.2",
"Jinja2~=3.1.2",
],
Expand Down
25 changes: 19 additions & 6 deletions airbyte-cdk/python/unit_tests/sources/streams/http/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,15 @@ def path(self, **kwargs) -> str:

def test_caching_filename():
stream = CacheHttpStream()
assert stream.cache_filename == f"{stream.name}.yml"
assert stream.cache_filename == f"{stream.name}.sqlite"


def test_caching_cassettes_are_different():
def test_caching_sessions_are_different():
stream_1 = CacheHttpStream()
stream_2 = CacheHttpStream()

assert stream_1.cache_file != stream_2.cache_file
assert stream_1._session != stream_2._session
assert stream_1.cache_filename == stream_2.cache_filename


def test_parent_attribute_exist():
Expand All @@ -395,7 +396,7 @@ def test_cache_response(mocker):
mocker.patch.object(stream, "url_base", "https://google.com/")
list(stream.read_records(sync_mode=SyncMode.full_refresh))

with open(stream.cache_filename, "r") as f:
with open(stream.cache_filename, "rb") as f:
assert f.read()


Expand All @@ -414,19 +415,31 @@ def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapp


@patch("airbyte_cdk.sources.streams.core.logging", MagicMock())
def test_using_cache(mocker):
def test_using_cache(mocker, requests_mock):
requests_mock.register_uri("GET", "https://google.com/", text="text")
requests_mock.register_uri("GET", "https://google.com/search", text="text")

parent_stream = CacheHttpStreamWithSlices()
mocker.patch.object(parent_stream, "url_base", "https://google.com/")

assert requests_mock.call_count == 0
assert parent_stream._session.cache.response_count() == 0

for _slice in parent_stream.stream_slices():
list(parent_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=_slice))

assert requests_mock.call_count == 2
assert parent_stream._session.cache.response_count() == 2

child_stream = CacheHttpSubStream(parent=parent_stream)

for _slice in child_stream.stream_slices(sync_mode=SyncMode.full_refresh):
pass

assert parent_stream.cassete.play_count != 0
assert requests_mock.call_count == 2
assert parent_stream._session.cache.response_count() == 2
assert parent_stream._session.cache.has_url("https://google.com/")
assert parent_stream._session.cache.has_url("https://google.com/search")


class AutoFailTrueHttpStream(StubBasicReadHttpStream):
Expand Down

0 comments on commit 258b23c

Please sign in to comment.