Skip to content

Commit

Permalink
Added support for ADDSCORES modifier
Browse files Browse the repository at this point in the history
  • Loading branch information
vladvildanov committed Jul 22, 2024
1 parent fd0b0d3 commit e82dc8e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
11 changes: 11 additions & 0 deletions redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None:
self._verbatim = False
self._cursor = []
self._dialect = None
self._add_scores = False

def load(self, *fields: List[str]) -> "AggregateRequest":
"""
Expand Down Expand Up @@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest":
self._with_schema = True
return self

def add_scores(self) -> "AggregateRequest":
"""
If set, includes the score as an ordinary field of the row.
"""
self._add_scores = True
return self

def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
Expand All @@ -315,6 +323,9 @@ def build_args(self) -> List[str]:
if self._verbatim:
ret.append("VERBATIM")

if self._add_scores:
ret.append("ADDSCORES")

if self._cursor:
ret += self._cursor

Expand Down
17 changes: 17 additions & 0 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,23 @@ async def test_withsuffixtrie(decoded_r: redis.Redis):
assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_add_scores(decoded_r: redis.Redis):
assert await decoded_r.ft().create_index(
(TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True))
)

assert await decoded_r.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"})
assert await decoded_r.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"})

req = (aggregations.AggregateRequest("*").add_scores())
res = await decoded_r.ft().aggregate(req)
assert len(res.rows) == 2
assert res.rows[0] == ["__score", "0.2"]
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_if_redis_enterprise()
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,23 @@ def test_aggregations_filter(client):
assert res["results"][1]["extra_attributes"] == {"age": "25"}


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
def test_aggregations_add_scores(client):
client.ft().create_index(
(TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True))
)

client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"})
client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"})

req = (aggregations.AggregateRequest("*").add_scores())
res = client.ft().aggregate(req)
assert len(res.rows) == 2
assert res.rows[0] == ["__score", "0.2"]
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
def test_index_definition(client):
Expand Down

0 comments on commit e82dc8e

Please sign in to comment.