From 61f3cd8ec2c63889173d80b1ce77625e06125319 Mon Sep 17 00:00:00 2001 From: vga91 Date: Tue, 2 Jul 2024 08:49:41 +0200 Subject: [PATCH] Fixes #4121: Better error messaging with vectordb query/get procedures --- .../database-integration/vectordb/chroma.adoc | 18 ++++----- .../database-integration/vectordb/custom.adoc | 6 +-- .../database-integration/vectordb/milvus.adoc | 35 +++++++++++----- .../vectordb/pinecone.adoc | 18 ++++----- .../database-integration/vectordb/qdrant.adoc | 18 ++++----- .../vectordb/weaviate.adoc | 26 +++++++----- .../test/java/apoc/vectordb/MilvusTest.java | 40 +++++++++++++++++++ .../test/java/apoc/vectordb/WeaviateTest.java | 20 +++++++++- .../src/main/java/apoc/vectordb/Milvus.java | 4 ++ .../src/main/java/apoc/vectordb/VectorDb.java | 9 ++++- .../main/java/apoc/vectordb/VectorDbUtil.java | 3 +- .../apoc/vectordb/VectorEmbeddingConfig.java | 7 ++++ .../src/main/java/apoc/vectordb/Weaviate.java | 11 ++++- 13 files changed, 162 insertions(+), 53 deletions(-) diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc index d135220f49..371cd0f04e 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc @@ -74,9 +74,9 @@ CALL apoc.vectordb.chroma.get($host, '', ['1','2'], {', ['1','2'], {.svc.gcp-starter.pinecone.io/qu .Example results [opts="header"] |=== -| score | metadata | id | vector | text -| 1, | {a: 1} | 1 | [1,2,3,4] -| 0.1 | {a: 2} | 2 | [1,2,3,4] +| score | metadata | id | vector | text | errors +| 1, | {a: 1} | 1 | [1,2,3,4] | null +| 0.1 | {a: 2} | 2 | [1,2,3,4] | null | ... |=== diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc index 11fac11124..eb0f1521b8 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc @@ -76,12 +76,21 @@ CALL apoc.vectordb.milvus.get('http://localhost:19531', 'test_collection', [1,2] .Example results [opts="header"] |=== -| score | metadata | id | vector | text | entity -| null | {city: "Berlin", foo: "one"} | null | null | null | null -| null | {city: "Berlin", foo: "two"} | null | null | null | null +| score | metadata | id | vector | text | entity | errors +| null | {city: "Berlin", foo: "one"} | null | null | null | null | null +| null | {city: "Berlin", foo: "two"} | null | null | null | null | null | ... |=== +In case of errors, e.g. due to `apoc.vectordb.milvus.query` with wrong vector size as a 3rd parameter, the error field will be populated, for example: + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | errors +| null | null | null | null | null | ..please check the primary key and its' type can only in [int, string], error: unable to cast "wrong" of type string to int64.. +|=== + .Get vectors with `{allResults: true}` [source,cypher] ---- @@ -92,9 +101,9 @@ CALL apoc.vectordb.milvus.get('http://localhost:19531', 'test_collection', [1,2] .Example results [opts="header"] |=== -| score | metadata | id | vector | text | entity -| null | {city: "Berlin", foo: "one"} | 1 | [...] | null | null -| null | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| score | metadata | id | vector | text | entity | errors +| null | {city: "Berlin", foo: "one"} | 1 | [...] | null | null | null +| null | {city: "Berlin", foo: "two"} | 2 | [...] | null | null | null | ... |=== @@ -115,12 +124,20 @@ CALL apoc.vectordb.milvus.query('http://localhost:19531', .Example results [opts="header"] |=== -| score | metadata | id | vector | text | entity -| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | null | null -| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| score | metadata | id | vector | text | entity | errors +| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | null | null | null +| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | null | null | null | ... |=== +In case of errors, e.g. due to `apoc.vectordb.milvus.query` with wrong vector size as a 3rd parameter, the error field will be populated, for example: + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | errors +| null | null | null | null | null | ..can only accept json format request, error: dimension: 4, but length of []float: 3: invalid parameter[expected=FloatVector][actual=[0.2,0.1,0.9]].. +|=== We can define a mapping, to auto-create one/multiple nodes and relationships, by leveraging the vector metadata. diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc index 8972ce404d..25dd1a4638 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc @@ -92,9 +92,9 @@ CALL apoc.vectordb.pinecone.get($host, 'test-index', [1,2], {}) .Example results [opts="header"] |=== -| score | metadata | id | vector | text | entity -| null | {city: "Berlin", foo: "one"} | null | null | null | null -| null | {city: "Berlin", foo: "two"} | null | null | null | null +| score | metadata | id | vector | text | entity | errors +| null | {city: "Berlin", foo: "one"} | null | null | null | null | null +| null | {city: "Berlin", foo: "two"} | null | null | null | null | null | ... |=== @@ -108,9 +108,9 @@ CALL apoc.vectordb.pinecone.get($host, 'test-index', ['1','2'], {allResults: tru .Example results [opts="header"] |=== -| score | metadata | id | vector | text | entity -| null | {city: "Berlin", foo: "one"} | 1 | [...] | null | null -| null | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| score | metadata | id | vector | text | entity | errors +| null | {city: "Berlin", foo: "one"} | 1 | [...] | null | null | null +| null | {city: "Berlin", foo: "two"} | 2 | [...] | null | null | null | ... |=== @@ -129,9 +129,9 @@ CALL apoc.vectordb.pinecone.query($host, .Example results [opts="header"] |=== -| score | metadata | id | vector | text | entity -| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | null | null -| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| score | metadata | id | vector | text | entity | errors +| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | null | null | null +| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | null | null | null | ... |=== diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc index e3e684861d..fd42aa3fb4 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc @@ -75,9 +75,9 @@ CALL apoc.vectordb.qdrant.get($hostOrKey, 'test_collection', [1,2], { { + Map error = (Map) row.get(DEFAULT_ERRORS); + String message = (String) error.get("message"); + String expected = "invalid parameter"; + assertTrue("Actual error message is: " + message, + message.contains(expected) + ); + }); + } + @Test public void queryVectorsWithoutVectorResult() { testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", @@ -322,6 +338,30 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", ); } + @Test + public void getVectorsWithWrongVectorIdFormat() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + + Map conf = map(ALL_RESULTS_KEY, true, + FIELDS_KEY, FIELDS, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo")); + + testCall(db, "CALL apoc.vectordb.milvus.get($host, 'test_collection', ['wrong', 'id'], $conf)", + map("host", HOST, "conf", conf), + r -> { + Map error = (Map) r.get(DEFAULT_ERRORS); + String message = (String) error.get("message"); + String expected = "unable to cast"; + assertTrue("Actual error message is: " + message, + message.contains(expected) + ); + } + ); + } + @Test public void queryVectorsWithCreateNodeUsingExistingNode() { diff --git a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java index aa21f6bc4c..640cf18da5 100644 --- a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java +++ b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java @@ -40,6 +40,7 @@ import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_ERRORS; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; @@ -48,6 +49,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; @@ -200,7 +202,21 @@ public void queryVectors() { assertLondonResult(row, ID_2, FALSE); assertNotNull(row.get("score")); assertNotNull(row.get("vector")); - }); + }); + } + + @Test + public void queryVectorsWithWrongVectorSize() { + testCall(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9], null, 5, $conf)", + map("host", HOST, "conf", map(ALL_RESULTS_KEY, true, FIELDS_KEY, FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + row -> { + List errors = (List) row.get(DEFAULT_ERRORS); + String message = (String) errors.get(0).get("message"); + String expected = "vector lengths don't match"; + assertTrue("Actual error message is: " + message, + message.contains(expected) + ); + }); } @Test @@ -377,7 +393,7 @@ public void getReadOnlyVectorsWithMapping() { METADATA_KEY, "foo") ); - testResult(db, "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', [$id1, $id2], $conf) " + + testResult(db, "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', ['$id1', $id2], $conf) " + "YIELD vector, id, metadata, node RETURN * ORDER BY id", MapUtil.map("host", HOST, "id1", ID_1, "id2", ID_2, "conf", conf), r -> assertReadOnlyProcWithMappingResults(r, "node") diff --git a/extended/src/main/java/apoc/vectordb/Milvus.java b/extended/src/main/java/apoc/vectordb/Milvus.java index 05ce11468e..ceeb99b18f 100644 --- a/extended/src/main/java/apoc/vectordb/Milvus.java +++ b/extended/src/main/java/apoc/vectordb/Milvus.java @@ -23,6 +23,7 @@ import static apoc.vectordb.VectorDb.getEmbeddingResultStream; import static apoc.vectordb.VectorDbHandler.Type.MILVUS; import static apoc.vectordb.VectorDbUtil.*; +import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_ERRORS; @Extended public class Milvus { @@ -180,6 +181,9 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe private Stream getMapStream(Map v) { var data = v.get("data"); + if (data == null) { + return Stream.of(Map.of(DEFAULT_ERRORS, v)); + } return ((List) data).stream() .map(i -> { diff --git a/extended/src/main/java/apoc/vectordb/VectorDb.java b/extended/src/main/java/apoc/vectordb/VectorDb.java index 6123e84966..9c9539222e 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDb.java +++ b/extended/src/main/java/apoc/vectordb/VectorDb.java @@ -114,6 +114,12 @@ public static Stream getEmbeddingResultStream(VectorEmbeddingCo } public static EmbeddingResult getEmbeddingResult(VectorEmbeddingConfig conf, Transaction tx, boolean hasEmbedding, boolean hasMetadata, VectorMappingConfig mapping, Map m) { + Object errors = m.get(conf.getErrorsKey()); + if (errors != null) { + return new EmbeddingResult(null, null, null, null, null, null, null, + errors); + } + Object id = conf.isAllResults() ? m.get(conf.getIdKey()) : null; List embedding = hasEmbedding ? (List) m.get(conf.getVectorKey()) : null; Map metadata = hasMetadata ? (Map) m.get(conf.getMetadataKey()) : null; @@ -126,7 +132,8 @@ public static EmbeddingResult getEmbeddingResult(VectorEmbeddingConfig conf, Tra if (entity != null) entity = Util.rebind(tx, entity); return new EmbeddingResult(id, score, embedding, metadata, text, mapping.getNodeLabel() == null ? null : (Node) entity, - mapping.getNodeLabel() != null ? null : (Relationship) entity + mapping.getNodeLabel() != null ? null : (Relationship) entity, + errors ); } diff --git a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java index 9a16cd1d12..9995e34341 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java +++ b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -39,7 +39,8 @@ public static void getEndpoint(Map config, String endpoint) { public record EmbeddingResult( Object id, Double score, List vector, Map metadata, String text, Node node, - Relationship rel) {} + Relationship rel, + Object errors) {} public static Map getCommonVectorDbInfo( String hostOrKey, String collection, Map configuration, String templateUrl, VectorDbHandler handler) { diff --git a/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java b/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java index 61e5264a8c..fb8957fa67 100644 --- a/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java +++ b/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java @@ -16,6 +16,7 @@ public class VectorEmbeddingConfig { public static final String DEFAULT_ID = "id"; public static final String DEFAULT_TEXT = "text"; + public static final String DEFAULT_ERRORS = "errors"; public static final String DEFAULT_VECTOR = "vector"; public static final String DEFAULT_METADATA = "metadata"; public static final String DEFAULT_SCORE = "score"; @@ -27,6 +28,7 @@ public class VectorEmbeddingConfig { private final String vectorKey; private final String metadataKey; private final String scoreKey; + private final String errorsKey; private final boolean allResults; private final boolean metaAsSubKey; @@ -40,6 +42,7 @@ public VectorEmbeddingConfig(Map config) { this.scoreKey = (String) config.getOrDefault(SCORE_KEY, DEFAULT_SCORE); this.idKey = (String) config.getOrDefault(ID_KEY, DEFAULT_ID); this.textKey = (String) config.getOrDefault(TEXT_KEY, DEFAULT_TEXT); + this.errorsKey = (String) config.getOrDefault(TEXT_KEY, DEFAULT_ERRORS); this.allResults = Util.toBoolean(config.get(ALL_RESULTS_KEY)); this.mapping = new VectorMappingConfig((Map) config.getOrDefault(MAPPING_KEY, Map.of())); @@ -68,6 +71,10 @@ public String getTextKey() { return textKey; } + public String getErrorsKey() { + return errorsKey; + } + public boolean isAllResults() { return allResults; } diff --git a/extended/src/main/java/apoc/vectordb/Weaviate.java b/extended/src/main/java/apoc/vectordb/Weaviate.java index 7653c32e46..6c5b26fb61 100644 --- a/extended/src/main/java/apoc/vectordb/Weaviate.java +++ b/extended/src/main/java/apoc/vectordb/Weaviate.java @@ -4,6 +4,7 @@ import apoc.ml.RestAPIConfig; import apoc.result.ListResult; import apoc.result.MapResult; +import apoc.util.CollectionUtils; import apoc.util.UrlResolver; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.Transaction; @@ -27,6 +28,7 @@ import static apoc.vectordb.VectorDb.getEmbeddingResultStream; import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE; import static apoc.vectordb.VectorDbUtil.*; +import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_ERRORS; @Extended public class Weaviate { @@ -222,7 +224,14 @@ private Stream queryCommon(String hostOrKey, String collection, return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> { - Object getValue = ((Map) v).get("data").get("Get"); + Map mapResult = (Map) v; + List errors = (List) mapResult.get("errors"); + if (CollectionUtils.isNotEmpty(errors)) { + Map map = new HashMap<>(); + map.put(DEFAULT_ERRORS, errors); + return Stream.of(map); + } + Object getValue = mapResult.get("data").get("Get"); Object collectionValue = ((Map) getValue).get(collection); return ((List) collectionValue).stream() .map(i -> {