Skip to content

Commit

Permalink
feat: add 'from_service_account_info' factory to clients (#706)
Browse files Browse the repository at this point in the history
Closes #705
  • Loading branch information
tseaver authored Dec 16, 2020
1 parent f02a125 commit 94d5f0c
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
DEFAULT_ENDPOINT
)

@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
"""Creates an instance of this client using the provided credentials info.

Args:
info (dict): The service account private key info.
args: Additional arguments to pass to the constructor.
kwargs: Additional arguments to pass to the constructor.

Returns:
{@api.name}: The constructed client.
"""
credentials = service_account.Credentials.from_service_account_info(info)
kwargs["credentials"] = credentials
return cls(*args, **kwargs)

@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def test__get_default_mtls_endpoint():
assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi


def test_{{ service.client_name|snake_case }}_from_service_account_info():
creds = credentials.AnonymousCredentials()
with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory:
factory.return_value = creds
info = {"valid": True}
client = {{ service.client_name }}.from_service_account_info(info)
assert client.transport._credentials == creds

{% if service.host %}assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %}


def test_{{ service.client_name|snake_case }}_from_service_account_file():
creds = credentials.AnonymousCredentials()
with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class {{ service.async_client_name }}:
parse_common_{{ resource_msg.message_type.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.parse_common_{{ resource_msg.message_type.resource_type|snake_case }}_path)
{% endfor %}

from_service_account_info = {{ service.client_name }}.from_service_account_info
from_service_account_file = {{ service.client_name }}.from_service_account_file
from_service_account_json = from_service_account_file

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,22 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
DEFAULT_ENDPOINT
)

@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
"""Creates an instance of this client using the provided credentials info.

Args:
info (dict): The service account private key info.
args: Additional arguments to pass to the constructor.
kwargs: Additional arguments to pass to the constructor.

Returns:
{@api.name}: The constructed client.
"""
credentials = service_account.Credentials.from_service_account_info(info)
kwargs["credentials"] = credentials
return cls(*args, **kwargs)

@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def test__get_default_mtls_endpoint():
assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi


def test_{{ service.client_name|snake_case }}_from_service_account_info():
creds = credentials.AnonymousCredentials()
with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory:
factory.return_value = creds
info = {"valid": True}
client = {{ service.client_name }}.from_service_account_info(info)
assert client.transport._credentials == creds

{% if service.host %}assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %}


@pytest.mark.parametrize("client_class", [{{ service.client_name }}, {{ service.async_client_name }}])
def test_{{ service.client_name|snake_case }}_from_service_account_file(client_class):
creds = credentials.AnonymousCredentials()
Expand Down

0 comments on commit 94d5f0c

Please sign in to comment.