From f97671af4ec16c69696c7edbba42e2b74f8f5d59 Mon Sep 17 00:00:00 2001 From: vga91 Date: Wed, 29 May 2024 10:45:41 +0200 Subject: [PATCH] Fixes #4090: The apoc.vectordb.*.get/query procedures should search for nodes/relationships with mapping config --- .../database-integration/vectordb/chroma.adoc | 39 ++++++++++----- .../database-integration/vectordb/milvus.adoc | 21 +++++++- .../vectordb/pinecone.adoc | 21 +++++++- .../database-integration/vectordb/qdrant.adoc | 20 +++++++- .../vectordb/weaviate.adoc | 22 ++++++++- .../test/java/apoc/vectordb/ChromaDbTest.java | 47 ++++++++++-------- .../test/java/apoc/vectordb/MilvusTest.java | 47 +++++++++++++----- .../test/java/apoc/vectordb/QdrantTest.java | 48 ++++++++++-------- .../test/java/apoc/vectordb/WeaviateTest.java | 47 +++++++++++------- .../src/main/java/apoc/vectordb/ChromaDb.java | 23 ++++----- .../src/main/java/apoc/vectordb/Milvus.java | 27 +++++----- .../src/main/java/apoc/vectordb/Pinecone.java | 25 ++++------ .../src/main/java/apoc/vectordb/Qdrant.java | 24 ++++----- .../src/main/java/apoc/vectordb/VectorDb.java | 19 +++---- .../main/java/apoc/vectordb/VectorDbUtil.java | 7 --- .../apoc/vectordb/VectorMappingConfig.java | 9 ++++ .../src/main/java/apoc/vectordb/Weaviate.java | 26 +++++----- .../test/java/apoc/vectordb/PineconeTest.java | 49 +++++++++++++------ .../java/apoc/vectordb/VectorDbTestUtil.java | 18 +++++++ 19 files changed, 347 insertions(+), 192 deletions(-) diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc index b37a2e3a38..5b2f3842cd 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc @@ -119,17 +119,6 @@ CALL apoc.vectordb.chroma.queryAndUpdate($host, | ... |=== -[NOTE] -==== -We can use mapping with `apoc.vectordb.chroma.getAndUpdate` procedure as well -==== - -[NOTE] -==== -To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures. -For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]}, -so that we do not return the other values that we do not need. -==== In the same way as other procedures, we can define a mapping, to fetch the associated nodes and relationships and optionally create them, @@ -151,7 +140,35 @@ CALL apoc.vectordb.chroma.query($host, '', }) ---- +We can also use mapping for `apoc.vectordb.chroma.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates. +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: +[source,cypher] +---- +CALL apoc.vectordb.chroma.query($host, '', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +[NOTE] +==== +We can use mapping with `apoc.vectordb.chroma.get*` procedures as well +==== + +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures. +For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]}, +so that we do not return the other values that we do not need. +==== .Delete vectors (it leverages https://docs.trychroma.com/usage-guide#deleting-data-from-a-collection[this API]) [source,cypher] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc index 8f1ee8198a..4ff2ad531e 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc @@ -189,9 +189,28 @@ which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. + +We can also use mapping for `apoc.vectordb.milvus.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates. +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.milvus.query('http://localhost:19531', 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + [NOTE] ==== -We can use mapping with `apoc.vectordb.milvus.getAndUpdate` procedure as well +We can use mapping with `apoc.vectordb.milvus.get*` procedures as well ==== [NOTE] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc index 4bf59e6322..09e815713f 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc @@ -203,9 +203,28 @@ which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. + +We can also use mapping for `apoc.vectordb.pinecone.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates. +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.pinecone.query($host, 'test-index', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + [NOTE] ==== -We can use mapping with `apoc.vectordb.pinecone.getAndUpdate` procedure as well +We can use mapping with `apoc.vectordb.pinecone.get*` procedures as well ==== [NOTE] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc index 1f6cce4b97..ba2cc6dceb 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc @@ -191,9 +191,27 @@ which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. + +We can also use mapping for `apoc.vectordb.qdrant.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates. +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.queryAndUpdate($hostOrKey, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + [NOTE] ==== -We can use mapping with `apoc.vectordb.qdrant.getAndUpdate` procedure as well +We can use mapping with `apoc.vectordb.qdrant.get*` procedures as well ==== [NOTE] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc index 3ab4540403..f438f907a2 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc @@ -205,9 +205,29 @@ and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. +We can also use mapping for `apoc.vectordb.weaviate.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates. +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.weaviate.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { fields: ["city", "foo"], + mapping: { + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + + + [NOTE] ==== -We can use mapping with `apoc.vectordb.weaviate.getAndUpdate` procedure as well +We can use mapping with `apoc.vectordb.weaviate.get*` procedures as well ==== [NOTE] diff --git a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java index 3b0b136701..ea642b1125 100644 --- a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java +++ b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java @@ -18,6 +18,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicReference; +import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -25,6 +26,7 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; import static apoc.vectordb.VectorDbTestUtil.EntityType.*; @@ -294,19 +296,22 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", assertNodesCreated(db); } + @Test public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); - - try { - testCall(db, "CALL apoc.vectordb.chroma.get($host, $collection, [1, 2], $conf)", - map("host", HOST, "collection", COLL_ID.get(), "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + MAPPING_KEY, map(NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, "CALL apoc.vectordb.chroma.get($host, $collection, ['1', '2'], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); } @Test @@ -338,17 +343,19 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", @Test public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); - - try { - testCall(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - map("host", HOST, "collection", COLL_ID.get(), "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + MAPPING_KEY, map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); } @Test diff --git a/extended-it/src/test/java/apoc/vectordb/MilvusTest.java b/extended-it/src/test/java/apoc/vectordb/MilvusTest.java index 513b6167ec..5a78c06a50 100644 --- a/extended-it/src/test/java/apoc/vectordb/MilvusTest.java +++ b/extended-it/src/test/java/apoc/vectordb/MilvusTest.java @@ -2,7 +2,6 @@ import apoc.util.TestUtil; import apoc.util.Util; -import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -27,9 +26,9 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; @@ -297,6 +296,24 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", assertNodesCreated(db); } + @Test + public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + + Map conf = map(ALL_RESULTS_KEY, true, + FIELDS_KEY, FIELDS, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo")); + + testResult(db, "CALL apoc.vectordb.milvus.get($host, 'test_collection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); + } + @Test public void queryVectorsWithCreateNodeUsingExistingNode() { @@ -336,7 +353,8 @@ public void queryVectorsWithCreateRel() { MAPPING_KEY, map(EMBEDDING_KEY, "vect", REL_TYPE, "TEST", ENTITY_KEY, "myId", - METADATA_KEY, "foo")); + METADATA_KEY, "foo") + ); testResult(db, "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", map("host", HOST, "conf", conf), r -> { @@ -356,17 +374,20 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", @Test public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); - - try { - testCall(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + FIELDS_KEY, FIELDS, + MAPPING_KEY, map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); } @Test diff --git a/extended-it/src/test/java/apoc/vectordb/QdrantTest.java b/extended-it/src/test/java/apoc/vectordb/QdrantTest.java index 9f15093526..6eb3f8c948 100644 --- a/extended-it/src/test/java/apoc/vectordb/QdrantTest.java +++ b/extended-it/src/test/java/apoc/vectordb/QdrantTest.java @@ -2,7 +2,6 @@ import apoc.util.TestUtil; import apoc.util.Util; -import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -10,7 +9,9 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Result; import org.neo4j.test.TestDatabaseManagementServiceBuilder; import org.testcontainers.qdrant.QdrantContainer; @@ -27,15 +28,16 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -331,17 +333,20 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", @Test public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map(NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); - try { - testCall(db, "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1, 2], $conf)", - map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + testResult(db, "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); } @Test @@ -405,17 +410,20 @@ MAPPING_KEY, map( @Test public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); - try { - testCall(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); } @Test diff --git a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java index ff5a31d751..4f924f4bd1 100644 --- a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java +++ b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java @@ -353,24 +353,28 @@ public void getVectorsWithCreateNodeUsingExistingNode() { assertNodesCreated(db); } + @Test public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + Map conf = MapUtil.map(ALL_RESULTS_KEY, true, - MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect")); + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, MapUtil.map( + NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); - try { - testCall(db, "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', [1, 2], $conf)", - map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + testResult(db, "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', [$id1, $id2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + MapUtil.map("host", HOST, "id1", ID_1, "id2", ID_2, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); } @Test public void queryVectorsWithCreateRel() { - db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); Map conf = map(ALL_RESULTS_KEY, true, @@ -400,17 +404,22 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", @Test public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + Map conf = MapUtil.map(ALL_RESULTS_KEY, true, - MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect")); + FIELDS_KEY, FIELDS, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, MapUtil.map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); - try { - testCall(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - MapUtil.map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata, rel RETURN * ORDER BY id", + MapUtil.map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); } @Test diff --git a/extended/src/main/java/apoc/vectordb/ChromaDb.java b/extended/src/main/java/apoc/vectordb/ChromaDb.java index 11d07bc2cd..7cf38f64e0 100644 --- a/extended/src/main/java/apoc/vectordb/ChromaDb.java +++ b/extended/src/main/java/apoc/vectordb/ChromaDb.java @@ -129,7 +129,7 @@ public Stream get(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, true); + return getCommon(hostOrKey, collection, ids, configuration, false); } @Procedure(value = "apoc.vectordb.chroma.getAndUpdate", mode = Mode.WRITE) @@ -138,18 +138,16 @@ public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration, true); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) throws Exception { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean updateMode) throws Exception { String url = "%s/api/v1/collections/%s/get"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.chroma.getAndUpdate"); - } - VectorEmbeddingConfig apiConfig = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + apiConfig.getMapping().setUpdateMode(updateMode); + return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, tx, v -> listToMap((Map) v).stream()); } @@ -162,7 +160,7 @@ public Stream query(@Name("hostOrKey") String hostOrKey, @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); } @Procedure(value = "apoc.vectordb.chroma.queryAndUpdate", mode = Mode.WRITE) @@ -173,18 +171,15 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean updateMode) throws Exception { String url = "%s/api/v1/collections/%s/query"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.chroma.queryAndUpdate"); - } VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + conf.getMapping().setUpdateMode(updateMode); return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> listOfListsToMap((Map) v).stream()); } diff --git a/extended/src/main/java/apoc/vectordb/Milvus.java b/extended/src/main/java/apoc/vectordb/Milvus.java index c97d45b4e6..fdcf62144b 100644 --- a/extended/src/main/java/apoc/vectordb/Milvus.java +++ b/extended/src/main/java/apoc/vectordb/Milvus.java @@ -133,7 +133,7 @@ public Stream get(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, true); + return getCommon(hostOrKey, collection, ids, configuration, false); } @Procedure(value = "apoc.vectordb.milvus.getAndUpdate", mode = Mode.WRITE) @@ -142,18 +142,15 @@ public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration, true); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) throws Exception { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean updateMode) throws Exception { String url = "%s/entities/get"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.milvus.getAndUpdate"); - } - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + conf.getMapping().setUpdateMode(updateMode); return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> getMapStream((Map) v)); } @@ -167,7 +164,7 @@ public Stream query(@Name("hostOrKey") String hostOrKey, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); } @Procedure(value = "apoc.vectordb.milvus.queryAndUpdate", mode = Mode.WRITE) @@ -179,7 +176,7 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); } private Stream getMapStream(Map v) { @@ -198,16 +195,14 @@ private Stream getMapStream(Map v) { }); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration, boolean updateMode) throws Exception { String url = "%s/entities/search"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.milvus.queryAndUpdate"); - } - VectorEmbeddingConfig apiConfig = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); - return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, tx, + VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + conf.getMapping().setUpdateMode(updateMode); + + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> getMapStream((Map) v)); } diff --git a/extended/src/main/java/apoc/vectordb/Pinecone.java b/extended/src/main/java/apoc/vectordb/Pinecone.java index 036261143f..06461e24e5 100644 --- a/extended/src/main/java/apoc/vectordb/Pinecone.java +++ b/extended/src/main/java/apoc/vectordb/Pinecone.java @@ -22,7 +22,6 @@ import static apoc.vectordb.VectorDb.executeRequest; import static apoc.vectordb.VectorDb.getEmbeddingResultStream; import static apoc.vectordb.VectorDbHandler.Type.PINECONE; -import static apoc.vectordb.VectorDbUtil.checkMappingConf; import static apoc.vectordb.VectorDbUtil.getCommonVectorDbInfo; @Extended @@ -133,7 +132,7 @@ public Stream get(@Name("hostOrKey") String hostOr @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, true); + return getCommon(hostOrKey, collection, ids, configuration, false); } @Procedure(value = "apoc.vectordb.pinecone.getAndUpdate", mode = Mode.WRITE) @@ -142,18 +141,16 @@ public Stream getAndUpdate(@Name("hostOrKey") Stri @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration, true); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) throws Exception { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean updateMode) throws Exception { String url = "%s/vectors/fetch"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.pinecone.getAndUpdate"); - } VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + conf.getMapping().setUpdateMode(updateMode); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> { Object vectors = ((Map) v).get("vectors"); @@ -170,7 +167,7 @@ public Stream query(@Name("hostOrKey") String host @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); } @Procedure(value = "apoc.vectordb.pinecone.queryAndUpdate", mode = Mode.WRITE) @@ -181,18 +178,16 @@ public Stream queryAndUpdate(@Name("hostOrKey") St @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean updateMode) throws Exception { String url = "%s/query"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.pinecone.queryAndUpdate"); - } - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + conf.getMapping().setUpdateMode(updateMode); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> { Map map = (Map) v; diff --git a/extended/src/main/java/apoc/vectordb/Qdrant.java b/extended/src/main/java/apoc/vectordb/Qdrant.java index da7f42903d..a245f47721 100644 --- a/extended/src/main/java/apoc/vectordb/Qdrant.java +++ b/extended/src/main/java/apoc/vectordb/Qdrant.java @@ -131,7 +131,7 @@ public Stream get(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, true); + return getCommon(hostOrKey, collection, ids, configuration, false); } @Procedure(value = "apoc.vectordb.qdrant.getAndUpdate", mode = Mode.WRITE) @@ -140,18 +140,16 @@ public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration, true); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) throws Exception { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean updateMode) throws Exception { String url = "%s/collections/%s/points"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.qdrant.getAndUpdate"); - } - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + conf.getMapping().setUpdateMode(updateMode); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx); } @@ -163,7 +161,7 @@ public Stream query(@Name("hostOrKey") String hostOrKey, @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); } @Procedure(value = "apoc.vectordb.qdrant.queryAndUpdate", mode = Mode.WRITE) @@ -174,18 +172,16 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean updateMode) throws Exception { String url = "%s/collections/%s/points/search"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.qdrant.queryAndUpdate"); - } VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + conf.getMapping().setUpdateMode(updateMode); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx); } diff --git a/extended/src/main/java/apoc/vectordb/VectorDb.java b/extended/src/main/java/apoc/vectordb/VectorDb.java index dabf699c5e..a40de9dbfb 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDb.java +++ b/extended/src/main/java/apoc/vectordb/VectorDb.java @@ -152,15 +152,16 @@ private static Entity handleMappingNode(Transaction transaction, VectorMappingCo Node node; Object propValue = metaProps.get(mapping.getMetadataKey()); node = transaction.findNode(Label.label(mapping.getNodeLabel()), mapping.getEntityKey(), propValue); - if (node == null && mapping.isCreate()) { - node = transaction.createNode(Label.label(mapping.getNodeLabel())); - node.setProperty(mapping.getEntityKey(), propValue); + if (mapping.isUpdateMode()) { + if (node == null && mapping.isCreate()) { + node = transaction.createNode(Label.label(mapping.getNodeLabel())); + node.setProperty(mapping.getEntityKey(), propValue); + } + if (node != null) { + setProperties(node, metaProps); + setVectorProp(mapping, embedding, node); + } } - if (node != null) { - setProperties(node, metaProps); - setVectorProp(mapping, embedding, node); - } - return node; } catch (MultipleFoundException e) { throw new RuntimeException("Multiple nodes found"); @@ -173,7 +174,7 @@ private static Entity handleMappingRel(Transaction transaction, VectorMappingCon Relationship rel; Object propValue = metaProps.get(mapping.getMetadataKey()); rel = transaction.findRelationship(RelationshipType.withName(mapping.getRelType()), mapping.getEntityKey(), propValue); - if (rel != null) { + if (mapping.isUpdateMode() && rel != null) { setProperties(rel, metaProps); setVectorProp(mapping, embedding, rel); } diff --git a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java index cea4af7117..1b455d2952 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java +++ b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -80,11 +80,4 @@ private static String getUrl(String hostOrKey, VectorDbHandler handler, Map configuration, String procName) { - if (configuration.containsKey(MAPPING_KEY)) { - throw new RuntimeException(ERROR_READONLY_MAPPING + "\n" + - "Try the equivalent procedure, which is the " + procName); - } - } - } diff --git a/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java b/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java index 2b91d049c0..6f996c07d2 100644 --- a/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java +++ b/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java @@ -23,6 +23,7 @@ public class VectorMappingConfig { private final String similarity; private final boolean create; + private boolean updateMode = false; public VectorMappingConfig(Map mapping) { if (mapping == null) { @@ -67,4 +68,12 @@ public boolean isCreate() { public String getSimilarity() { return similarity; } + + public boolean isUpdateMode() { + return updateMode; + } + + public void setUpdateMode(boolean updateMode) { + this.updateMode = updateMode; + } } diff --git a/extended/src/main/java/apoc/vectordb/Weaviate.java b/extended/src/main/java/apoc/vectordb/Weaviate.java index 7ea0b19463..99a6a60efa 100644 --- a/extended/src/main/java/apoc/vectordb/Weaviate.java +++ b/extended/src/main/java/apoc/vectordb/Weaviate.java @@ -144,7 +144,7 @@ public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration, true); } @Procedure(value = "apoc.vectordb.weaviate.get") @@ -153,15 +153,12 @@ public Stream get(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) { - return getCommon(hostOrKey, collection, ids, configuration, true); + return getCommon(hostOrKey, collection, ids, configuration, false); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean updateMode) { Map config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema"); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.chroma.getAndUpdate"); - } + /** * TODO: we put method: null as a workaround, it should be "GET": https://weaviate.io/developers/weaviate/api/rest#tag/objects/get/objects/{className}/{id} @@ -172,6 +169,8 @@ private Stream getCommon(String hostOrKey, String collection, L List fields = procedureCallContext.outputFields().toList(); VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + conf.getMapping().setUpdateMode(updateMode); + boolean hasEmbedding = fields.contains("vector") && conf.isAllResults(); boolean hasMetadata = fields.contains("metadata"); VectorMappingConfig mapping = conf.getMapping(); @@ -200,8 +199,7 @@ public Stream query(@Name("hostOrKey") String hostOrKey, @Name(value = "filter", defaultValue = "null") Object filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - checkMappingConf(configuration, "apoc.vectordb.weaviate.queryAndUpdate"); - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); } @@ -213,17 +211,15 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe @Name(value = "filter", defaultValue = "null") Object filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration, boolean updateMode) throws Exception { Map config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/graphql"); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.weaviate.queryAndUpdate"); - } - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + conf.getMapping().setUpdateMode(updateMode); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> { Object getValue = ((Map) v).get("data").get("Get"); diff --git a/extended/src/test/java/apoc/vectordb/PineconeTest.java b/extended/src/test/java/apoc/vectordb/PineconeTest.java index fc8e2b4934..d7a4c895be 100644 --- a/extended/src/test/java/apoc/vectordb/PineconeTest.java +++ b/extended/src/test/java/apoc/vectordb/PineconeTest.java @@ -3,7 +3,6 @@ import apoc.util.MapUtil; import apoc.util.TestUtil; import apoc.util.Util; -import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -29,9 +28,9 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; @@ -39,7 +38,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import static org.junit.Assert.fail; import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; @@ -363,20 +361,23 @@ public void getVectorsWithCreateNodeUsingExistingNode() { assertNodesCreated(db); } - + @Test public void getReadOnlyVectorsWithMapping() { - Map conf = MapUtil.map(ALL_RESULTS_KEY, true, - MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect")); - - try { - testCall(db, "CALL apoc.vectordb.pinecone.get($host, 'TestCollection', [1, 2], $conf)", - Util.map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, ADMIN_AUTHORIZATION, + MAPPING_KEY, map( + NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo")); + + testResult(db, "CALL apoc.vectordb.pinecone.get($host, 'TestCollection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + Util.map("host", HOST, "coll", collName, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); } @Test @@ -407,6 +408,24 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", assertRelsCreated(db); } + @Test + public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, ADMIN_AUTHORIZATION, + MAPPING_KEY, map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, "CALL apoc.vectordb.pinecone.query($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "coll", collName, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); + } + @Test public void queryVectorsWithSystemDbStorage() { String keyConfig = "pinecone-config-foo"; diff --git a/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java b/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java index 0d7247b66e..ab68f98d9f 100644 --- a/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java +++ b/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java @@ -1,5 +1,6 @@ package apoc.vectordb; +import apoc.util.MapUtil; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.ResourceIterator; @@ -11,6 +12,7 @@ import static apoc.util.Util.map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; public class VectorDbTestUtil { @@ -82,4 +84,20 @@ private static void assertBerlinProperties(Map props) { public static Map getAuthHeader(String key) { return map("Authorization", "Bearer " + key); } + + public static void assertReadOnlyProcWithMappingResults(Result r, String node) { + Map row = r.next(); + Map props = ((Entity) row.get(node)).getAllProperties(); + assertEquals(MapUtil.map("readID", "one"), props); + assertNotNull(row.get("vector")); + assertNotNull(row.get("id")); + + row = r.next(); + props = ((Entity) row.get(node)).getAllProperties(); + assertEquals(MapUtil.map("readID", "two"), props); + assertNotNull(row.get("vector")); + assertNotNull(row.get("id")); + + assertFalse(r.hasNext()); + } }