Skip to content

Commit

Permalink
Merge branch 'main' into fix-async-transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Oct 23, 2023
2 parents 3a710c6 + 82122e7 commit af9a950
Show file tree
Hide file tree
Showing 32 changed files with 1,975 additions and 443 deletions.
4 changes: 2 additions & 2 deletions .github/.OwlBot.lock.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.
docker:
image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
digest: sha256:fac304457974bb530cc5396abd4ab25d26a469cd3bc97cbfb18c8d4324c584eb
# created: 2023-10-02T21:31:03.517640371Z
digest: sha256:4f9b3b106ad0beafc2c8a415e3f62c1a0cc23cabea115dbe841b848f581cfe99
# created: 2023-10-18T20:26:37.410353675Z
6 changes: 3 additions & 3 deletions .kokoro/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,9 @@ typing-extensions==4.4.0 \
--hash=sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa \
--hash=sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e
# via -r requirements.in
urllib3==1.26.12 \
--hash=sha256:3fa96cf423e6987997fc326ae8df396db2a8b7c667747d47ddd8ecba91f4a74e \
--hash=sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997
urllib3==1.26.18 \
--hash=sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07 \
--hash=sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0
# via
# requests
# twine
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2053,7 +2053,6 @@ def __call__(
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:

r"""Call the cancel operation method over HTTP.
Args:
Expand Down Expand Up @@ -2119,7 +2118,6 @@ def __call__(
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:

r"""Call the delete operation method over HTTP.
Args:
Expand Down Expand Up @@ -2182,7 +2180,6 @@ def __call__(
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operations_pb2.Operation:

r"""Call the get operation method over HTTP.
Args:
Expand Down Expand Up @@ -2249,7 +2246,6 @@ def __call__(
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operations_pb2.ListOperationsResponse:

r"""Call the list operations method over HTTP.
Args:
Expand Down
6 changes: 0 additions & 6 deletions google/cloud/firestore_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,6 @@ def extract_fields(
yield prefix_path, _EmptyDict
else:
for key, value in sorted(document_data.items()):

if expand_dots:
sub_key = FieldPath.from_string(key)
else:
Expand Down Expand Up @@ -503,7 +502,6 @@ def __init__(self, document_data) -> None:
iterator = self._get_document_iterator(prefix_path)

for field_path, value in iterator:

if field_path == prefix_path and value is _EmptyDict:
self.empty_document = True

Expand Down Expand Up @@ -565,7 +563,6 @@ def _get_update_mask(self, allow_empty_mask=False) -> None:
def get_update_pb(
self, document_path, exists=None, allow_empty_mask=False
) -> types.write.Write:

if exists is not None:
current_document = common.Precondition(exists=exists)
else:
Expand Down Expand Up @@ -762,7 +759,6 @@ def _normalize_merge_paths(self, merge) -> list:
return merge_paths

def _apply_merge_paths(self, merge) -> None:

if self.empty_document:
raise ValueError("Cannot merge specific fields with empty document.")

Expand All @@ -773,7 +769,6 @@ def _apply_merge_paths(self, merge) -> None:
self.merge = merge_paths

for merge_path in merge_paths:

if merge_path in self.transform_paths:
self.transform_merge.append(merge_path)

Expand Down Expand Up @@ -1187,7 +1182,6 @@ def deserialize_bundle(
bundle: Optional[FirestoreBundle] = None
data: Dict
for data in _parse_bundle_elements_data(serialized):

# BundleElements are serialized as JSON containing one key outlining
# the type, with all further data nested under that key
keys: List[str] = list(data.keys())
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from google.cloud.firestore_v1.transaction import Transaction


class AsyncCollectionReference(BaseCollectionReference):
class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]):
"""A reference to a collection in a Firestore database.
The collection may already exist or this class can facilitate creation
Expand Down
53 changes: 43 additions & 10 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@
)

from google.cloud.firestore_v1 import async_document
from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from typing import AsyncGenerator, List, Optional, Type

# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING

from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery
if TYPE_CHECKING: # pragma: NO COVER
# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
from google.cloud.firestore_v1.field_path import FieldPath


class AsyncQuery(BaseQuery):
Expand Down Expand Up @@ -222,15 +223,47 @@ def count(
"""Adds a count over the nested query.
Args:
alias
(Optional[str]): The alias for the count
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).count(alias=alias)

def sum(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]:
"""Adds a sum over the nested query.
Args:
field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across.
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).sum(field_ref, alias=alias)

def avg(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]:
"""Adds an avg over the nested query.
Args:
field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across.
alias(Optional[str]): Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
Returns:
:class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`:
An instance of an AsyncAggregationQuery object
"""
return AsyncAggregationQuery(self).avg(field_ref, alias=alias)

async def stream(
self,
transaction=None,
Expand Down Expand Up @@ -292,9 +325,9 @@ async def stream(
yield snapshot

@staticmethod
def _get_collection_reference_class() -> Type[
"firestore_v1.async_collection.AsyncCollectionReference"
]:
def _get_collection_reference_class() -> (
Type["firestore_v1.async_collection.AsyncCollectionReference"]
):
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference

return AsyncCollectionReference
Expand Down
73 changes: 63 additions & 10 deletions google/cloud/firestore_v1/base_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from google.api_core import retry as retries


from google.cloud.firestore_v1.field_path import FieldPath
from google.cloud.firestore_v1.types import RunAggregationQueryResponse

from google.cloud.firestore_v1.types import StructuredAggregationQuery
from google.cloud.firestore_v1 import _helpers

Expand All @@ -60,14 +60,17 @@ def __repr__(self):


class BaseAggregation(ABC):
def __init__(self, alias: str | None = None):
self.alias = alias

@abc.abstractmethod
def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""


class CountAggregation(BaseAggregation):
def __init__(self, alias: str | None = None):
self.alias = alias
super(CountAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
Expand All @@ -77,13 +80,48 @@ def _to_protobuf(self):
return aggregation_pb


class SumAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
super(SumAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum()
aggregation_pb.sum.field.field_path = self.field_ref
return aggregation_pb


class AvgAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
super(AvgAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg()
aggregation_pb.avg.field.field_path = self.field_ref
return aggregation_pb


def _query_response_to_result(
response_pb: RunAggregationQueryResponse,
) -> List[AggregationResult]:
results = [
AggregationResult(
alias=key,
value=response_pb.result.aggregate_fields[key].integer_value,
value=response_pb.result.aggregate_fields[key].integer_value
or response_pb.result.aggregate_fields[key].double_value,
read_time=response_pb.read_time,
)
for key in response_pb.result.aggregate_fields.pb.keys()
Expand All @@ -95,11 +133,9 @@ def _query_response_to_result(
class BaseAggregationQuery(ABC):
"""Represents an aggregation query to the Firestore API."""

def __init__(
self,
nested_query,
) -> None:
def __init__(self, nested_query, alias: str | None = None) -> None:
self._nested_query = nested_query
self._alias = alias
self._collection_ref = nested_query._parent
self._aggregations: List[BaseAggregation] = []

Expand All @@ -115,6 +151,22 @@ def count(self, alias: str | None = None):
self._aggregations.append(count_aggregation)
return self

def sum(self, field_ref: str | FieldPath, alias: str | None = None):
"""
Adds a sum over the nested query
"""
sum_aggregation = SumAggregation(field_ref, alias=alias)
self._aggregations.append(sum_aggregation)
return self

def avg(self, field_ref: str | FieldPath, alias: str | None = None):
"""
Adds an avg over the nested query
"""
avg_aggregation = AvgAggregation(field_ref, alias=alias)
self._aggregations.append(avg_aggregation)
return self

def add_aggregation(self, aggregation: BaseAggregation) -> None:
"""
Adds an aggregation operation to the nested query
Expand Down Expand Up @@ -196,9 +248,10 @@ def stream(
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> Generator[List[AggregationResult], Any, None] | AsyncGenerator[
List[AggregationResult], None
]:
) -> (
Generator[List[AggregationResult], Any, None]
| AsyncGenerator[List[AggregationResult], None]
):
"""Runs the aggregation query.
This sends a``RunAggregationQuery`` RPC and returns an iterator in the stream of ``RunAggregationQueryResponse`` messages.
Expand Down
12 changes: 7 additions & 5 deletions google/cloud/firestore_v1/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,15 @@ def _rpc_metadata(self):

return self._rpc_metadata_internal

def collection(self, *collection_path) -> BaseCollectionReference:
def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]:
raise NotImplementedError

def collection_group(self, collection_id: str) -> BaseQuery:
raise NotImplementedError

def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference:
def _get_collection_reference(
self, collection_id: str
) -> BaseCollectionReference[BaseQuery]:
"""Checks validity of collection_id and then uses subclasses collection implementation.
Args:
Expand Down Expand Up @@ -325,7 +327,7 @@ def _document_path_helper(self, *document_path) -> List[str]:

def recursive_delete(
self,
reference: Union[BaseCollectionReference, BaseDocumentReference],
reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference],
bulk_writer: Optional["BulkWriter"] = None, # type: ignore
) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -459,8 +461,8 @@ def collections(
retry: retries.Retry = None,
timeout: float = None,
) -> Union[
AsyncGenerator[BaseCollectionReference, Any],
Generator[BaseCollectionReference, Any, Any],
AsyncGenerator[BaseCollectionReference[BaseQuery], Any],
Generator[BaseCollectionReference[BaseQuery], Any, Any],
]:
raise NotImplementedError

Expand Down
Loading

0 comments on commit af9a950

Please sign in to comment.