diff --git a/CHANGELOG.md b/CHANGELOG.md index 672fd065..44af4328 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] ### Added +- Added `AsyncSearch#collapse` ([827](https://github.com/opensearch-project/opensearch-py/pull/827)) ### Changed ### Deprecated ### Removed diff --git a/opensearchpy/_async/helpers/search.py b/opensearchpy/_async/helpers/search.py index e6a52428..c82999ce 100644 --- a/opensearchpy/_async/helpers/search.py +++ b/opensearchpy/_async/helpers/search.py @@ -8,7 +8,7 @@ # GitHub history for details. import copy -from typing import Any, Sequence +from typing import Any, Dict, Sequence, cast from opensearchpy._async.helpers.actions import aiter, async_scan from opensearchpy.connection.async_connections import get_connection @@ -39,6 +39,7 @@ def __init__(self, **kwargs: Any) -> None: self.aggs = AggsProxy(self) self._sort: Sequence[Any] = [] + self._collapse: Dict[str, Any] = {} self._source: Any = None self._highlight: Any = {} self._highlight_opts: Any = {} @@ -111,13 +112,13 @@ def from_dict(cls, d: Any) -> Any: s.update_from_dict(d) return s - def _clone(self) -> Any: + def _clone(self) -> "AsyncSearch": """ Return a clone of the current search request. Performs a shallow copy of all the underlying objects. Used internally by most state modifying APIs. """ - s = super()._clone() + s = cast(AsyncSearch, super()._clone()) s._response_class = self._response_class s._sort = self._sort[:] @@ -126,6 +127,7 @@ def _clone(self) -> Any: s._highlight_opts = self._highlight_opts.copy() s._suggest = self._suggest.copy() s._script_fields = self._script_fields.copy() + s._collapse = self._collapse.copy() for x in ("query", "post_filter"): getattr(s, x)._proxied = getattr(self, x)._proxied @@ -281,6 +283,34 @@ def sort(self, *keys: Any) -> Any: s._sort.append(k) return s + def collapse( + self, + field: Any = None, + inner_hits: Any = None, + max_concurrent_group_searches: Any = None, + ) -> "AsyncSearch": + """ + Add collapsing information to the search request. + + If called without providing ``field``, it will remove all collapse + requirements, otherwise it will replace them with the provided + arguments. + + The API returns a copy of the AsyncSearch object and can thus be chained. + """ + s = self._clone() + s._collapse = {} + + if field is None: + return s + + s._collapse["field"] = field + if inner_hits: + s._collapse["inner_hits"] = inner_hits + if max_concurrent_group_searches: + s._collapse["max_concurrent_group_searches"] = max_concurrent_group_searches + return s + def highlight_options(self, **kwargs: Any) -> Any: """ Update the global highlighting options used for this request. For @@ -376,6 +406,9 @@ def to_dict(self, count: bool = False, **kwargs: Any) -> Any: if self._sort: d["sort"] = self._sort + if self._collapse: + d["collapse"] = self._collapse + d.update(recursive_to_dict(self._extra)) if self._source not in (None, {}): diff --git a/test_opensearchpy/test_async/test_helpers/test_search.py b/test_opensearchpy/test_async/test_helpers/test_search.py index d01f0b80..81ea75a4 100644 --- a/test_opensearchpy/test_async/test_helpers/test_search.py +++ b/test_opensearchpy/test_async/test_helpers/test_search.py @@ -240,6 +240,40 @@ async def test_sort_by_score() -> None: s.sort("-_score") +def test_collapse() -> None: + s = search.AsyncSearch() + + inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]} + s = s.collapse( + field="user.id", inner_hits=inner_hits, max_concurrent_group_searches=4 + ) + + assert { + "field": "user.id", + "inner_hits": { + "name": "most_recent", + "size": 5, + "sort": [{"@timestamp": "desc"}], + }, + "max_concurrent_group_searches": 4, + } == s._collapse + assert { + "collapse": { + "field": "user.id", + "inner_hits": { + "name": "most_recent", + "size": 5, + "sort": [{"@timestamp": "desc"}], + }, + "max_concurrent_group_searches": 4, + } + } == s.to_dict() + + s = s.collapse() + assert {} == s._collapse + assert search.AsyncSearch().to_dict() == s.to_dict() + + async def test_slice() -> None: s = search.AsyncSearch() assert {"from": 3, "size": 7} == s[3:10].to_dict() @@ -546,3 +580,19 @@ async def test_rescore_query_to_dict() -> None: }, }, } + + +def test_collapse_chaining() -> None: + s = search.AsyncSearch(index="index_name") + s = s.filter("term", color="red") + s = s.collapse(field="category") + s = s.filter("term", brand="something") + + assert { + "query": { + "bool": { + "filter": [{"term": {"color": "red"}}, {"term": {"brand": "something"}}] + } + }, + "collapse": {"field": "category"}, + } == s.to_dict()