From 9379374a25abb1fdb9c053e221bbaa051d2c6096 Mon Sep 17 00:00:00 2001 From: Dmitry Goldenberg Date: Sat, 20 Jul 2024 16:02:04 -0400 Subject: [PATCH] Issue 424: Add support for handling of vector fields. Added two unit tests. Signed-off-by: Dmitry Goldenberg --- .../AbstractScalaOpenSearchSparkSQL.scala | 40 +++++++++++++++++++ .../AbstractScalaOpenSearchSparkSQL.scala | 40 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/spark/sql-20/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala b/spark/sql-20/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala index 33eeb7955..b5e4db7f9 100644 --- a/spark/sql-20/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala +++ b/spark/sql-20/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala @@ -2451,6 +2451,46 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool assertEquals("again", samples.get(0).asInstanceOf[Row].get(0)) } + @Test + def testKnnVectorAsArrayOfFloats(): Unit = { + val mapping = wrapMapping("data", s"""{ + | "properties": { + | "name": { + | "type": "$keyword" + | }, + | "vector": { + | "type": "knn_vector", + | "dimension": 2 + | } + | } + | } + """.stripMargin) + + val index = wrapIndex("sparksql-test-knnvector-array-knnvector") + val typed = "data" + val (target, _) = makeTargets(index, typed) + RestUtils.touch(index) + RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8)) + + val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin + sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target) + + RestUtils.refresh(index) + + val df = sqc.read.format("opensearch").load(index) + + val dataType = df.schema("vector").dataType + assertEquals("array", dataType.typeName) + val array = dataType.asInstanceOf[ArrayType] + assertEquals(FloatType, array.elementType) + + val head = df.head() + val vector = head.getSeq(1) + assertEquals(2, vector.length) + assertEquals(-0.013f, vector(0)) + assertEquals(0.009f, vector(1)) + } + /** * Tests the handling of k-nn vector fields. */ diff --git a/spark/sql-30/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala b/spark/sql-30/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala index 47c0090a6..189ffeea2 100644 --- a/spark/sql-30/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala +++ b/spark/sql-30/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala @@ -2467,6 +2467,46 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool df.count() } + @Test + def testKnnVectorAsArrayOfFloats(): Unit = { + val mapping = wrapMapping("data", s"""{ + | "properties": { + | "name": { + | "type": "$keyword" + | }, + | "vector": { + | "type": "knn_vector", + | "dimension": 2 + | } + | } + | } + """.stripMargin) + + val index = wrapIndex("sparksql-test-knnvector-array-knnvector") + val typed = "data" + val (target, _) = makeTargets(index, typed) + RestUtils.touch(index) + RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8)) + + val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin + sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target) + + RestUtils.refresh(index) + + val df = sqc.read.format("opensearch").load(index) + + val dataType = df.schema("vector").dataType + assertEquals("array", dataType.typeName) + val array = dataType.asInstanceOf[ArrayType] + assertEquals(FloatType, array.elementType) + + val head = df.head() + val vector = head.getSeq(1) + assertEquals(2, vector.length) + assertEquals(-0.013f, vector(0)) + assertEquals(0.009f, vector(1)) + } + /** * Tests the handling of k-nn vector fields. */