From ba47c81c82bffbcbe23096096b7e35851134ac8b Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Wed, 29 May 2024 18:57:34 +0200 Subject: [PATCH] Fixes #4090: The apoc.vectordb.*.get/query procedures should search for nodes/relationships with mapping config (#4092) * Fixes #4090: The apoc.vectordb.*.get/query procedures should search for nodes/relationships with mapping config * changed create mode from boolean to enum --- .../database-integration/vectordb/chroma.adoc | 47 +++++++++++----- .../database-integration/vectordb/milvus.adoc | 27 ++++++++-- .../vectordb/pinecone.adoc | 27 ++++++++-- .../database-integration/vectordb/qdrant.adoc | 26 +++++++-- .../vectordb/weaviate.adoc | 28 ++++++++-- .../test/java/apoc/vectordb/ChromaDbTest.java | 52 +++++++++--------- .../test/java/apoc/vectordb/MilvusTest.java | 54 +++++++++++++------ .../test/java/apoc/vectordb/QdrantTest.java | 49 +++++++++-------- .../test/java/apoc/vectordb/WeaviateTest.java | 54 +++++++++++-------- .../src/main/java/apoc/vectordb/ChromaDb.java | 23 ++++---- .../src/main/java/apoc/vectordb/Milvus.java | 31 +++++------ .../src/main/java/apoc/vectordb/Pinecone.java | 26 ++++----- .../src/main/java/apoc/vectordb/Qdrant.java | 24 ++++----- .../src/main/java/apoc/vectordb/VectorDb.java | 39 ++++++++++---- .../main/java/apoc/vectordb/VectorDbUtil.java | 11 ++-- .../apoc/vectordb/VectorMappingConfig.java | 21 ++++---- .../src/main/java/apoc/vectordb/Weaviate.java | 26 ++++----- .../test/java/apoc/vectordb/PineconeTest.java | 53 ++++++++++++------ .../java/apoc/vectordb/VectorDbTestUtil.java | 18 +++++++ 19 files changed, 408 insertions(+), 228 deletions(-) diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc index 46654f92f1..f773e83976 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc @@ -119,17 +119,6 @@ CALL apoc.vectordb.chroma.queryAndUpdate($host, | ... |=== -[NOTE] -==== -We can use mapping with `apoc.vectordb.chroma.getAndUpdate` procedure as well -==== - -[NOTE] -==== -To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures. -For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]}, -so that we do not return the other values that we do not need. -==== We can define a mapping, to fetch the associated nodes and relationships and optionally create them, by leveraging the vector metadata. @@ -157,7 +146,7 @@ which will be returned in the `entity` column result. -Or else, we can create a node if not exists, via `create: true`: +We can also set the mapping configuration `mode` to `CREATE_IF_MISSING` (which creates nodes if not exist), `READ_ONLY` (to search for nodes/rels, without making updates) or `UPDATE_EXISTING` (default behavior): [source,cypher] ---- @@ -166,7 +155,7 @@ CALL apoc.vectordb.chroma.queryAndUpdate($host, '', {}, 5, { mapping: { - create: true, + mode: "CREATE_IF_MISSING", embeddingKey: "vect", nodeLabel: "Test", entityKey: "myId", @@ -200,6 +189,38 @@ and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. +We can also use mapping for `apoc.vectordb.chroma.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates +(i.e. equivalent to `*.queryOrUpdate` procedure with mapping config having `mode: "READ_ONLY"`). + +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.weaviate.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { fields: ["city", "foo"], + mapping: { + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +[NOTE] +==== +We can use mapping with `apoc.vectordb.chroma.get*` procedures as well +==== + +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures. +For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]}, +so that we do not return the other values that we do not need. +==== + .Delete vectors (it leverages https://docs.trychroma.com/usage-guide#deleting-data-from-a-collection[this API]) [source,cypher] ---- diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc index 8f1ee8198a..02f3bd7db5 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc @@ -147,7 +147,7 @@ which populates the two nodes as: `(:Test {myId: 'one', city: 'Berlin', vect: [v which will be returned in the `entity` column result. -Or else, we can create a node if not exists, via `create: true`: +We can also set the mapping configuration `mode` to `CREATE_IF_MISSING` (which creates nodes if not exist), `READ_ONLY` (to search for nodes/rels, without making updates) or `UPDATE_EXISTING` (default behavior): [source,cypher] ---- @@ -156,7 +156,7 @@ CALL apoc.vectordb.milvus.queryAndUpdate('http://localhost:19531', 'test_collect {}, 5, { mapping: { - create: true, + mode: "CREATE_IF_MISSING", embeddingKey: "vect", nodeLabel: "Test", entityKey: "myId", @@ -189,9 +189,30 @@ which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. + +We can also use mapping for `apoc.vectordb.milvus.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates +(i.e. equivalent to `*.queryOrUpdate` procedure with mapping config having `mode: "READ_ONLY"`). + +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.milvus.query('http://localhost:19531', 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + [NOTE] ==== -We can use mapping with `apoc.vectordb.milvus.getAndUpdate` procedure as well +We can use mapping with `apoc.vectordb.milvus.get*` procedures as well ==== [NOTE] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc index 4bf59e6322..d397507319 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc @@ -161,7 +161,7 @@ which populates the two nodes as: `(:Test {myId: 'one', city: 'Berlin', vect: [v which will be returned in the `entity` column result. -Or else, we can create a node if not exists, via `create: true`: +We can also set the mapping configuration `mode` to `CREATE_IF_MISSING` (which creates nodes if not exist), `READ_ONLY` (to search for nodes/rels, without making updates) or `UPDATE_EXISTING` (default behavior): [source,cypher] ---- @@ -170,7 +170,7 @@ CALL apoc.vectordb.pinecone.queryAndUpdate($host, 'test-index', {}, 5, { mapping: { - create: true, + mode: "CREATE_IF_MISSING", embeddingKey: "vect", nodeLabel: "Test", entityKey: "myId", @@ -203,9 +203,30 @@ which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. + +We can also use mapping for `apoc.vectordb.pinecone.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates +(i.e. equivalent to `*.queryOrUpdate` procedure with mapping config having `mode: "READ_ONLY"`). + +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.pinecone.query($host, 'test-index', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + [NOTE] ==== -We can use mapping with `apoc.vectordb.pinecone.getAndUpdate` procedure as well +We can use mapping with `apoc.vectordb.pinecone.get*` procedures as well ==== [NOTE] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc index 1f6cce4b97..c604766abf 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc @@ -149,7 +149,7 @@ which populates the two nodes as: `(:Test {myId: 'one', city: 'Berlin', vect: [v which will be returned in the `entity` column result. -Or else, we can create a node if not exists, via `create: true`: +We can also set the mapping configuration `mode` to `CREATE_IF_MISSING` (which creates nodes if not exist), `READ_ONLY` (to search for nodes/rels, without making updates) or `UPDATE_EXISTING` (default behavior): [source,cypher] ---- @@ -158,7 +158,7 @@ CALL apoc.vectordb.qdrant.queryAndUpdate($hostOrKey, 'test_collection', {}, 5, { mapping: { - create: true, + mode: "CREATE_IF_MISSING", embeddingKey: "vect", nodeLabel: "Test", entityKey: "myId", @@ -191,9 +191,29 @@ which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. + +We can also use mapping for `apoc.vectordb.qdrant.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates +(i.e. equivalent to `*.queryOrUpdate` procedure with mapping config having `mode: "READ_ONLY"`). + +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($hostOrKey, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + [NOTE] ==== -We can use mapping with `apoc.vectordb.qdrant.getAndUpdate` procedure as well +We can use mapping with `apoc.vectordb.qdrant.get*` procedures as well ==== [NOTE] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc index 3ab4540403..5064a28975 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc @@ -160,7 +160,7 @@ and `(:Test {myId: 'two', city: 'London', vect: [vector2]})`, which will be returned in the `entity` column result. -Or else, we can create a node if not exists, via `create: true`: +We can also set the mapping configuration `mode` to `CREATE_IF_MISSING` (which creates nodes if not exist), `READ_ONLY` (to search for nodes/rels, without making updates) or `UPDATE_EXISTING` (default behavior): [source,cypher] ---- @@ -170,7 +170,7 @@ CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'test_collection', 5, { fields: ["city", "foo"], mapping: { - create: true, + mode: "CREATE_IF_MISSING", embeddingKey: "vect", nodeLabel: "Test", entityKey: "myId", @@ -205,9 +205,31 @@ and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, which will be returned in the `entity` column result. +We can also use mapping for `apoc.vectordb.weaviate.query` procedure, to search for nodes/rels fitting label/type and metadataKey, without making updates +(i.e. equivalent to `*.queryOrUpdate` procedure with mapping config having `mode: "READ_ONLY"`). + +For example, with the previous relationships, we can execute the following procedure, which just return the relationships in the column `rel`: + +[source,cypher] +---- +CALL apoc.vectordb.weaviate.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { fields: ["city", "foo"], + mapping: { + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + + + [NOTE] ==== -We can use mapping with `apoc.vectordb.weaviate.getAndUpdate` procedure as well +We can use mapping with `apoc.vectordb.weaviate.get*` procedures as well ==== [NOTE] diff --git a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java index 3b0b136701..7904f12dba 100644 --- a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java +++ b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java @@ -1,8 +1,6 @@ package apoc.vectordb; import apoc.util.TestUtil; -import apoc.util.Util; -import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -25,10 +23,10 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; import static apoc.vectordb.VectorDbTestUtil.EntityType.*; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; @@ -231,7 +229,8 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo", - CREATE_KEY, true) + MODE_KEY, MappingMode.CREATE_IF_MISSING.toString() + ) ); testResult(db, "CALL apoc.vectordb.chroma.queryAndUpdate($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", @@ -294,19 +293,22 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", assertNodesCreated(db); } + @Test public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); - - try { - testCall(db, "CALL apoc.vectordb.chroma.get($host, $collection, [1, 2], $conf)", - map("host", HOST, "collection", COLL_ID.get(), "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + MAPPING_KEY, map(NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, "CALL apoc.vectordb.chroma.get($host, $collection, ['1', '2'], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); } @Test @@ -338,17 +340,19 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", @Test public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); - - try { - testCall(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - map("host", HOST, "collection", COLL_ID.get(), "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + MAPPING_KEY, map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); } @Test diff --git a/extended-it/src/test/java/apoc/vectordb/MilvusTest.java b/extended-it/src/test/java/apoc/vectordb/MilvusTest.java index 513b6167ec..8ab4d342cb 100644 --- a/extended-it/src/test/java/apoc/vectordb/MilvusTest.java +++ b/extended-it/src/test/java/apoc/vectordb/MilvusTest.java @@ -2,7 +2,6 @@ import apoc.util.TestUtil; import apoc.util.Util; -import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -27,22 +26,22 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; -import static apoc.vectordb.VectorMappingConfig.CREATE_KEY; +import static apoc.vectordb.VectorMappingConfig.MODE_KEY; import static apoc.vectordb.VectorMappingConfig.EMBEDDING_KEY; import static apoc.vectordb.VectorMappingConfig.ENTITY_KEY; import static apoc.vectordb.VectorMappingConfig.METADATA_KEY; import static apoc.vectordb.VectorMappingConfig.NODE_LABEL; import static apoc.vectordb.VectorMappingConfig.REL_TYPE; +import static apoc.vectordb.VectorMappingConfig.MappingMode; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import static org.junit.Assert.fail; import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; @@ -231,7 +230,8 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo", - CREATE_KEY, true) + MODE_KEY, MappingMode.CREATE_IF_MISSING.toString() + ) ); testResult(db, "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", map("host", HOST, "conf", conf), @@ -297,6 +297,24 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", assertNodesCreated(db); } + @Test + public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + + Map conf = map(ALL_RESULTS_KEY, true, + FIELDS_KEY, FIELDS, + MAPPING_KEY, map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo")); + + testResult(db, "CALL apoc.vectordb.milvus.get($host, 'test_collection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); + } + @Test public void queryVectorsWithCreateNodeUsingExistingNode() { @@ -336,7 +354,8 @@ public void queryVectorsWithCreateRel() { MAPPING_KEY, map(EMBEDDING_KEY, "vect", REL_TYPE, "TEST", ENTITY_KEY, "myId", - METADATA_KEY, "foo")); + METADATA_KEY, "foo") + ); testResult(db, "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", map("host", HOST, "conf", conf), r -> { @@ -356,17 +375,20 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", @Test public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); - - try { - testCall(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + FIELDS_KEY, FIELDS, + MAPPING_KEY, map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); } @Test diff --git a/extended-it/src/test/java/apoc/vectordb/QdrantTest.java b/extended-it/src/test/java/apoc/vectordb/QdrantTest.java index 9f15093526..5d1d4c01cc 100644 --- a/extended-it/src/test/java/apoc/vectordb/QdrantTest.java +++ b/extended-it/src/test/java/apoc/vectordb/QdrantTest.java @@ -2,7 +2,6 @@ import apoc.util.TestUtil; import apoc.util.Util; -import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -27,15 +26,16 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -263,7 +263,8 @@ MAPPING_KEY, map( NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo", - CREATE_KEY, true) + MODE_KEY, MappingMode.CREATE_IF_MISSING.toString() + ) ); testResult(db, "CALL apoc.vectordb.qdrant.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", map("host", HOST, "conf", conf), @@ -331,17 +332,20 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", @Test public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map(NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); - try { - testCall(db, "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1, 2], $conf)", - map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + testResult(db, "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); } @Test @@ -405,17 +409,20 @@ MAPPING_KEY, map( @Test public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + Map conf = map(ALL_RESULTS_KEY, true, - MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); - try { - testCall(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); } @Test diff --git a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java index ff5a31d751..4461960ba7 100644 --- a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java +++ b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java @@ -2,7 +2,6 @@ import apoc.util.MapUtil; import apoc.util.TestUtil; -import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -27,7 +26,6 @@ import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE; import static apoc.vectordb.VectorDbTestUtil.*; import static apoc.vectordb.VectorDbTestUtil.EntityType.*; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; @@ -255,8 +253,9 @@ public void queryVectorsWithCreateNode() { MAPPING_KEY, map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", - METADATA_KEY, "foo", - CREATE_KEY, true) + METADATA_KEY, "foo", + MODE_KEY, MappingMode.CREATE_IF_MISSING.toString() + ) ); testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + "YIELD score, vector, id, metadata, node RETURN * ORDER BY id", @@ -353,24 +352,28 @@ public void getVectorsWithCreateNodeUsingExistingNode() { assertNodesCreated(db); } + @Test public void getReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + Map conf = MapUtil.map(ALL_RESULTS_KEY, true, - MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect")); + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, MapUtil.map( + NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); - try { - testCall(db, "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', [1, 2], $conf)", - map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + testResult(db, "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', [$id1, $id2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + MapUtil.map("host", HOST, "id1", ID_1, "id2", ID_2, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); } @Test public void queryVectorsWithCreateRel() { - db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); Map conf = map(ALL_RESULTS_KEY, true, @@ -400,17 +403,22 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", @Test public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + Map conf = MapUtil.map(ALL_RESULTS_KEY, true, - MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect")); + FIELDS_KEY, FIELDS, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, MapUtil.map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); - try { - testCall(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - MapUtil.map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata, rel RETURN * ORDER BY id", + MapUtil.map("host", HOST, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); } @Test diff --git a/extended/src/main/java/apoc/vectordb/ChromaDb.java b/extended/src/main/java/apoc/vectordb/ChromaDb.java index 11d07bc2cd..23d8d1e996 100644 --- a/extended/src/main/java/apoc/vectordb/ChromaDb.java +++ b/extended/src/main/java/apoc/vectordb/ChromaDb.java @@ -129,7 +129,8 @@ public Stream get(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, true); + setReadOnlyMappingMode(configuration); + return getCommon(hostOrKey, collection, ids, configuration); } @Procedure(value = "apoc.vectordb.chroma.getAndUpdate", mode = Mode.WRITE) @@ -138,18 +139,15 @@ public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) throws Exception { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration) throws Exception { String url = "%s/api/v1/collections/%s/get"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.chroma.getAndUpdate"); - } - VectorEmbeddingConfig apiConfig = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, tx, v -> listToMap((Map) v).stream()); } @@ -162,7 +160,8 @@ public Stream query(@Name("hostOrKey") String hostOrKey, @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + setReadOnlyMappingMode(configuration); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } @Procedure(value = "apoc.vectordb.chroma.queryAndUpdate", mode = Mode.WRITE) @@ -173,16 +172,12 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration) throws Exception { String url = "%s/api/v1/collections/%s/query"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.chroma.queryAndUpdate"); - } VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, diff --git a/extended/src/main/java/apoc/vectordb/Milvus.java b/extended/src/main/java/apoc/vectordb/Milvus.java index c97d45b4e6..05ce11468e 100644 --- a/extended/src/main/java/apoc/vectordb/Milvus.java +++ b/extended/src/main/java/apoc/vectordb/Milvus.java @@ -127,13 +127,14 @@ public Stream delete( .map(MapResult::new); } - @Procedure(value = "apoc.vectordb.milvus.get", mode = Mode.WRITE) + @Procedure(value = "apoc.vectordb.milvus.get") @Description("apoc.vectordb.milvus.get(hostOrKey, collection, ids, $configuration) - Get the vectors with the specified `ids`") public Stream get(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, true); + setReadOnlyMappingMode(configuration); + return getCommon(hostOrKey, collection, ids, configuration); } @Procedure(value = "apoc.vectordb.milvus.getAndUpdate", mode = Mode.WRITE) @@ -142,17 +143,13 @@ public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration/*, true*/); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) throws Exception { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration) throws Exception { String url = "%s/entities/get"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.milvus.getAndUpdate"); - } - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> getMapStream((Map) v)); @@ -166,8 +163,8 @@ public Stream query(@Name("hostOrKey") String hostOrKey, @Name(value = "filter", defaultValue = "null") Object filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + setReadOnlyMappingMode(configuration); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } @Procedure(value = "apoc.vectordb.milvus.queryAndUpdate", mode = Mode.WRITE) @@ -178,8 +175,7 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe @Name(value = "filter", defaultValue = "null") Object filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } private Stream getMapStream(Map v) { @@ -198,16 +194,13 @@ private Stream getMapStream(Map v) { }); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration) throws Exception { String url = "%s/entities/search"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.milvus.queryAndUpdate"); - } - VectorEmbeddingConfig apiConfig = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); - return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, tx, + VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> getMapStream((Map) v)); } diff --git a/extended/src/main/java/apoc/vectordb/Pinecone.java b/extended/src/main/java/apoc/vectordb/Pinecone.java index 036261143f..896213b11b 100644 --- a/extended/src/main/java/apoc/vectordb/Pinecone.java +++ b/extended/src/main/java/apoc/vectordb/Pinecone.java @@ -22,8 +22,8 @@ import static apoc.vectordb.VectorDb.executeRequest; import static apoc.vectordb.VectorDb.getEmbeddingResultStream; import static apoc.vectordb.VectorDbHandler.Type.PINECONE; -import static apoc.vectordb.VectorDbUtil.checkMappingConf; import static apoc.vectordb.VectorDbUtil.getCommonVectorDbInfo; +import static apoc.vectordb.VectorDbUtil.setReadOnlyMappingMode; @Extended public class Pinecone { @@ -133,7 +133,8 @@ public Stream get(@Name("hostOrKey") String hostOr @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, true); + setReadOnlyMappingMode(configuration); + return getCommon(hostOrKey, collection, ids, configuration); } @Procedure(value = "apoc.vectordb.pinecone.getAndUpdate", mode = Mode.WRITE) @@ -142,18 +143,15 @@ public Stream getAndUpdate(@Name("hostOrKey") Stri @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) throws Exception { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration) throws Exception { String url = "%s/vectors/fetch"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.pinecone.getAndUpdate"); - } VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> { Object vectors = ((Map) v).get("vectors"); @@ -170,7 +168,8 @@ public Stream query(@Name("hostOrKey") String host @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + setReadOnlyMappingMode(configuration); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } @Procedure(value = "apoc.vectordb.pinecone.queryAndUpdate", mode = Mode.WRITE) @@ -181,18 +180,15 @@ public Stream queryAndUpdate(@Name("hostOrKey") St @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration) throws Exception { String url = "%s/query"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.pinecone.queryAndUpdate"); - } - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> { Map map = (Map) v; diff --git a/extended/src/main/java/apoc/vectordb/Qdrant.java b/extended/src/main/java/apoc/vectordb/Qdrant.java index da7f42903d..f381e2b158 100644 --- a/extended/src/main/java/apoc/vectordb/Qdrant.java +++ b/extended/src/main/java/apoc/vectordb/Qdrant.java @@ -131,7 +131,8 @@ public Stream get(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, true); + setReadOnlyMappingMode(configuration); + return getCommon(hostOrKey, collection, ids, configuration); } @Procedure(value = "apoc.vectordb.qdrant.getAndUpdate", mode = Mode.WRITE) @@ -140,18 +141,15 @@ public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) throws Exception { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration) throws Exception { String url = "%s/collections/%s/points"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.qdrant.getAndUpdate"); - } - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx); } @@ -163,7 +161,8 @@ public Stream query(@Name("hostOrKey") String hostOrKey, @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + setReadOnlyMappingMode(configuration); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } @Procedure(value = "apoc.vectordb.qdrant.queryAndUpdate", mode = Mode.WRITE) @@ -174,18 +173,15 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe @Name(value = "filter", defaultValue = "{}") Map filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Map filter, long limit, Map configuration) throws Exception { String url = "%s/collections/%s/points/search"; Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.qdrant.queryAndUpdate"); - } VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx); } diff --git a/extended/src/main/java/apoc/vectordb/VectorDb.java b/extended/src/main/java/apoc/vectordb/VectorDb.java index dabf699c5e..6123e84966 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDb.java +++ b/extended/src/main/java/apoc/vectordb/VectorDb.java @@ -152,15 +152,21 @@ private static Entity handleMappingNode(Transaction transaction, VectorMappingCo Node node; Object propValue = metaProps.get(mapping.getMetadataKey()); node = transaction.findNode(Label.label(mapping.getNodeLabel()), mapping.getEntityKey(), propValue); - if (node == null && mapping.isCreate()) { - node = transaction.createNode(Label.label(mapping.getNodeLabel())); - node.setProperty(mapping.getEntityKey(), propValue); + switch (mapping.getMode()) { + case READ_ONLY -> { + // do nothing, just return the entity + } + case UPDATE_EXISTING -> { + setPropsIfEntityExists(mapping, metaProps, embedding, node); + } + case CREATE_IF_MISSING -> { + if (node == null) { + node = transaction.createNode(Label.label(mapping.getNodeLabel())); + node.setProperty(mapping.getEntityKey(), propValue); + } + setPropsIfEntityExists(mapping, metaProps, embedding, node); + } } - if (node != null) { - setProperties(node, metaProps); - setVectorProp(mapping, embedding, node); - } - return node; } catch (MultipleFoundException e) { throw new RuntimeException("Multiple nodes found"); @@ -173,9 +179,13 @@ private static Entity handleMappingRel(Transaction transaction, VectorMappingCon Relationship rel; Object propValue = metaProps.get(mapping.getMetadataKey()); rel = transaction.findRelationship(RelationshipType.withName(mapping.getRelType()), mapping.getEntityKey(), propValue); - if (rel != null) { - setProperties(rel, metaProps); - setVectorProp(mapping, embedding, rel); + switch (mapping.getMode()) { + case READ_ONLY -> { + // do nothing, just return the entity + } + case UPDATE_EXISTING, CREATE_IF_MISSING -> { + setPropsIfEntityExists(mapping, metaProps, embedding, rel); + } } return rel; @@ -184,6 +194,13 @@ private static Entity handleMappingRel(Transaction transaction, VectorMappingCon } } + private static void setPropsIfEntityExists(VectorMappingConfig mapping, Map metaProps, List embedding, Entity entity) { + if (entity != null) { + setProperties(entity, metaProps); + setVectorProp(mapping, embedding, entity); + } + } + private static void setVectorProp(VectorMappingConfig mapping, List embedding, T entity) { if (mapping.getEmbeddingKey() == null) { return; diff --git a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java index cea4af7117..9a16cd1d12 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java +++ b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -17,6 +17,8 @@ import static apoc.ml.RestAPIConfig.ENDPOINT_KEY; import static apoc.util.SystemDbUtil.withSystemDb; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static apoc.vectordb.VectorMappingConfig.MODE_KEY; +import static apoc.vectordb.VectorMappingConfig.MappingMode.READ_ONLY; public class VectorDbUtil { @@ -80,11 +82,8 @@ private static String getUrl(String hostOrKey, VectorDbHandler handler, Map configuration, String procName) { - if (configuration.containsKey(MAPPING_KEY)) { - throw new RuntimeException(ERROR_READONLY_MAPPING + "\n" + - "Try the equivalent procedure, which is the " + procName); - } + public static void setReadOnlyMappingMode(Map configuration) { + Map mappingConf = (Map) configuration.getOrDefault(MAPPING_KEY, new HashMap<>()); + mappingConf.put(MODE_KEY, READ_ONLY.toString()); } - } diff --git a/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java b/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java index 2b91d049c0..3850544179 100644 --- a/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java +++ b/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java @@ -1,18 +1,20 @@ package apoc.vectordb; -import apoc.util.Util; - import java.util.Collections; import java.util.Map; public class VectorMappingConfig { + enum MappingMode { + READ_ONLY, UPDATE_EXISTING, CREATE_IF_MISSING + } + public static final String METADATA_KEY = "metadataKey"; public static final String ENTITY_KEY = "entityKey"; public static final String NODE_LABEL = "nodeLabel"; public static final String REL_TYPE = "relType"; public static final String EMBEDDING_KEY = "embeddingKey"; public static final String SIMILARITY_KEY = "similarity"; - public static final String CREATE_KEY = "create"; + public static final String MODE_KEY = "mode"; private final String metadataKey; private final String entityKey; @@ -22,7 +24,7 @@ public class VectorMappingConfig { private final String embeddingKey; private final String similarity; - private final boolean create; + private MappingMode mode; public VectorMappingConfig(Map mapping) { if (mapping == null) { @@ -37,7 +39,8 @@ public VectorMappingConfig(Map mapping) { this.similarity = (String) mapping.getOrDefault(SIMILARITY_KEY, "cosine"); - this.create = Util.toBoolean(mapping.get(CREATE_KEY)); + String modeValue = (String) mapping.getOrDefault(MODE_KEY, MappingMode.UPDATE_EXISTING.toString() ); + this.mode = MappingMode.valueOf( modeValue.toUpperCase() ); } public String getMetadataKey() { @@ -60,11 +63,11 @@ public String getEmbeddingKey() { return embeddingKey; } - public boolean isCreate() { - return create; - } - public String getSimilarity() { return similarity; } + + public MappingMode getMode() { + return mode; + } } diff --git a/extended/src/main/java/apoc/vectordb/Weaviate.java b/extended/src/main/java/apoc/vectordb/Weaviate.java index 7ea0b19463..7653c32e46 100644 --- a/extended/src/main/java/apoc/vectordb/Weaviate.java +++ b/extended/src/main/java/apoc/vectordb/Weaviate.java @@ -144,7 +144,7 @@ public Stream getAndUpdate(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) { - return getCommon(hostOrKey, collection, ids, configuration, false); + return getCommon(hostOrKey, collection, ids, configuration); } @Procedure(value = "apoc.vectordb.weaviate.get") @@ -153,15 +153,13 @@ public Stream get(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) { - return getCommon(hostOrKey, collection, ids, configuration, true); + setReadOnlyMappingMode(configuration); + return getCommon(hostOrKey, collection, ids, configuration); } - private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) { + private Stream getCommon(String hostOrKey, String collection, List ids, Map configuration) { Map config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema"); - - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.chroma.getAndUpdate"); - } + /** * TODO: we put method: null as a workaround, it should be "GET": https://weaviate.io/developers/weaviate/api/rest#tag/objects/get/objects/{className}/{id} @@ -172,6 +170,7 @@ private Stream getCommon(String hostOrKey, String collection, L List fields = procedureCallContext.outputFields().toList(); VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); + boolean hasEmbedding = fields.contains("vector") && conf.isAllResults(); boolean hasMetadata = fields.contains("metadata"); VectorMappingConfig mapping = conf.getMapping(); @@ -200,8 +199,8 @@ public Stream query(@Name("hostOrKey") String hostOrKey, @Name(value = "filter", defaultValue = "null") Object filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - checkMappingConf(configuration, "apoc.vectordb.weaviate.queryAndUpdate"); - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + setReadOnlyMappingMode(configuration); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } @@ -213,17 +212,14 @@ public Stream queryAndUpdate(@Name("hostOrKey") String hostOrKe @Name(value = "filter", defaultValue = "null") Object filter, @Name(value = "limit", defaultValue = "10") long limit, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration); } - private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration, boolean readOnly) throws Exception { + private Stream queryCommon(String hostOrKey, String collection, List vector, Object filter, long limit, Map configuration) throws Exception { Map config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/graphql"); - if (readOnly) { - checkMappingConf(configuration, "apoc.vectordb.weaviate.queryAndUpdate"); - } - VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> { Object getValue = ((Map) v).get("data").get("Get"); diff --git a/extended/src/test/java/apoc/vectordb/PineconeTest.java b/extended/src/test/java/apoc/vectordb/PineconeTest.java index fc8e2b4934..ead9db3e53 100644 --- a/extended/src/test/java/apoc/vectordb/PineconeTest.java +++ b/extended/src/test/java/apoc/vectordb/PineconeTest.java @@ -3,7 +3,6 @@ import apoc.util.MapUtil; import apoc.util.TestUtil; import apoc.util.Util; -import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -29,9 +28,9 @@ import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; @@ -39,7 +38,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import static org.junit.Assert.fail; import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; @@ -270,7 +268,9 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo", - CREATE_KEY, true)); + MODE_KEY, MappingMode.CREATE_IF_MISSING.toString() + ) + ); testResult(db, "CALL apoc.vectordb.pinecone.queryAndUpdate($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", map("host", HOST, "coll", collName, "conf", conf), r -> { @@ -363,20 +363,23 @@ public void getVectorsWithCreateNodeUsingExistingNode() { assertNodesCreated(db); } - + @Test public void getReadOnlyVectorsWithMapping() { - Map conf = MapUtil.map(ALL_RESULTS_KEY, true, - MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect")); - - try { - testCall(db, "CALL apoc.vectordb.pinecone.get($host, 'TestCollection', [1, 2], $conf)", - Util.map("host", HOST, "conf", conf), - r -> fail() - ); - } catch (RuntimeException e) { - Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); - } + db.executeTransactionally("CREATE (:Test {readID: 'one'}), (:Test {readID: 'two'})"); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, ADMIN_AUTHORIZATION, + MAPPING_KEY, map( + NODE_LABEL, "Test", + ENTITY_KEY, "readID", + METADATA_KEY, "foo")); + + testResult(db, "CALL apoc.vectordb.pinecone.get($host, 'TestCollection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + Util.map("host", HOST, "coll", collName, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "node") + ); } @Test @@ -407,6 +410,24 @@ MAPPING_KEY, map(EMBEDDING_KEY, "vect", assertRelsCreated(db); } + @Test + public void queryReadOnlyVectorsWithMapping() { + db.executeTransactionally("CREATE (:Start)-[:TEST {readID: 'one'}]->(:End), (:Start)-[:TEST {readID: 'two'}]->(:End)"); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, ADMIN_AUTHORIZATION, + MAPPING_KEY, map( + REL_TYPE, "TEST", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, "CALL apoc.vectordb.pinecone.query($host, $coll, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "coll", collName, "conf", conf), + r -> assertReadOnlyProcWithMappingResults(r, "rel") + ); + } + @Test public void queryVectorsWithSystemDbStorage() { String keyConfig = "pinecone-config-foo"; diff --git a/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java b/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java index 0d7247b66e..ab68f98d9f 100644 --- a/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java +++ b/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java @@ -1,5 +1,6 @@ package apoc.vectordb; +import apoc.util.MapUtil; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.ResourceIterator; @@ -11,6 +12,7 @@ import static apoc.util.Util.map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; public class VectorDbTestUtil { @@ -82,4 +84,20 @@ private static void assertBerlinProperties(Map props) { public static Map getAuthHeader(String key) { return map("Authorization", "Bearer " + key); } + + public static void assertReadOnlyProcWithMappingResults(Result r, String node) { + Map row = r.next(); + Map props = ((Entity) row.get(node)).getAllProperties(); + assertEquals(MapUtil.map("readID", "one"), props); + assertNotNull(row.get("vector")); + assertNotNull(row.get("id")); + + row = r.next(); + props = ((Entity) row.get(node)).getAllProperties(); + assertEquals(MapUtil.map("readID", "two"), props); + assertNotNull(row.get("vector")); + assertNotNull(row.get("id")); + + assertFalse(r.hasNext()); + } }