Skip to content

Commit

Permalink
Issue 424: Add support for handling of vector fields. Added two unit …
Browse files Browse the repository at this point in the history
…tests.

Signed-off-by: Dmitry Goldenberg <[email protected]>
  • Loading branch information
dgoldenberg-ias committed Sep 11, 2024
1 parent 5d9dabc commit 9379374
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2451,6 +2451,46 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool
assertEquals("again", samples.get(0).asInstanceOf[Row].get(0))
}

@Test
def testKnnVectorAsArrayOfFloats(): Unit = {
val mapping = wrapMapping("data", s"""{
| "properties": {
| "name": {
| "type": "$keyword"
| },
| "vector": {
| "type": "knn_vector",
| "dimension": 2
| }
| }
| }
""".stripMargin)

val index = wrapIndex("sparksql-test-knnvector-array-knnvector")
val typed = "data"
val (target, _) = makeTargets(index, typed)
RestUtils.touch(index)
RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))

val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin
sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target)

RestUtils.refresh(index)

val df = sqc.read.format("opensearch").load(index)

val dataType = df.schema("vector").dataType
assertEquals("array", dataType.typeName)
val array = dataType.asInstanceOf[ArrayType]
assertEquals(FloatType, array.elementType)

val head = df.head()
val vector = head.getSeq(1)
assertEquals(2, vector.length)
assertEquals(-0.013f, vector(0))
assertEquals(0.009f, vector(1))
}

/**
* Tests the handling of k-nn vector fields.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2467,6 +2467,46 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool
df.count()
}

@Test
def testKnnVectorAsArrayOfFloats(): Unit = {
val mapping = wrapMapping("data", s"""{
| "properties": {
| "name": {
| "type": "$keyword"
| },
| "vector": {
| "type": "knn_vector",
| "dimension": 2
| }
| }
| }
""".stripMargin)

val index = wrapIndex("sparksql-test-knnvector-array-knnvector")
val typed = "data"
val (target, _) = makeTargets(index, typed)
RestUtils.touch(index)
RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))

val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin
sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target)

RestUtils.refresh(index)

val df = sqc.read.format("opensearch").load(index)

val dataType = df.schema("vector").dataType
assertEquals("array", dataType.typeName)
val array = dataType.asInstanceOf[ArrayType]
assertEquals(FloatType, array.elementType)

val head = df.head()
val vector = head.getSeq(1)
assertEquals(2, vector.length)
assertEquals(-0.013f, vector(0))
assertEquals(0.009f, vector(1))
}

/**
* Tests the handling of k-nn vector fields.
*/
Expand Down

0 comments on commit 9379374

Please sign in to comment.