diff --git a/src/main/java/redis/clients/jedis/CommandObjects.java b/src/main/java/redis/clients/jedis/CommandObjects.java index 7226a014a7..6fef11ea50 100644 --- a/src/main/java/redis/clients/jedis/CommandObjects.java +++ b/src/main/java/redis/clients/jedis/CommandObjects.java @@ -4440,7 +4440,12 @@ public void setDefaultSearchDialect(int dialect) { private class SearchProfileResponseBuilder extends Builder>> { - private static final String PROFILE_STR = "profile"; + private static final String PROFILE_1_STR = "profile"; + + private static final String RESULTS_2_STR = "Results"; + private static final String PROFILE_2_STR = "Profile"; + + private static final String SHARDS_STR = "Shards"; private final Builder replyBuilder; @@ -4454,14 +4459,30 @@ public Map.Entry> build(Object data) { if (list == null || list.isEmpty()) return null; if (list.get(0) instanceof KeyValue) { + Object results = null, profile = null; for (KeyValue keyValue : (List) data) { - if (PROFILE_STR.equals(BuilderFactory.STRING.build(keyValue.getKey()))) { - return KeyValue.of(replyBuilder.build(data), - BuilderFactory.AGGRESSIVE_ENCODED_OBJECT_MAP.build(keyValue.getValue())); + String keyString = BuilderFactory.STRING.build(keyValue.getKey()); + Object valueRaw = keyValue.getValue(); + if (PROFILE_1_STR.equals(keyString)) { + profile = valueRaw; + results = data; + break; + } else if (RESULTS_2_STR.equals(keyString)) { + results = valueRaw; + } else if (PROFILE_2_STR.equals(keyString)) { + profile = valueRaw; } } + if (results != null) { + return KeyValue.of(replyBuilder.build(results), + BuilderFactory.AGGRESSIVE_ENCODED_OBJECT_MAP.build(profile)); + } } +// if (SHARDS_STR.equals(BuilderFactory.STRING.build(((List) list.get(1)).get(0)))) { +// return KeyValue.of(replyBuilder.build(list.get(0)), +// SearchBuilderFactory.SEARCH_PROFILE_PROFILE.build(((List) ((List) list.get(1)).get(1)).get(0))); +// } return KeyValue.of(replyBuilder.build(list.get(0)), SearchBuilderFactory.SEARCH_PROFILE_PROFILE.build(list.get(1))); } diff --git a/src/test/java/redis/clients/jedis/modules/search/AggregationTest.java b/src/test/java/redis/clients/jedis/modules/search/AggregationTest.java index cefdfbe5d3..5136e57d5a 100644 --- a/src/test/java/redis/clients/jedis/modules/search/AggregationTest.java +++ b/src/test/java/redis/clients/jedis/modules/search/AggregationTest.java @@ -134,6 +134,30 @@ public void testAggregations2() { assertEquals("10", rows.get(1).get("sum")); } + private Map getIteratorsProfile(Map profile) { + if (protocol != RedisProtocol.RESP3) { + return (Map) profile.get("Iterators profile"); + } else { + if (!profile.containsKey("Shards")) { + return (Map) ((List) profile.get("Iterators profile")).get(0); + } else { + return (Map) ((Map) ((List) profile.get("Shards")).get(0)).get("Iterators profile"); + } + } + } + + private List> getResultProcessorsProfile(Map profile) { + if (protocol != RedisProtocol.RESP3) { + return (List) profile.get("Result processors profile"); + } else { + if (!profile.containsKey("Shards")) { + return (List) profile.get("Result processors profile"); + } else { + return (List) ((Map) ((List) profile.get("Shards")).get(0)).get("Result processors profile"); + } + } + } + @Test public void testAggregations2Profile() { Schema sc = new Schema(); @@ -168,16 +192,10 @@ public void testAggregations2Profile() { Map profile = reply.getValue(); assertEquals(Arrays.asList("Index", "Grouper", "Sorter"), - ((List>) profile.get("Result processors profile")).stream() + getResultProcessorsProfile(profile).stream() .map(map -> map.get("Type")).collect(Collectors.toList())); - if (protocol != RedisProtocol.RESP3) { - assertEquals("WILDCARD", ((Map) profile.get("Iterators profile")).get("Type")); - } else { - assertEquals(Arrays.asList("WILDCARD"), - ((List>) profile.get("Iterators profile")).stream() - .map(map -> map.get("Type")).collect(Collectors.toList())); - } + assertEquals("WILDCARD", getIteratorsProfile(profile).get("Type")); } @Test diff --git a/src/test/java/redis/clients/jedis/modules/search/SearchTest.java b/src/test/java/redis/clients/jedis/modules/search/SearchTest.java index 1bc6cb345d..3a2f085fc7 100644 --- a/src/test/java/redis/clients/jedis/modules/search/SearchTest.java +++ b/src/test/java/redis/clients/jedis/modules/search/SearchTest.java @@ -1140,6 +1140,30 @@ public void testDialectsWithFTExplain() throws Exception { assertTrue("Should contain '{K=10 nearest vector'", client.ftExplain(index, query).contains("{K=10 nearest vector")); } + private Map getIteratorsProfile(Map profile) { + if (protocol != RedisProtocol.RESP3) { + return (Map) profile.get("Iterators profile"); + } else { + if (!profile.containsKey("Shards")) { + return (Map) ((List) profile.get("Iterators profile")).get(0); + } else { + return (Map) ((Map) ((List) profile.get("Shards")).get(0)).get("Iterators profile"); + } + } + } + + private List> getResultProcessorsProfile(Map profile) { + if (protocol != RedisProtocol.RESP3) { + return (List) profile.get("Result processors profile"); + } else { + if (!profile.containsKey("Shards")) { + return (List) profile.get("Result processors profile"); + } else { + return (List) ((Map) ((List) profile.get("Shards")).get(0)).get("Result processors profile"); + } + } + } + @Test public void searchProfile() { Schema sc = new Schema().addTextField("t1", 1.0).addTextField("t2", 1.0); @@ -1158,14 +1182,7 @@ public void searchProfile() { assertEquals(Collections.singletonList("doc1"), result.getDocuments().stream().map(Document::getId).collect(Collectors.toList())); Map profile = reply.getValue(); - Map iteratorsProfile; - if (protocol != RedisProtocol.RESP3) { - iteratorsProfile = (Map) profile.get("Iterators profile"); - } else { - List iteratorsProfileList = (List) profile.get("Iterators profile"); - assertEquals(1, iteratorsProfileList.size()); - iteratorsProfile = (Map) iteratorsProfileList.get(0); - } + Map iteratorsProfile = getIteratorsProfile(profile); assertEquals("TEXT", iteratorsProfile.get("Type")); assertEquals("foo", iteratorsProfile.get("Term")); assertEquals(1L, iteratorsProfile.get("Counter")); @@ -1173,7 +1190,7 @@ public void searchProfile() { assertSame(Double.class, iteratorsProfile.get("Time").getClass()); assertEquals(Arrays.asList("Index", "Scorer", "Sorter", "Loader"), - ((List>) profile.get("Result processors profile")).stream() + getResultProcessorsProfile(profile).stream() .map(map -> map.get("Type")).collect(Collectors.toList())); } @@ -1206,13 +1223,7 @@ public void testHNSWVVectorSimilarity() { doc1 = reply.getKey().getDocuments().get(0); assertEquals("a", doc1.getId()); assertEquals("0", doc1.get("__v_score")); - if (protocol != RedisProtocol.RESP3) { - assertEquals("VECTOR", ((Map) reply.getValue().get("Iterators profile")).get("Type")); - } else { - assertEquals(Arrays.asList("VECTOR"), - ((List>) reply.getValue().get("Iterators profile")).stream() - .map(map -> map.get("Type")).collect(Collectors.toList())); - } + assertEquals("VECTOR", getIteratorsProfile(reply.getValue()).get("Type")); } @Test @@ -1244,13 +1255,7 @@ public void testFlatVectorSimilarity() { doc1 = reply.getKey().getDocuments().get(0); assertEquals("a", doc1.getId()); assertEquals("0", doc1.get("__v_score")); - if (protocol != RedisProtocol.RESP3) { - assertEquals("VECTOR", ((Map) reply.getValue().get("Iterators profile")).get("Type")); - } else { - assertEquals(Arrays.asList("VECTOR"), - ((List>) reply.getValue().get("Iterators profile")).stream() - .map(map -> map.get("Type")).collect(Collectors.toList())); - } + assertEquals("VECTOR", getIteratorsProfile(reply.getValue()).get("Type")); } @Test diff --git a/src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java b/src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java index 897da8eece..26af00d3fc 100644 --- a/src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java +++ b/src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java @@ -1056,6 +1056,30 @@ public void testFlatVectorSimilarity() { assertEquals("0", doc1.get("__v_score")); } + private Map getIteratorsProfile(Map profile) { + if (protocol != RedisProtocol.RESP3) { + return (Map) profile.get("Iterators profile"); + } else { + if (!profile.containsKey("Shards")) { + return (Map) ((List) profile.get("Iterators profile")).get(0); + } else { + return (Map) ((Map) ((List) profile.get("Shards")).get(0)).get("Iterators profile"); + } + } + } + + private List> getResultProcessorsProfile(Map profile) { + if (protocol != RedisProtocol.RESP3) { + return (List) profile.get("Result processors profile"); + } else { + if (!profile.containsKey("Shards")) { + return (List) profile.get("Result processors profile"); + } else { + return (List) ((Map) ((List) profile.get("Shards")).get(0)).get("Result processors profile"); + } + } + } + @Test public void searchProfile() { assertOK(client.ftCreate(index, TextField.of("t1"), TextField.of("t2"))); @@ -1070,17 +1094,11 @@ public void searchProfile() { SearchResult result = reply.getKey(); assertEquals(1, result.getTotalResults()); - assertEquals(Collections.singletonList("doc1"), result.getDocuments().stream().map(Document::getId).collect(Collectors.toList())); + assertEquals(Collections.singletonList("doc1"), + result.getDocuments().stream().map(Document::getId).collect(Collectors.toList())); Map profile = reply.getValue(); - Map iteratorsProfile; - if (protocol != RedisProtocol.RESP3) { - iteratorsProfile = (Map) profile.get("Iterators profile"); - } else { - List iteratorsProfileList = (List) profile.get("Iterators profile"); - assertEquals(1, iteratorsProfileList.size()); - iteratorsProfile = (Map) iteratorsProfileList.get(0); - } + Map iteratorsProfile = getIteratorsProfile(profile); assertEquals("TEXT", iteratorsProfile.get("Type")); assertEquals("foo", iteratorsProfile.get("Term")); assertEquals(1L, iteratorsProfile.get("Counter")); @@ -1088,7 +1106,7 @@ public void searchProfile() { assertSame(Double.class, iteratorsProfile.get("Time").getClass()); assertEquals(Arrays.asList("Index", "Scorer", "Sorter", "Loader"), - ((List>) profile.get("Result processors profile")).stream() + getResultProcessorsProfile(profile).stream() .map(map -> map.get("Type")).collect(Collectors.toList())); } @@ -1116,16 +1134,9 @@ public void vectorSearchProfile() { Map profile = reply.getValue(); - if (protocol != RedisProtocol.RESP3) { - assertEquals("VECTOR", ((Map) profile.get("Iterators profile")).get("Type")); - } else { - assertEquals(Arrays.asList("VECTOR"), - ((List>) profile.get("Iterators profile")).stream() - .map(map -> map.get("Type")).collect(Collectors.toList())); - } + assertEquals("VECTOR", getIteratorsProfile(profile).get("Type")); - List> resultProcessorsProfile - = (List>) reply.getValue().get("Result processors profile"); + List> resultProcessorsProfile = getResultProcessorsProfile(profile); assertEquals(3, resultProcessorsProfile.size()); assertEquals("Index", resultProcessorsProfile.get(0).get("Type")); assertEquals("Sorter", resultProcessorsProfile.get(2).get("Type")); @@ -1146,13 +1157,8 @@ public void maxPrefixExpansionSearchProfile() { Map.Entry> reply = client.ftProfileSearch(index, FTProfileParams.profileParams(), "foo*", FTSearchParams.searchParams().limit(0, 0)); // Warning=Max prefix expansion reached - if (protocol != RedisProtocol.RESP3) { - assertEquals("Max prefix expansion reached", - ((Map) reply.getValue().get("Iterators profile")).get("Warning")); - } else { - assertEquals("Max prefix expansion reached", - ((Map) ((List) reply.getValue().get("Iterators profile")).get(0)).get("Warning")); - } + assertEquals("Max prefix expansion reached", + getIteratorsProfile(reply.getValue()).get("Warning")); } finally { client.ftConfigSet(configParam, configValue); } @@ -1164,12 +1170,10 @@ public void noContentSearchProfile() { client.hset("1", "t", "foo"); client.hset("2", "t", "bar"); - Map.Entry> profile = client.ftProfileSearch(index, + Map.Entry> reply = client.ftProfileSearch(index, FTProfileParams.profileParams(), "foo -@t:baz", FTSearchParams.searchParams().noContent()); - Map depth0 = protocol != RedisProtocol.RESP3 - ? (Map) profile.getValue().get("Iterators profile") - : ((List>) profile.getValue().get("Iterators profile")).get(0); + Map depth0 = getIteratorsProfile(reply.getValue()); assertEquals("INTERSECT", depth0.get("Type")); List> depth0_children = (List>) depth0.get("Child iterators"); @@ -1191,13 +1195,11 @@ public void deepReplySearchProfile() { client.hset("1", "t", "hello"); client.hset("2", "t", "world"); - Map.Entry> profile + Map.Entry> reply = client.ftProfileSearch(index, FTProfileParams.profileParams(), "hello(hello(hello(hello(hello(hello)))))", FTSearchParams.searchParams().noContent()); - Map depth0 = protocol != RedisProtocol.RESP3 - ? (Map) profile.getValue().get("Iterators profile") - : ((List>) profile.getValue().get("Iterators profile")).get(0); + Map depth0 = getIteratorsProfile(reply.getValue()); AtomicInteger intersectLevelCount = new AtomicInteger(); AtomicInteger textLevelCount = new AtomicInteger(); @@ -1234,12 +1236,10 @@ public void limitedSearchProfile() { client.hset("3", "t", "help"); client.hset("4", "t", "helowa"); - Map.Entry> profile = client.ftProfileSearch(index, + Map.Entry> reply = client.ftProfileSearch(index, FTProfileParams.profileParams().limited(), "%hell% hel*", FTSearchParams.searchParams().noContent()); - Map depth0 = protocol != RedisProtocol.RESP3 - ? (Map) profile.getValue().get("Iterators profile") - : ((List>) profile.getValue().get("Iterators profile")).get(0); + Map depth0 = getIteratorsProfile(reply.getValue()); assertEquals("INTERSECT", depth0.get("Type")); assertEquals(3L, depth0.get("Counter"));