Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Correct AsyncTransaction Return Types #751

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions google/cloud/firestore_v1/async_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ async def get_all(
query, or :data:`None` if the document does not exist.
"""
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
return await self._client.get_all(references, transaction=self, **kwargs)
async for snapshot in self._client.get_all(
references, transaction=self, **kwargs
):
yield snapshot

async def get(
self,
Expand All @@ -195,11 +198,13 @@ async def get(
"""
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
if isinstance(ref_or_query, AsyncDocumentReference):
return await self._client.get_all(
async for snapshot in self._client.get_all(
[ref_or_query], transaction=self, **kwargs
)
):
yield snapshot
elif isinstance(ref_or_query, AsyncQuery):
return await ref_or_query.stream(transaction=self, **kwargs)
async for snapshot in ref_or_query.stream(transaction=self, **kwargs):
yield snapshot
else:
raise ValueError(
'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.'
Expand Down
39 changes: 22 additions & 17 deletions tests/unit/v1/test_async_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import mock
import pytest
import types

from tests.unit.v1.test__helpers import AsyncMock
from tests.unit.v1.test__helpers import AsyncMock, AsyncIter


def _make_async_transaction(*args, **kwargs):
Expand Down Expand Up @@ -286,19 +287,21 @@ async def test_asynctransaction__commit_failure():
async def _get_all_helper(retry=None, timeout=None):
from google.cloud.firestore_v1 import _helpers

client = AsyncMock(spec=["get_all"])
client = mock.Mock(spec=["get_all"])
response_iterator = AsyncIter([])
client.get_all.return_value = response_iterator
transaction = _make_async_transaction(client)
ref1, ref2 = mock.Mock(), mock.Mock()
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

result = await transaction.get_all([ref1, ref2], **kwargs)

snapshots = transaction.get_all([ref1, ref2], **kwargs)
assert isinstance(snapshots, types.AsyncGeneratorType)
_ = [s async for s in snapshots]
client.get_all.assert_called_once_with(
[ref1, ref2],
transaction=transaction,
**kwargs,
)
assert result is client.get_all.return_value


@pytest.mark.asyncio
Expand All @@ -319,15 +322,17 @@ async def _get_w_document_ref_helper(retry=None, timeout=None):
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1 import _helpers

client = AsyncMock(spec=["get_all"])
client = mock.Mock(spec=["get_all"])
response_iterator = AsyncIter([])
client.get_all.return_value = response_iterator
transaction = _make_async_transaction(client)
ref = AsyncDocumentReference("documents", "doc-id")
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

result = await transaction.get(ref, **kwargs)

snapshots = transaction.get(ref, **kwargs)
assert isinstance(snapshots, types.AsyncGeneratorType)
_ = [s async for s in snapshots]
client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs)
assert result is client.get_all.return_value


@pytest.mark.asyncio
Expand All @@ -351,19 +356,18 @@ async def _get_w_query_helper(retry=None, timeout=None):
client = AsyncMock(spec=[])
transaction = _make_async_transaction(client)
query = AsyncQuery(parent=AsyncMock(spec=[]))
query.stream = AsyncMock()
query.stream = mock.Mock()
response_iterator = AsyncIter([])
query.stream.return_value = response_iterator
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

result = await transaction.get(
query,
**kwargs,
)

snapshots = transaction.get(query, **kwargs)
assert isinstance(snapshots, types.AsyncGeneratorType)
_ = [s async for s in snapshots]
query.stream.assert_called_once_with(
transaction=transaction,
**kwargs,
)
assert result is query.stream.return_value


@pytest.mark.asyncio
Expand All @@ -382,7 +386,8 @@ async def test_asynctransaction_get_failure():
transaction = _make_async_transaction(client)
ref_or_query = object()
with pytest.raises(ValueError):
await transaction.get(ref_or_query)
async for _ in transaction.get(ref_or_query):
pass


def _make_async_transactional(*args, **kwargs):
Expand Down
Loading