From 1ef8a14ae59295f05f67c21343dfaf4b219ce7cf Mon Sep 17 00:00:00 2001 From: vga91 Date: Thu, 14 Mar 2024 10:44:04 +0100 Subject: [PATCH] Fixes #3971: Check how to integrate vector databases via rest APIs --- docs/asciidoc/modules/ROOT/nav.adoc | 1 + .../pages/database-integration/index.adoc | 1 + .../pages/database-integration/vectordb.adoc | 447 ++++++++++++++++++ extended-it/build.gradle | 10 +- .../test/java/apoc/vectordb/ChromaDbTest.java | 268 +++++++++++ .../test/java/apoc/vectordb/QdrantDbTest.java | 263 +++++++++++ .../src/main/java/apoc/ml/RestAPIConfig.java | 69 +++ .../src/main/java/apoc/util/ExtendedUtil.java | 42 ++ .../src/main/java/apoc/vectordb/ChromaDb.java | 236 +++++++++ .../src/main/java/apoc/vectordb/Qdrant.java | 177 +++++++ .../src/main/java/apoc/vectordb/VectorDb.java | 225 +++++++++ .../main/java/apoc/vectordb/VectorDbUtil.java | 23 + .../java/apoc/vectordb/VectorEmbedding.java | 138 ++++++ .../apoc/vectordb/VectorEmbeddingConfig.java | 62 +++ .../apoc/vectordb/VectorMappingConfig.java | 62 +++ .../test/java/apoc/vectordb/PineconeTest.java | 119 +++++ .../java/apoc/vectordb/VectorDbTestUtil.java | 109 +++++ 17 files changed, 2251 insertions(+), 1 deletion(-) create mode 100644 docs/asciidoc/modules/ROOT/pages/database-integration/vectordb.adoc create mode 100644 extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java create mode 100644 extended-it/src/test/java/apoc/vectordb/QdrantDbTest.java create mode 100644 extended/src/main/java/apoc/ml/RestAPIConfig.java create mode 100644 extended/src/main/java/apoc/vectordb/ChromaDb.java create mode 100644 extended/src/main/java/apoc/vectordb/Qdrant.java create mode 100644 extended/src/main/java/apoc/vectordb/VectorDb.java create mode 100644 extended/src/main/java/apoc/vectordb/VectorDbUtil.java create mode 100644 extended/src/main/java/apoc/vectordb/VectorEmbedding.java create mode 100644 extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java create mode 100644 extended/src/main/java/apoc/vectordb/VectorMappingConfig.java create mode 100644 extended/src/test/java/apoc/vectordb/PineconeTest.java create mode 100644 extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java diff --git a/docs/asciidoc/modules/ROOT/nav.adoc b/docs/asciidoc/modules/ROOT/nav.adoc index ae2f30addc..d848526a36 100644 --- a/docs/asciidoc/modules/ROOT/nav.adoc +++ b/docs/asciidoc/modules/ROOT/nav.adoc @@ -39,6 +39,7 @@ include::partial$generated-documentation/nav.adoc[] ** xref::database-integration/bolt-neo4j.adoc[] ** xref::database-integration/load-ldap.adoc[] ** xref::database-integration/redis.adoc[] + ** xref::database-integration/vectordb.adoc[] * xref:graph-updates/index.adoc[] ** xref::graph-updates/uuid.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/index.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/index.adoc index 37a061603f..c12016700a 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/index.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/index.adoc @@ -17,4 +17,5 @@ For more information on how to use these procedures, see: * xref::database-integration/bolt-neo4j.adoc[] * xref::database-integration/load-ldap.adoc[] * xref::database-integration/redis.adoc[] +* xref::database-integration/vectordb.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb.adoc new file mode 100644 index 0000000000..ee25b93fe0 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb.adoc @@ -0,0 +1,447 @@ +[[vectordb]] += Vector Databases +:description: This section describes procedures that can be used to interact with Vector Databases. + +APOC provides these set of procedures, which leverages the Rest APIs, to interact with Vector Databases: + +- `apoc.vectordb.qdrant.*` (to interact with https://qdrant.tech/documentation/overview/[Qdrant]) +- `apoc.vectordb.chroma.*` (to interact with https://docs.trychroma.com/getting-started[Chroma]) +- `apoc.vectordb.custom.*` (to interact with other vector databases) + + + +All the procedures can have, as a final parameter, a configuration map with these possible parameters: + +.config parameters + +|=== +| key | description +| headers | additional HTTP headers +| method | HTTP method +| endpoint | endpoint key, + can be used to override the default endpoint created via the 1st parameter of the `apoc.vectordb.qdrant.*` and `apoc.vectordb.qdrant.*`, + to handle potential endpoint changes. +| body | body HTTP request +| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response. The default is `null`. +|=== + + +Besides the above config, the `apoc.vectordb..get` and the `apoc.vectordb..query` procedures can have these additional parameters: + +.embeddingConfig parameters + +|=== +| key | description +| mapping | to auto-create indexes/entities. See examples below. +| vectorKey, metadataKey, scoreKey, textKey | used with the `apoc.vectordb.custom.get` procedure. + To let the procedure know which key in the restAPI (if present) corresponds to the one that should be populated as respectively the vector/metadata/score/text result. + Defaults are "vector", "metadata", "score", "text". + See examples below. +|=== + + +== Qdrant + +Here is a list of all available Qdrant procedures: + +[opts=header, cols="1, 3"] +|=== +| name | description +| apoc.vectordb.qdrant.createCollection(hostOrKey, collection, similarity, size, $config) | + Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`. +| apoc.vectordb.qdrant.deleteCollection(hostOrKey, collection, $config) | + Deletes a collection with the name specified in the 2nd parameter +| apoc.vectordb.qdrant.upsert(hostOrKey, collection, vectors, $config) | + Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}] +| apoc.vectordb.qdrant.delete(hostOrKey, collection, ids, $config) | + Delete the vectors with the specified `ids`. +| apoc.vectordb.qdrant.get(hostOrKey, collection, ids, $config) | + Get the vectors with the specified `ids`. +| apoc.vectordb.qdrant.query(hostOrKey, collection, vector, filter, limit, $config) | + Retrieve closest vectors the the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter. +|=== + +where the 1st parameter can be a key defined by the apoc config `apoc.qdrant..host=myHost`. + + +=== Examples + +.Create a collection +[source,cypher] +---- +CALL apoc.vectordb.qdrant.createCollection('localhot:6333', 'test_collection', 'Cosine', 4, {}) +---- + + +.Delete a collection +[source,cypher] +---- +CALL apoc.vectordb.qdrant.deleteCollection('localhot:6333', 'test_collection', {}) +---- + + +.Upsert vectors +[source,cypher] +---- +CALL apoc.vectordb.qdrant.upsert('localhot:6333', 'test_collection', + [ + {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}}, + {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}} + ], + {}) +---- + + +.Get vectors +[source,cypher] +---- +CALL apoc.vectordb.qdrant.get('localhost:6333', 'test_collection', [1,2], {}) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text +| null | {city: "Berlin", foo: "one"} | 1 | [...] | null +| null | {city: "Berlin", foo: "two"} | 2 | [...] | null +| ... +|=== + +.Query vectors +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query('localhot:6333', + 'test_collection', + [0.2, 0.1, 0.9, 0.7], + { must: + [ { key: "city", match: { value: "London" } } ] + }, + 5, + {}) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text +| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | null +| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | null +| ... +|=== + + +We can define a mapping, to auto-create an index (if not exists), a constraint, and one/multiple nodes and relationships, +by leveraging the vector metadata. + +For example, if we have created 2 vectors with the above upsert procedures, +we can populate some existing nodes (i.e. `(:Test {myId: 'one'})` and `(:Test {myId: 'two'})`): + + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingProp: "vect", + label: "Test", + prop: "myId", + id: "foo" + } + }) +---- + +which creates an index `VECTOR INDEX FOR (n:Test) ON (n.vect) OPTIONS {indexConfig: {`vector.dimensions`: 4 `vector.similarity_function`: 'cosine'}}`, + a constraint `CREATE CONSTRAINT IF NOT EXISTS FOR (n:Test) REQUIRE n.myId IS UNIQUE` + and populates the two nodes as: `(:Test {myId: 'one', city: 'Berlin', vect: [vector1]})` + and `(:Test {myId: 'two', city: 'London', vect: [vector2]})`. + + +Or else, we can create a node if not exists, via `create: true`: + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + create: true, + embeddingProp: "vect", + label: "Test", + prop: "myId", + id: "foo" + } + }) +---- + +which creates an index, a constraint and 2 new nodes as above. + +Or, we can populate an existing relationship (i.e. `(:Start)-[:TEST {myId: 'one'}]->(:End)` and `(:Start)-[:TEST {myId: 'two'}]->(:End)`): + + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingProp: "vect", + type: "TEST", + prop: "myId", + id: "foo" + } + }) +---- + +which creates an index `VECTOR INDEX FOR ()-[n:TEST]-() ON (n.vect) OPTIONS {indexConfig: {`vector.dimensions`: 4 `vector.similarity_function`: 'cosine'}}`, +a constraint `CREATE CONSTRAINT IF NOT EXISTS FOR ()-[n:TEST]-() REQUIRE n.myId IS UNIQUE` +and populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin', vect: [vector1]}]-()` +and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`. + + +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.qdrant.query and the `apoc.vectordb.qdrant.get` procedures. + +For example, by executing a `CALL apoc.vectordb.qdrant.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"with_payload": false, "with_vectors": false}, +so that we do not return the other values that we do not need. +==== + + + +.Delete vectors +[source,cypher] +---- +CALL apoc.vectordb.qdrant.delete(, 'test_collection', [1,2], {}) +---- + + +== Chroma + +The list and the signature procedures are consistent with the Qdrant ones: + + +[opts=header, cols="1, 3"] +|=== +| name | description +| apoc.vectordb.chroma.createCollection(hostOrKey, collection, similarity, size, $config) | + Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`. +| apoc.vectordb.chroma.deleteCollection(hostOrKey, collection, $config) | + Deletes a collection with the name specified in the 2nd parameter +| apoc.vectordb.chroma.upsert(hostOrKey, collection, vectors, $config) | + Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}] +| apoc.vectordb.chroma.delete(hostOrKey, collection, ids, $config) | + Delete the vectors with the specified `ids`. +| apoc.vectordb.chroma.get(hostOrKey, collection, ids, $config) | + Get the vectors with the specified `ids`. +| apoc.vectordb.chroma.query(hostOrKey, collection, vector, filter, limit, $config) | + Retrieve closest vectors the the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter. +|=== + +where the 1st parameter can be a key defined by the apoc config `apoc.chroma..host=myHost`. + +=== Examples + +.Create a collection +[source,cypher] +---- +CALL apoc.vectordb.chroma.createCollection('localhot:8000', 'test_collection', 'Cosine', 4, {}) +---- + + +.Delete a collection +[source,cypher] +---- +CALL apoc.vectordb.chroma.deleteCollection('localhot:8000', '', {}) +---- + + +.Upsert vectors +[source,cypher] +---- +CALL apoc.vectordb.qdrant.upsert('localhot:6333', '', + [ + {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}, text: 'ajeje'}, + {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}, text: 'brazorf'} + ], + {}) +---- + + +.Get vectors +[source,cypher] +---- +CALL apoc.vectordb.chroma.get('localhost:8000', '', ['1','2'], {}) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text +| null | {city: "Berlin", foo: "one"} | 1 | [...] | ajeje +| null | {city: "Berlin", foo: "two"} | 2 | [...] | brazorf +| ... +|=== + + +.Query vectors +[source,cypher] +---- +CALL apoc.vectordb.chroma.query('localhot:8000', + '', + [0.2, 0.1, 0.9, 0.7], + { must: + [ { key: "city", match: { value: "London" } } ] + }, + 5, + {}) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text +| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | ajeje +| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | brazorf +| ... +|=== + + +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.qdrant.query and the `apoc.vectordb.qdrant.get` procedures. +For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]}, +so that we do not return the other values that we do not need. +==== + + +In the same way as other procedures, we can define a mapping, to auto-create an index (if not exists) and one/multiple nodes and relationships, +by leveraging the vector metadata. For example: + +.Query vectors +[source,cypher] +---- +CALL apoc.vectordb.chrome.query($host, '', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingProp: "vect", + label: "Test", + prop: "myId", + id: "foo" + } + }) +---- + + + +.Delete vectors +[source,cypher] +---- +CALL apoc.vectordb.chroma.delete('localhot:8000', '', [1,2], {}) +---- + + +== Custom (i.e. other vector databases) + +Here is a list of all available Qdrant procedures: + +[opts=header, cols="1, 3"] +|=== +| name | description +| apoc.vectordb.custom.get(host, $embeddingConfig) | Customizable get / query procedure, + returning a result like the others `apoc.vectordb.*.get` ones +| apoc.vectordb.custom(host, $config) | Fully customizable procedure, returns generic object results. +|=== + + +=== Examples + + +The `apoc.vectordb.custom.get` can be used with every API that return something like this +(note that the call does not need to return all keys): + +``` +[ + "": "value", + "": scoreValue, + "": [ ... ] + "": { .. }, + "": "..." +], +[ + ... +] +``` + +where we can customize idKey, scoreKey, embeddingKey, metadataKey and textKey via the homonyms config parameters. + + +Let's look at some examples using https://docs.pinecone.io/guides/getting-started/overview[Pinecone]. + + +.apoc.vectordb.custom.get example +[source,cypher] +---- +CALL apoc.vectordb.custom.get('https://.svc.gcp-starter.pinecone.io/query', { + body: { + "namespace", namespace, + "vector", vector, + "topK", 3, + "includeValues", true, + "includeMetadata", true + }, + headers: {"Api-Key", apiKey}, + method: null, + jsonPath: "matches", + // the RestAPI return values as the key with values the vectors + embeddingKey: 'values' +}) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text +| 1, | {a: 1} | 1 | [1,2,3,4] +| 0.1 | {a: 2} | 2 | [1,2,3,4] +| ... +|=== + + + +.apoc.vectordb.custom example +[source,cypher] +---- +CALL apoc.vectordb.custom('https://.svc.gcp-starter.pinecone.io/query', { + body: { + "namespace", namespace, + "vector", vector, + "topK", 3, + "includeValues", true, + "includeMetadata", true + }, + headers: {"Api-Key", apiKey}, + method: null, + jsonPath: "matches" +}) +---- + + +.Example esults +[opts="header"] +|=== +| value +| {score: , metadata: , id: , vector: } +| {score: , metadata: , id: , vector: } +| ... +|=== diff --git a/extended-it/build.gradle b/extended-it/build.gradle index 6398d4087e..624b513001 100644 --- a/extended-it/build.gradle +++ b/extended-it/build.gradle @@ -20,7 +20,7 @@ dependencies { def withoutJacksons = { exclude group: 'com.fasterxml.jackson.core', module: 'jackson-annotations' exclude group: 'com.fasterxml.jackson.core', module: 'jackson-databind' - } + } def withoutServers = { exclude group: 'org.eclipse.jetty' exclude group: 'org.eclipse.jetty.aggregate' @@ -49,6 +49,14 @@ dependencies { exclude group: 'io.netty' } testImplementation group: 'org.apache.parquet', name: 'parquet-hadoop', version: '1.13.1', withoutServers + + testImplementation group: 'org.testcontainers', name: 'qdrant', version: '1.19.7' + testImplementation group: 'org.testcontainers', name: 'chromadb', version: '1.19.7' + + // https://mvnrepository.com/artifact/io.qdrant/client + implementation group: 'io.qdrant', name: 'client', version: '1.8.0' + + configurations.all { exclude group: 'org.slf4j', module: 'slf4j-nop' diff --git a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java new file mode 100644 index 0000000000..82c9a0bebc --- /dev/null +++ b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java @@ -0,0 +1,268 @@ +package apoc.vectordb; + +import apoc.util.TestUtil; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; +import org.testcontainers.chromadb.ChromaDBContainer; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import static apoc.util.MapUtil.map; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinVector; +import static apoc.vectordb.VectorDbTestUtil.assertLondonVector; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertOtherNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertRelsAndIndexesCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static java.util.Collections.emptyMap; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +public class ChromaDbTest { + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + private static final ChromaDBContainer chroma = new ChromaDBContainer("chromadb/chroma:0.4.25.dev137"); + private static final AtomicReference collId = new AtomicReference<>(); + + public static String HOST; + + @BeforeClass + public static void setUp() throws Exception { + chroma.start(); + + HOST = "localhost:" + chroma.getMappedPort(8000); + TestUtil.registerProcedure(db, ChromaDb.class); + + testCall(db, "CALL apoc.vectordb.chroma.createCollection($host, 'test_collection', 'cosine', 4)", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + collId.set((String) value.get("id")); + }); + + testCall(db, """ + CALL apoc.vectordb.chroma.upsert($host, $collection, + [ + {id: '1', vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}, text: 'ajeje'}, + {id: '2', vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}, text: 'brazorf'} + ]) + """, + map("host", HOST, "collection", collId.get()), + r -> { + assertNull(r.get("value")); + }); + } + + @AfterClass + public static void tearDown() throws Exception { + testCall(db, "CALL apoc.vectordb.chroma.deleteCollection($host, 'test_collection')", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertNull(value); + }); + } + + @Before + public void before() { + dropAndDeleteAll(db); + } + + @Test + public void getEmbeddings() { + testResult(db, "CALL apoc.vectordb.chroma.get($host, $collection, ['1']) ", + Map.of("host", HOST, "collection", collId.get()), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("vector")); + assertEquals("ajeje", row.get("text")); + }); + } + + @Test + public void deleteVector() { + testCall(db, """ + CALL apoc.vectordb.chroma.upsert($host, $collection, + [ + {id: 3, embedding: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}} + ]) + """, + map("host", HOST, "collection", collId.get()), + r -> { + assertNull(r.get("value")); + }); + + testCall(db, "CALL apoc.vectordb.chroma.delete($host, $collection, [3]) ", + Map.of("host", HOST, "collection", collId.get()), + r -> { + assertEquals(List.of("3"), r.get("value")); + }); + } + + @Test + public void createAndDeleteVector() { + testResult(db, "CALL apoc.vectordb.chroma.get($host, $collection, ['1']) ", + Map.of("host", HOST, "collection", collId.get()), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void getEmbedding() { + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5)", + Map.of("host", HOST, "collection", collId.get(), "conf", emptyMap()), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void getEmbeddingWithYield() { + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5) YIELD metadata, id", + Map.of("host", HOST, "collection", collId.get(), "conf", emptyMap()), + r -> { + assertBerlinVector(r.next()); + assertLondonVector(r.next()); + }); + } + + @Test + public void getEmbeddingWithFilter() { + testResult(db, """ + CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {city: 'London'}, 5) YIELD metadata, id""", + Map.of("host", HOST, "collection", collId.get(), "conf", emptyMap()), + r -> { + assertLondonVector(r.next()); + }); + } + + @Test + public void getEmbeddingWithLimit() { + testResult(db, """ + CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 1) YIELD metadata, id""", + Map.of("host", HOST, "collection", collId.get(), "conf", emptyMap()), + r -> { + assertBerlinVector(r.next()); + }); + } + + @Test + public void getEmbeddingWithCreateIndex() { + Map conf = Map.of(MAPPING_KEY, Map.of("embeddingProp", "vect", + "label", "Test", + "prop", "myId", + "id", "foo", + "create", true)); + + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "collection", collId.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db, true); + + + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.22, 0.11, 0.99, 0.17], {}, 5, $conf) " + + " YIELD score, vector, id, metadata RETURN * ORDER BY id", + Map.of("host", HOST, "collection", collId.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertOtherNodesCreated(db); + } + + @Test + public void getEmbeddingWithCreateIndexUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = Map.of(MAPPING_KEY, Map.of("embeddingProp", "vect", + "label", "Test", + "prop", "myId", + "id", "foo")); + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "collection", collId.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db, false); + } + + @Test + public void getEmbeddingWithCreateRelIndex() { + + db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); + + Map conf = Map.of(MAPPING_KEY, Map.of("embeddingProp", "vect", + "type", "TEST", + "prop", "myId", + "id", "foo", + "create", true)); + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "collection", collId.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertRelsAndIndexesCreated(db); + } +} diff --git a/extended-it/src/test/java/apoc/vectordb/QdrantDbTest.java b/extended-it/src/test/java/apoc/vectordb/QdrantDbTest.java new file mode 100644 index 0000000000..12773b61a9 --- /dev/null +++ b/extended-it/src/test/java/apoc/vectordb/QdrantDbTest.java @@ -0,0 +1,263 @@ +package apoc.vectordb; + +import apoc.util.TestUtil; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; +import org.testcontainers.qdrant.QdrantContainer; + +import java.util.Map; + +import static apoc.util.MapUtil.map; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinVector; +import static apoc.vectordb.VectorDbTestUtil.assertLondonVector; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertOtherNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertRelsAndIndexesCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.vectorEntityAssertions; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static java.util.Collections.emptyMap; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class QdrantDbTest { + + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + private static final QdrantContainer qdrant = new QdrantContainer("qdrant/qdrant:v1.7.4"); + public static String HOST; + + @BeforeClass + public static void setUp() throws Exception { + qdrant.start(); + + HOST = "localhost:" + qdrant.getMappedPort(6333); + TestUtil.registerProcedure(db, Qdrant.class); + + testCall(db, "CALL apoc.vectordb.qdrant.createCollection($host, 'test_collection', 'Cosine', 4)", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals("ok", value.get("status")); + }); + + testCall(db, """ + CALL apoc.vectordb.qdrant.upsert($host, 'test_collection', + [ + {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}}, + {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}} + ]) + """, + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals("ok", value.get("status")); + }); + + } + + @AfterClass + public static void tearDown() throws Exception { + testCall(db, "CALL apoc.vectordb.qdrant.deleteCollection($host, 'test_collection')", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals(true, value.get("result")); + }); + } + + @Before + public void before() { + dropAndDeleteAll(db); + } + + @Test + public void getEmbeddings() { + testResult(db, "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1]) ", + Map.of("host", HOST), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void deleteVector() { + testCall(db, """ + CALL apoc.vectordb.qdrant.upsert($host, 'test_collection', + [ + {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}}, + {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}} + ]) + """, + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals("ok", value.get("status")); + }); + + testCall(db, "CALL apoc.vectordb.qdrant.delete($host, 'test_collection', [3, 4]) ", + Map.of("host", HOST), + r -> { + Map value = (Map) r.get("value"); + assertEquals("ok", value.get("status")); + }); + } + + @Test + public void getEmbedding() { + testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5)", + Map.of("host", HOST, "conf", emptyMap()), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void getEmbeddingWithYield() { + testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5) YIELD metadata, id", + Map.of("host", HOST, "conf", emptyMap()), + r -> { + assertBerlinVector(r.next()); + assertLondonVector(r.next()); + }); + } + + @Test + public void getEmbeddingWithFilter() { + testResult(db, """ + CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], + { must: + [ { key: "city", match: { value: "London" } } ] + }, + 5) YIELD metadata, id""", + Map.of("host", HOST, "conf", emptyMap()), + r -> { + assertLondonVector(r.next()); + }); + } + + @Test + public void getEmbeddingWithLimit() { + testResult(db, """ + CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 1) YIELD metadata, id""", + Map.of("host", HOST, "conf", emptyMap()), + r -> { + assertBerlinVector(r.next()); + }); + } + + @Test + public void getEmbeddingWithCreateIndex() { + + Map conf = Map.of(MAPPING_KEY, Map.of("embeddingProp", "vect", + "label", "Test", + "prop", "myId", + "id", "foo", + "create", true)); + testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db, true); + + testResult(db, "MATCH (n:Test) RETURN properties(n) AS props ORDER BY n.myId", + r -> vectorEntityAssertions(r, true)); + + testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertOtherNodesCreated(db); + } + + @Test + public void getEmbeddingWithCreateIndexUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = Map.of(MAPPING_KEY, Map.of("embeddingProp", "vect", + "label", "Test", + "prop", "myId", + "id", "foo")); + testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db, false); + } + + @Test + public void getEmbeddingWithCreateRelIndex() { + + db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); + + Map conf = Map.of(MAPPING_KEY, Map.of("embeddingProp", "vect", + "type", "TEST", + "prop", "myId", + "id", "foo")); + testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonVector(row); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertRelsAndIndexesCreated(db); + } + +} diff --git a/extended/src/main/java/apoc/ml/RestAPIConfig.java b/extended/src/main/java/apoc/ml/RestAPIConfig.java new file mode 100644 index 0000000000..0f512a4f33 --- /dev/null +++ b/extended/src/main/java/apoc/ml/RestAPIConfig.java @@ -0,0 +1,69 @@ +package apoc.ml; + + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +// TODO - maybe move to `apoc.util` package? +public class RestAPIConfig { + public static final String HEADERS_KEY = "headers"; + public static final String METHOD_KEY = "method"; + public static final String ENDPOINT_KEY = "endpoint"; + public static final String JSON_PATH_KEY = "jsonPath"; + public static final String BODY_KEY = "body"; + + private final Map headers; + private final Map body; + private final String endpoint; + private final String jsonPath; + + public RestAPIConfig(Map config) { + this(config, Map.of(), Map.of()); + } + + public RestAPIConfig(Map config, Map additionalHeaders, Map additionalBodies) { + if (config == null) { + config = Collections.emptyMap(); + } + + String httpMethod = (String) config.getOrDefault(METHOD_KEY, "POST"); + Map headerConf = (Map) config.getOrDefault(HEADERS_KEY, new HashMap<>()); + headerConf.putIfAbsent("content-type", "application/json"); + headerConf.putIfAbsent(METHOD_KEY, httpMethod); + additionalHeaders.forEach( (k,v)-> headerConf.putIfAbsent(k,v) ); + + this.headers = headerConf; + + this.endpoint = getEndpoint(config); + + this.jsonPath = (String) config.get(JSON_PATH_KEY); + Map bodyConf = (Map) config.getOrDefault(BODY_KEY, new HashMap<>()); + additionalBodies.forEach( (k,v)-> bodyConf.putIfAbsent(k,v) ); + this.body = bodyConf; + } + + private String getEndpoint(Map config) { + String endpointConfig = (String) config.get(ENDPOINT_KEY); + if (endpointConfig == null) { + throw new RuntimeException("Endpoint must be specified"); + } + return endpointConfig; + } + + public Map getHeaders() { + return headers; + } + + public Map getBody() { + return body; + } + + public String getEndpoint() { + return endpoint; + } + + public String getJsonPath() { + return jsonPath; + } +} diff --git a/extended/src/main/java/apoc/util/ExtendedUtil.java b/extended/src/main/java/apoc/util/ExtendedUtil.java index c7c2e55430..10d38e2c4c 100644 --- a/extended/src/main/java/apoc/util/ExtendedUtil.java +++ b/extended/src/main/java/apoc/util/ExtendedUtil.java @@ -30,6 +30,7 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.temporal.TemporalAccessor; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; @@ -287,4 +288,45 @@ private static void retryRunnable(long maxRetries, long retry, Runnable consumer retryRunnable(maxRetries, retry, consumer); } } + + public static void setProperties(Entity entity, Map props) { + for (var entry: props.entrySet()) { + entity.setProperty(entry.getKey(), entry.getValue()); + } + } + + /** + * Transform a list like: [ {key1: valueFoo1, key2: valueFoo2}, {key1: valueBar1, key2: valueBar2} ] + * to a map like: { keyNew1: [valueFoo1, valueBar1], keyNew2: [valueFoo2, valueBar2] }, + * + * where mapKeys is e.g. {key1: keyNew1, key2: keyNew2} + */ + public static Map listOfMapToMapOfLists(Map mapKeys, List> vectors) { + Map additionalBodies = new HashMap(); + for (var vector: vectors) { + mapKeys.forEach((from, to) -> { + mapEntryToList(additionalBodies, vector, from, to); + }); + } + return additionalBodies; + } + + private static void mapEntryToList(Map map, Map vector, Object keyFrom, Object keyTo) { + Object item = vector.get(keyFrom); + if (item == null) { + return; + } + + map.compute(keyTo, (k, v) -> { + if (v == null) { + List list = new ArrayList<>(); + list.add(item); + return list; + } + List list = (List) v; + list.add(item); + return list; + }); + } + } diff --git a/extended/src/main/java/apoc/vectordb/ChromaDb.java b/extended/src/main/java/apoc/vectordb/ChromaDb.java new file mode 100644 index 0000000000..cfa09f98b9 --- /dev/null +++ b/extended/src/main/java/apoc/vectordb/ChromaDb.java @@ -0,0 +1,236 @@ +package apoc.vectordb; + +import apoc.ml.RestAPIConfig; +import apoc.result.ListResult; +import apoc.result.MapResult; +import apoc.util.UrlResolver; +import org.apache.commons.collections4.CollectionUtils; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Transaction; +import org.neo4j.graphdb.security.URLAccessChecker; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static apoc.util.ExtendedUtil.listOfMapToMapOfLists; +import static apoc.util.MapUtil.map; +import static apoc.vectordb.VectorDb.executeRequest; +import static apoc.vectordb.VectorDb.getEmbeddingResultStream; +import static apoc.vectordb.VectorDbUtil.getEndpoint; +import static apoc.vectordb.VectorEmbedding.Type.CHROMA; +import static apoc.vectordb.VectorEmbeddingConfig.*; + +public class ChromaDb { + + @Context + public ProcedureCallContext procedureCallContext; + + @Context + public Transaction tx; + + @Context + public GraphDatabaseService db; + + @Context + public URLAccessChecker urlAccessChecker; + + @Procedure("apoc.vectordb.chroma.createCollection") + @Description("apoc.vectordb.chroma.createCollection") + public Stream createCollection(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("similarity") String similarity, + @Name("size") Long size, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var config = new HashMap<>(configuration); + + String qdrantUrl = getChromaUrl(hostOrKey); + String endpoint = "%s/api/v1/collections".formatted(qdrantUrl); + getEndpoint(config, endpoint); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map metadata = Map.of("hnsw:space", similarity, + "size", size); + Map additionalBodies = Map.of("name", collection, "metadata", metadata); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig, urlAccessChecker) + .map(v -> (Map)v) + .map(MapResult::new); + } + + @Procedure("apoc.vectordb.chroma.deleteCollection") + @Description("apoc.vectordb.chroma.deleteCollection") + public Stream deleteCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var config = new HashMap<>(configuration); + + String qdrantUrl = getChromaUrl(hostOrKey); + String endpoint = "%s/api/v1/collections/%s".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + config.putIfAbsent(METHOD_KEY, "DELETE"); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), Map.of()); + return executeRequest(restAPIConfig, urlAccessChecker) + .map(v -> (Map)v) + .map(MapResult::new); + } + + @Procedure("apoc.vectordb.chroma.upsert") + @Description("apoc.vectordb.chroma.upsert") + public Stream upsert( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List> vectors, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var config = new HashMap<>(configuration); + + String qdrantUrl = getChromaUrl(hostOrKey); + String endpoint = "%s/api/v1/collections/%s/upsert".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + + Map mapKeys = Map.of("id", "ids", + "vector", "embeddings", + "metadata", "metadatas", + "text", "documents"); + + // transform to format digestible by RestAPI, + // that is from [{id: , vector: ,,,}, {id: , vector: ,,,}] + // to {ids: [, ], vectors: [, ]} + Map additionalBodies = listOfMapToMapOfLists(mapKeys, vectors); + additionalBodies.compute( "ids", (k,v) -> getStringIds(v) ); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig, urlAccessChecker) + .map(v -> (Map)v) + .map(MapResult::new); + } + + @Procedure(value = "apoc.vectordb.chroma.delete", mode = Mode.SCHEMA) + @Description("apoc.vectordb.chroma.delete()") + public Stream delete(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var config = new HashMap<>(configuration); + + String qdrantUrl = getChromaUrl(hostOrKey); + String endpoint = "%s/api/v1/collections/%s/delete".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + + VectorEmbeddingConfig apiConfig = CHROMA.get().fromGet(config, procedureCallContext, getStringIds(ids)); + return executeRequest(apiConfig, urlAccessChecker) + .map(v -> (List) v) + .map(ListResult::new); + } + + @Procedure(value = "apoc.vectordb.chroma.get", mode = Mode.SCHEMA) + @Description("apoc.vectordb.chroma.get()") + public Stream query(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var config = new HashMap<>(configuration); + + String qdrantUrl = getChromaUrl(hostOrKey); + String endpoint = "%s/api/v1/collections/%s/get".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + + VectorEmbeddingConfig apiConfig = CHROMA.get().fromGet(config, procedureCallContext, ids); + return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx, + v -> listToMap((Map) v).stream()); + } + + @Procedure(value = "apoc.vectordb.chroma.query", mode = Mode.SCHEMA) + @Description("apoc.vectordb.chroma.query()") + public Stream query(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + + var config = new HashMap<>(configuration); + + String qdrantUrl = getChromaUrl(hostOrKey); + String endpoint = "%s/api/v1/collections/%s/query".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + + VectorEmbeddingConfig apiConfig = CHROMA.get().fromQuery(config, procedureCallContext, vector, filter, limit); + return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx, + v -> listOfListsToMap((Map) v).stream()); + } + + private static List listOfListsToMap(Map startMap) { + List distances = startMap.get("distances") == null + ? null + : ((List) startMap.get("distances")) + .get(0); + List metadatas = startMap.get("metadatas") == null + ? null + : ((List) startMap.get("metadatas")) + .get(0); + List documents = startMap.get("documents") == null + ? null + : ((List) startMap.get("documents")) + .get(0); + List embeddings = startMap.get("embeddings") == null + ? null + : ((List) startMap.get("embeddings")) + .get(0); + + List ids = ((List) startMap.get("ids")).get(0); + + return getMaps(distances, metadatas, documents, embeddings, ids); + } + + private static List listToMap(Map startMap) { + List distances = (List) startMap.get("distances"); + List metadatas = (List) startMap.get("metadatas"); + List documents = (List) startMap.get("documents"); + List embeddings = (List) startMap.get("embeddings"); + + List ids = (List) startMap.get("ids"); + + return getMaps(distances, metadatas, documents, embeddings, ids); + } + + private static List getMaps(List distances, List metadatas, List documents, List embeddings, List ids) { + final List result = new ArrayList<>(); + for (int i = 0; i < ids.size(); i++) { + Map map = map(DEFAULT_ID, ids.get(i)); + if (CollectionUtils.isNotEmpty(distances)) { + map.put(DEFAULT_SCORE, distances.get(i)); + } + if (CollectionUtils.isNotEmpty(metadatas)) { + map.put(DEFAULT_METADATA, metadatas.get(i)); + } + if (CollectionUtils.isNotEmpty(documents)) { + map.put(DEFAULT_TEXT, documents.get(i)); + } + if (CollectionUtils.isNotEmpty(embeddings)) { + map.put(DEFAULT_VECTOR, embeddings.get(i)); + } + result.add(map); + } + + return result; + } + + private List getStringIds(List ids) { + return ids.stream().map(Object::toString).toList(); + } + + protected String getChromaUrl(String hostOrKey) { + return new UrlResolver("http", "localhost", 8000).getUrl("chroma", hostOrKey); + } +} diff --git a/extended/src/main/java/apoc/vectordb/Qdrant.java b/extended/src/main/java/apoc/vectordb/Qdrant.java new file mode 100644 index 0000000000..43492e8335 --- /dev/null +++ b/extended/src/main/java/apoc/vectordb/Qdrant.java @@ -0,0 +1,177 @@ +package apoc.vectordb; + +import apoc.ml.RestAPIConfig; +import apoc.result.MapResult; +import apoc.util.UrlResolver; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Transaction; +import org.neo4j.graphdb.security.URLAccessChecker; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static apoc.ml.RestAPIConfig.ENDPOINT_KEY; +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.vectordb.VectorDb.executeRequest; +import static apoc.vectordb.VectorDb.getEmbeddingResultStream; +import static apoc.vectordb.VectorDbUtil.getEndpoint; +import static apoc.vectordb.VectorEmbedding.Type.QDRANT; +import static apoc.vectordb.VectorEmbeddingConfig.EMBEDDING_KEY; + +public class Qdrant { + + @Context + public ProcedureCallContext procedureCallContext; + + @Context + public Transaction tx; + + @Context + public GraphDatabaseService db; + + @Context + public URLAccessChecker urlAccessChecker; + + @Procedure("apoc.vectordb.qdrant.createCollection") + @Description("apoc.vectordb.qdrant.createCollection(hostOrKey, collection, similarity, size, $config)") + public Stream createCollection(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("similarity") String similarity, + @Name("size") Long size, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var config = new HashMap<>(configuration); + + String qdrantUrl = getQdrantUrl(hostOrKey); + String endpoint = "%s/collections/%s".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + config.putIfAbsent(METHOD_KEY, "PUT"); + + Map additionalBodies = Map.of("vectors", Map.of( + "size", size, + "distance", similarity + )); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig, urlAccessChecker) + .map(v -> (Map)v) + .map(MapResult::new); + } + + @Procedure("apoc.vectordb.qdrant.deleteCollection") + @Description("apoc.vectordb.qdrant.deleteCollection(hostOrKey, collection, $config)") + public Stream deleteCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + + var config = new HashMap<>(configuration); + + String qdrantUrl = getQdrantUrl(hostOrKey); + String endpoint = "%s/collections/%s".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + config.putIfAbsent(METHOD_KEY, "DELETE"); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config); + return executeRequest(restAPIConfig, urlAccessChecker) + .map(v -> (Map)v) + .map(MapResult::new); + } + + @Procedure("apoc.vectordb.qdrant.upsert") + @Description("apoc.vectordb.qdrant.upsert(hostOrKey, collection, vectors, $config)") + public Stream upsert( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List> vectors, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + + var config = new HashMap<>(configuration); + + String qdrantUrl = getQdrantUrl(hostOrKey); + String endpoint = "%s/collections/%s/points".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + config.putIfAbsent(METHOD_KEY, "PUT"); + + List> point = vectors.stream() + .map(i -> { + Map map = new HashMap<>(i); + map.putIfAbsent("vector", map.remove("vector")); + map.putIfAbsent("payload", map.remove("metadata")); + return map; + }) + .toList(); + Map additionalBodies = Map.of("points", point); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig, urlAccessChecker) + .map(v -> (Map)v) + .map(MapResult::new); + } + + @Procedure("apoc.vectordb.qdrant.delete") + @Description("apoc.vectordb.qdrant.delete(hostOrKey, collection, ids, $config)") + public Stream delete( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + + var config = new HashMap<>(configuration); + + String qdrantUrl = getQdrantUrl(hostOrKey); + String endpoint = "%s/collections/%s/points/delete".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map additionalBodies = Map.of("points", ids); + RestAPIConfig apiConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(apiConfig, urlAccessChecker) + .map(v -> (Map)v) + .map(MapResult::new); + } + + @Procedure(value = "apoc.vectordb.qdrant.get", mode = Mode.SCHEMA) + @Description("apoc.vectordb.qdrant.get(hostOrKey, collection, ids, $config)") + public Stream query(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var config = new HashMap<>(configuration); + + String qdrantUrl = getQdrantUrl(hostOrKey); + String endpoint = "%s/collections/%s/points".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + + VectorEmbeddingConfig apiConfig = QDRANT.get().fromGet(config, procedureCallContext, ids); + return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx); + } + + @Procedure(value = "apoc.vectordb.qdrant.query", mode = Mode.SCHEMA) + @Description("apoc.vectordb.qdrant.query(hostOrKey, collection, vector, filter, limit, $config)") + public Stream query(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + + var config = new HashMap<>(configuration); + + String qdrantUrl = getQdrantUrl(hostOrKey); + String endpoint = "%s/collections/%s/points/search".formatted(qdrantUrl, collection); + getEndpoint(config, endpoint); + + VectorEmbeddingConfig apiConfig = QDRANT.get().fromQuery(config, procedureCallContext, vector, filter, limit); + return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx); + } + + protected String getQdrantUrl(String hostOrKey) { + return new UrlResolver("http", "localhost", 6333).getUrl("qdrant", hostOrKey); + } +} diff --git a/extended/src/main/java/apoc/vectordb/VectorDb.java b/extended/src/main/java/apoc/vectordb/VectorDb.java new file mode 100644 index 0000000000..2f12baf251 --- /dev/null +++ b/extended/src/main/java/apoc/vectordb/VectorDb.java @@ -0,0 +1,225 @@ +package apoc.vectordb; + +import apoc.ml.RestAPIConfig; +import apoc.result.ObjectResult; +import apoc.util.JsonUtil; +import apoc.util.Util; +import org.apache.commons.collections4.MapUtils; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Label; +import org.neo4j.graphdb.MultipleFoundException; +import org.neo4j.graphdb.Node; +import org.neo4j.graphdb.Relationship; +import org.neo4j.graphdb.RelationshipType; +import org.neo4j.graphdb.Transaction; +import org.neo4j.graphdb.security.URLAccessChecker; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Stream; + +import static apoc.util.ExtendedUtil.setProperties; +import static apoc.util.JsonUtil.OBJECT_MAPPER; +import static apoc.vectordb.VectorDbUtil.*; + +/** + * Base class + */ +public class VectorDb { + + @Context + public URLAccessChecker urlAccessChecker; + + @Context + public GraphDatabaseService db; + + @Context + public Transaction tx; + + @Context + public ProcedureCallContext procedureCallContext; + + /** + * We can use this procedure with every API that return something like this: + * ``` + * [ + * "idKey": "idValue", + * "scoreKey": 1, + * "embeddingKey": [ ] + * "metadataKey": { .. }, + * "textKey": "..." + * ], + * [ + * ... + * ] + * ``` + * + * Otherwise, if the result is different (e.g. the Chroma result), we have to leverage the apoc.vectordb.custom, + * which retrurn an Object, but we can't use it to filter result via `ProcedureCallContext procedureCallContext` + * and mapping data to auto-create neo4j vector indexes and properties + */ + @Procedure(value = "apoc.vectordb.custom.get", mode = Mode.SCHEMA) + @Description("apoc.vectordb.custom.get(host, $configuration) - Customizable get / query procedure") + public Stream get(@Name("host") String host, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + + getEndpoint(configuration, host); + VectorEmbeddingConfig restAPIConfig = new VectorEmbeddingConfig(configuration, Map.of(), Map.of()); + return getEmbeddingResultStream(restAPIConfig, procedureCallContext, urlAccessChecker, db, tx); + } + + public static Stream getEmbeddingResultStream(VectorEmbeddingConfig conf, + ProcedureCallContext procedureCallContext, + URLAccessChecker urlAccessChecker, + GraphDatabaseService db, + Transaction tx) throws Exception { + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, db, tx, v -> ((List) v).stream()); + } + + public static Stream getEmbeddingResultStream(VectorEmbeddingConfig conf, + ProcedureCallContext procedureCallContext, + URLAccessChecker urlAccessChecker, + GraphDatabaseService db, + Transaction tx, + Function> objectMapper) throws Exception { + List fields = procedureCallContext.outputFields().toList(); + + boolean hasEmbedding = fields.contains("vector"); + boolean hasMetadata = fields.contains("metadata"); + Stream resultStream = executeRequest(conf, urlAccessChecker); + + VectorMappingConfig mapping = conf.getMapping(); + + return resultStream + .flatMap(objectMapper) + .map(m -> { + Object id = m.get(conf.getIdKey()); + List embedding = hasEmbedding ? (List) m.get(conf.getVectorKey()) : null; + Map metadata = hasMetadata ? (Map) m.get(conf.getMetadataKey()) : null; + // in case of get operation, e.g. http://localhost:52798/collections/{coll_name}/points with Qdrant db, + // score is not present + Double score = Util.toDouble(m.get(conf.getScoreKey())); + String text = (String) m.get(conf.getTextKey()); + + handleMapping(tx, db, mapping, metadata, embedding); + return new EmbeddingResult(id, score, embedding, metadata, text); + }); + } + + private static void handleMapping(Transaction tx, GraphDatabaseService db, VectorMappingConfig mapping, Map metadata, List embedding) { + if (mapping.getProp() == null) { + return; + } + if (MapUtils.isEmpty(metadata)) { + throw new RuntimeException("To use mapping config, the metadata should not be empty. Make sure you execute `YIELD metadata` on the procedure"); + } + Map metaProps = new HashMap<>(metadata); + if (mapping.getLabel() != null) { + handleMappingNode(tx, db, mapping, metaProps, embedding); + } else if (mapping.getType() != null) { + handleMappingRel(tx, db, mapping, metaProps, embedding); + } else { + throw new RuntimeException("Mapping conf has to contain either label or type key"); + } + } + + private static void handleMappingNode(Transaction tx, GraphDatabaseService db, VectorMappingConfig mapping, Map metaProps, List embedding) { + String query = "CREATE CONSTRAINT IF NOT EXISTS FOR (n:%s) REQUIRE n.%s IS UNIQUE" + .formatted(mapping.getLabel(), mapping.getProp()); + db.executeTransactionally(query); + + try { + Node node; + try (Transaction transaction = db.beginTx()) { + Object propValue = metaProps.remove(mapping.getId()); + node = transaction.findNode(Label.label(mapping.getLabel()), mapping.getProp(), propValue); + if (node == null && mapping.isCreate()) { + node = transaction.createNode(Label.label(mapping.getLabel())); + } + if (node != null) { + setProperties(node, metaProps); + } + transaction.commit(); + } + + String indexQuery = "CREATE VECTOR INDEX IF NOT EXISTS FOR (n:%s) ON (n.%s) OPTIONS {indexConfig: {`vector.dimensions`: %s, `vector.similarity_function`: '%s'}}"; + String setVectorQuery = "CALL db.create.setNodeVectorProperty($entity, $key, $vector)"; + setVectorProp(tx, db, mapping, embedding, node, indexQuery, setVectorQuery); + + } catch (MultipleFoundException e) { + throw new RuntimeException("Multiple nodes found"); + } + } + + private static void handleMappingRel(Transaction tx, GraphDatabaseService db, VectorMappingConfig mapping, Map metaProps, List embedding) { + try { + String query = "CREATE CONSTRAINT IF NOT EXISTS FOR ()-[r:%s]-() REQUIRE (r.%s) IS UNIQUE" + .formatted(mapping.getType(), mapping.getProp()); + db.executeTransactionally(query); + + // in this case we cannot auto-create the rel, since we should have to define start and end node as well + Relationship rel; + try (Transaction transaction = db.beginTx()) { + Object propValue = metaProps.remove(mapping.getId()); + rel = transaction.findRelationship(RelationshipType.withName(mapping.getType()), mapping.getProp(), propValue); + if (rel != null) { + setProperties(rel, metaProps); + } + transaction.commit(); + } + + String indexQuery ="CREATE VECTOR INDEX IF NOT EXISTS FOR ()-[r:%s]-() ON (r.%s) OPTIONS {indexConfig: {`vector.dimensions`: %s, `vector.similarity_function`: '%s'}}"; + String setVectorQuery = "CALL db.create.setRelationshipVectorProperty($entity, $key, $vector)"; + setVectorProp(tx, db, mapping, embedding, rel, indexQuery, setVectorQuery); + + } catch (MultipleFoundException e) { + throw new RuntimeException("Multiple relationships found"); + } + } + + private static void setVectorProp(Transaction tx, GraphDatabaseService db, VectorMappingConfig mapping, List embedding, T entity, String indexQuery, String setVectorQuery) { + if (entity == null || mapping.getEmbeddingProp() == null) { + return; + } + + if (embedding == null) { + throw new RuntimeException("The embedding value is null. Make sure you execute `YIELD embedding` on the procedure"); + } + + String labelOrType = entity instanceof Node + ? mapping.getLabel() + : mapping.getType(); + String vectorIndex = indexQuery + .formatted(labelOrType, mapping.getEmbeddingProp(), embedding.size(), mapping.getSimilarity()); + db.executeTransactionally(vectorIndex); + db.executeTransactionally(setVectorQuery, + Map.of("entity", Util.rebind(tx, entity), "key", mapping.getEmbeddingProp(), "vector", embedding)); + } + + // TODO - evaluate. It could be renamed e.g. to `apoc.util.restapi.custom` or `apoc.restapi.custom`, + // since it can potentially be used as a generic method to call any RestAPI + @Procedure("apoc.vectordb.custom") + @Description("apoc.vectordb.custom(host, $config) - fully customizable vector db procedure, returns generic object results") + public Stream custom(@Name("host") String host, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + + getEndpoint(configuration, host); + RestAPIConfig restAPIConfig = new RestAPIConfig(configuration); + return executeRequest(restAPIConfig, urlAccessChecker) + .map(ObjectResult::new); + } + + public static Stream executeRequest(RestAPIConfig apiConfig, URLAccessChecker urlAccessChecker) throws Exception { + Map headers = apiConfig.getHeaders(); + String body = OBJECT_MAPPER.writeValueAsString(apiConfig.getBody()); + return JsonUtil.loadJson(apiConfig.getEndpoint(), headers, body, apiConfig.getJsonPath(), true, List.of(), urlAccessChecker); + } +} diff --git a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java new file mode 100644 index 0000000000..c7727804c7 --- /dev/null +++ b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -0,0 +1,23 @@ +package apoc.vectordb; + +import java.util.List; +import java.util.Map; + +import static apoc.ml.RestAPIConfig.ENDPOINT_KEY; + +public class VectorDbUtil { + + /** + * we can configure the endpoint via config map or via hostOrKey parameter, + * to handle potential endpoint changes. + * For example, in Qdrant `BASE_URL/collections/COLLECTION_NAME/points` could change in the future. + */ + public static void getEndpoint(Map config, String endpoint) { + config.putIfAbsent(ENDPOINT_KEY, endpoint); + } + + /** + * Result of `apoc.vectordb.*.get` and `apoc.vectordb.*.query` procedures + */ + public record EmbeddingResult(Object id, Double score, List vector, Map metadata, String text) {} +} diff --git a/extended/src/main/java/apoc/vectordb/VectorEmbedding.java b/extended/src/main/java/apoc/vectordb/VectorEmbedding.java new file mode 100644 index 0000000000..168be829e3 --- /dev/null +++ b/extended/src/main/java/apoc/vectordb/VectorEmbedding.java @@ -0,0 +1,138 @@ +package apoc.vectordb; + +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static apoc.ml.RestAPIConfig.JSON_PATH_KEY; +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.MapUtil.map; +import static apoc.vectordb.VectorEmbeddingConfig.EMBEDDING_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY; + +public interface VectorEmbedding { + + enum Type { + CHROMA(new ChromaEmbeddingType()), + QDRANT(new QdrantEmbeddingType()); + + private final VectorEmbedding embedding; + + Type(VectorEmbedding embedding) { + this.embedding = embedding; + } + + public VectorEmbedding get() { + return embedding; + } + } + + public VectorEmbeddingConfig fromGet(Map config, + ProcedureCallContext procedureCallContext, + List ids); + + public VectorEmbeddingConfig fromQuery(Map config, + ProcedureCallContext procedureCallContext, + List vector, + Map filter, + long limit); + + // + // -- implementations + // + + public static class QdrantEmbeddingType implements VectorEmbedding { + + @Override + public VectorEmbeddingConfig fromGet(Map config, ProcedureCallContext procedureCallContext, List ids) { + List fields = procedureCallContext.outputFields().toList(); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map additionalBodies = map("ids", ids); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + @Override + public VectorEmbeddingConfig fromQuery(Map config, ProcedureCallContext procedureCallContext, + List vector, Map filter, long limit) { + List fields = procedureCallContext.outputFields().toList(); + + Map additionalBodies = map("vector", vector, + "filter", filter, + "limit", limit); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + // "with_payload": and "with_vectors": return the metadata and vector, if true + // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding + private static VectorEmbeddingConfig getVectorEmbeddingConfig(Map config, List fields, Map additionalBodies) { + additionalBodies.put("with_payload", fields.contains("metadata")); + additionalBodies.put("with_vectors", fields.contains("vector")); + + config.putIfAbsent(EMBEDDING_KEY, "vector"); + config.putIfAbsent(METADATA_KEY, "payload"); + config.putIfAbsent(JSON_PATH_KEY, "result"); + + return new VectorEmbeddingConfig(config, Map.of(), additionalBodies); + } + } + + public static class ChromaEmbeddingType implements VectorEmbedding { + + @Override + public VectorEmbeddingConfig fromGet(Map config, + ProcedureCallContext procedureCallContext, + List ids) { + + List fields = procedureCallContext.outputFields().toList(); + + Map additionalBodies = map("ids", ids); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + @Override + public VectorEmbeddingConfig fromQuery(Map config, + ProcedureCallContext procedureCallContext, + List vector, + Map filter, + long limit) { + + List fields = procedureCallContext.outputFields().toList(); + + Map additionalBodies = map("query_embeddings", List.of(vector), + "where", filter, + "n_results", limit); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + // "include": [metadatas, embeddings, ...] return the metadata/embeddings/... if included in the list + // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding + private static VectorEmbeddingConfig getVectorEmbeddingConfig(Map config, + List fields, + Map additionalBodies) { + ArrayList include = new ArrayList<>(); + if (fields.contains("metadata")) { + include.add("metadatas"); + } + if (fields.contains("text")) { + include.add("documents"); + } + if (fields.contains("vector")) { + include.add("embeddings"); + } + if (fields.contains("score")) { + include.add("distances"); + } + + additionalBodies.put("include", include); + + return new VectorEmbeddingConfig(config, Map.of(), additionalBodies); + } + } +} diff --git a/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java b/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java new file mode 100644 index 0000000000..263352197a --- /dev/null +++ b/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java @@ -0,0 +1,62 @@ +package apoc.vectordb; + +import apoc.ml.RestAPIConfig; + +import java.util.Map; + +public class VectorEmbeddingConfig extends RestAPIConfig { + public static final String EMBEDDING_KEY = "embeddingKey"; + public static final String METADATA_KEY = "metadataKey"; + public static final String SCORE_KEY = "scoreKey"; + public static final String TEXT_KEY = "textKey"; + public static final String ID_KEY = "idKey"; + public static final String MAPPING_KEY = "mapping"; + + public static final String DEFAULT_ID = "id"; + public static final String DEFAULT_TEXT = "text"; + public static final String DEFAULT_VECTOR = "vector"; + public static final String DEFAULT_METADATA = "metadata"; + public static final String DEFAULT_SCORE = "score"; + + private final String idKey; + private final String textKey; + private final String vectorKey; + private final String metadataKey; + private final String scoreKey; + + private final VectorMappingConfig mapping; + + public VectorEmbeddingConfig(Map config, Map additionalHeaders, Map additionalBodies) { + super(config, additionalHeaders, additionalBodies); + this.vectorKey = (String) config.getOrDefault(EMBEDDING_KEY, DEFAULT_VECTOR); + this.metadataKey = (String) config.getOrDefault(METADATA_KEY, DEFAULT_METADATA); + this.scoreKey = (String) config.getOrDefault(SCORE_KEY, DEFAULT_SCORE); + this.idKey = (String) config.getOrDefault(ID_KEY, DEFAULT_ID); + this.textKey = (String) config.getOrDefault(TEXT_KEY, DEFAULT_TEXT); + this.mapping = new VectorMappingConfig((Map) config.getOrDefault(MAPPING_KEY, Map.of())); + } + + public String getIdKey() { + return idKey; + } + + public String getVectorKey() { + return vectorKey; + } + + public String getMetadataKey() { + return metadataKey; + } + + public String getScoreKey() { + return scoreKey; + } + + public String getTextKey() { + return textKey; + } + + public VectorMappingConfig getMapping() { + return mapping; + } +} diff --git a/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java b/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java new file mode 100644 index 0000000000..c84a0c1218 --- /dev/null +++ b/extended/src/main/java/apoc/vectordb/VectorMappingConfig.java @@ -0,0 +1,62 @@ +package apoc.vectordb; + +import apoc.util.Util; + +import java.util.Collections; +import java.util.Map; + +public class VectorMappingConfig { + private final Object id; + private final String prop; + + private final String label; + private final String type; + private final String embeddingProp; + private final String similarity; + + private final boolean create; + + public VectorMappingConfig(Map mapping) { + if (mapping == null) { + mapping = Collections.emptyMap(); + } + this.id = mapping.get("id"); + this.prop = (String) mapping.get("prop"); + + this.label = (String) mapping.get("label"); + this.type = (String) mapping.get("type"); + this.embeddingProp = (String) mapping.get("embeddingProp"); + + this.similarity = (String) mapping.getOrDefault("similarity", "cosine"); + + this.create = Util.toBoolean(mapping.get("create")); + } + + public Object getId() { + return id; + } + + public String getProp() { + return prop; + } + + public String getLabel() { + return label; + } + + public String getType() { + return type; + } + + public String getEmbeddingProp() { + return embeddingProp; + } + + public boolean isCreate() { + return create; + } + + public String getSimilarity() { + return similarity; + } +} diff --git a/extended/src/test/java/apoc/vectordb/PineconeTest.java b/extended/src/test/java/apoc/vectordb/PineconeTest.java new file mode 100644 index 0000000000..b0c26d79b8 --- /dev/null +++ b/extended/src/test/java/apoc/vectordb/PineconeTest.java @@ -0,0 +1,119 @@ +package apoc.vectordb; + +import apoc.util.TestUtil; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.net.URL; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static apoc.ml.RestAPIConfig.BODY_KEY; +import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.ml.RestAPIConfig.JSON_PATH_KEY; +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.util.Util.map; +import static apoc.vectordb.VectorEmbeddingConfig.EMBEDDING_KEY; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * It leverages `apoc.vectordb.custom*` procedures + * * + * * + * Example of Pinecone RestAPI: + * PINECONE_HOST: `https://INDEX-ID.svc.gcp-starter.pinecone.io` + * PINECONE_KEY: `API Key` + * PINECONE_NAMESPACE: `the one to be specified in body: {.. "ns": NAMESPACE}` + * PINECONE_DIMENSION: vector dimension + */ +public class PineconeTest { + private static String apiKey; + private static String host; + private static String size; + private static String namespace; + + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + + @BeforeClass + public static void setUp() throws Exception { + apiKey = extracted("PINECONE_KEY"); + host = extracted("PINECONE_HOST"); + size = extracted("PINECONE_DIMENSION"); + namespace = extracted("PINECONE_NAMESPACE"); + + TestUtil.registerProcedure(db, VectorDb.class); + } + + private static String extracted(String envKey) { + String size = System.getenv(envKey); + Assume.assumeNotNull("No %s environment configured".formatted(envKey), size); + return size; + } + + + @Test + public void callQueryEndpointViaCustomGetProc() { + + Map conf = getConf(); + conf.put(EMBEDDING_KEY, "values"); + + testResult(db, "CALL apoc.vectordb.custom.get($host, $conf)", + map("host", host + "/query", "conf", conf), + r -> { + r.forEachRemaining(i -> { + assertNotNull(i.get("score")); + assertNotNull(i.get("metadata")); + assertNotNull(i.get("id")); + assertNotNull(i.get("vector")); + }); + }); + } + + @Test + public void callQueryEndpointViaCustomProc() { + testCall(db, "CALL apoc.vectordb.custom($host, $conf)", + map("host", host + "/query", "conf", getConf()), + r -> { + List value = (List) r.get("value"); + value.forEach(i -> { + assertTrue(i.containsKey("score")); + assertTrue(i.containsKey("metadata")); + assertTrue(i.containsKey("id")); + }); + }); + } + + /** + * TODO: "method" is null as a workaround. + * Since with `method: POST` the {@link apoc.util.Util#openUrlConnection(URL, Map)} has a `setChunkedStreamingMode` + * that makes the request to respond 200 OK, but returns an empty result + */ + private static Map getConf() { + List vector = Collections.nCopies(Integer.parseInt(size), 0.1); + + Map body = map( + "namespace", namespace, + "vector", vector, + "topK", 3, + "includeValues", true, + "includeMetadata", true + ); + + Map header = map("Api-Key", apiKey); + + return map(BODY_KEY, body, + HEADERS_KEY, header, + METHOD_KEY, null, + JSON_PATH_KEY, "matches"); + } +} diff --git a/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java b/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java new file mode 100644 index 0000000000..7774fa662d --- /dev/null +++ b/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java @@ -0,0 +1,109 @@ +package apoc.vectordb; + +import apoc.util.collection.Iterables; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Label; +import org.neo4j.graphdb.RelationshipType; +import org.neo4j.graphdb.ResourceIterator; +import org.neo4j.graphdb.Result; +import org.neo4j.graphdb.Transaction; +import org.neo4j.graphdb.schema.ConstraintDefinition; +import org.neo4j.graphdb.schema.IndexDefinition; +import org.neo4j.graphdb.schema.IndexType; + +import java.util.List; +import java.util.Map; + +import static apoc.util.TestUtil.testCallCount; +import static apoc.util.TestUtil.testResult; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class VectorDbTestUtil { + + public static void dropAndDeleteAll(GraphDatabaseService db) { + try (Transaction tx = db.beginTx()) { + tx.schema().getConstraints().forEach(ConstraintDefinition::drop); + tx.schema().getIndexes().forEach(IndexDefinition::drop); + tx.commit(); + } + db.executeTransactionally("MATCH (n) DETACH DELETE n"); + } + + public static void assertBerlinVector(Map row) { + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertEquals("1", row.get("id").toString()); + } + + public static void assertLondonVector(Map row) { + assertEquals(Map.of("city", "London", "foo", "two"), row.get("metadata")); + assertEquals("2", row.get("id").toString()); + } + + public static void assertOtherNodesCreated(GraphDatabaseService db) { + assertIndexNodesCreated(db); + + testCallCount(db, "MATCH (n:Test) RETURN n", 4); + } + + public static void assertNodesCreated(GraphDatabaseService db, boolean isNew) { + assertIndexNodesCreated(db); + + testResult(db, "MATCH (n:Test) RETURN properties(n) AS props ORDER BY n.myId", + r -> vectorEntityAssertions(r, isNew)); + } + + public static void assertIndexNodesCreated(GraphDatabaseService db) { + try (Transaction tx = db.beginTx()) { + List indexes = Iterables.stream(tx.schema().getIndexes()) + .filter(i -> i.getIndexType().equals(IndexType.VECTOR)) + .toList(); + assertEquals(1, indexes.size()); + assertEquals(List.of(Label.label("Test")), indexes.get(0).getLabels()); + assertEquals(List.of("vect"), indexes.get(0).getPropertyKeys()); + + List constraints = Iterables.asList(tx.schema().getConstraints()); + assertEquals(1, constraints.size()); + assertEquals(Label.label("Test"), constraints.get(0).getLabel()); + assertEquals(List.of("myId"), constraints.get(0).getPropertyKeys()); + } + } + + public static void assertRelsAndIndexesCreated(GraphDatabaseService db) { + try (Transaction tx = db.beginTx()) { + List indexes = Iterables.stream(tx.schema().getIndexes()) + .filter(i -> i.getIndexType().equals(IndexType.VECTOR)) + .toList(); + assertEquals(1, indexes.size()); + assertEquals(List.of(RelationshipType.withName("TEST")), indexes.get(0).getRelationshipTypes()); + assertEquals(List.of("vect"), indexes.get(0).getPropertyKeys()); + + List constraints = Iterables.asList(tx.schema().getConstraints()); + assertEquals(1, constraints.size()); + assertEquals(RelationshipType.withName("TEST"), constraints.get(0).getRelationshipType()); + assertEquals(List.of("myId"), constraints.get(0).getPropertyKeys()); + } + + testResult(db, "MATCH (:Start)-[r:TEST]->(:End) RETURN properties(r) AS props ORDER BY r.myId", + r -> vectorEntityAssertions(r, false)); + } + + public static void vectorEntityAssertions(Result r, boolean isNew) { + ResourceIterator props = r.columnAs("props"); + Map next = props.next(); + assertEquals("Berlin", next.get("city")); + if (!isNew) { + assertEquals("one", next.get("myId")); + } + assertTrue(next.get("vect") instanceof float[]); + next = props.next(); + assertEquals("London", next.get("city")); + if (!isNew) { + assertEquals("two", next.get("myId")); + } + assertTrue(next.get("vect") instanceof float[]); + + assertFalse(props.hasNext()); + } +}