Skip to content

Commit

Permalink
Issue 424: Add support for handling of vector fields. (#489)
Browse files Browse the repository at this point in the history
* Issue 424: Add support for handling of vector fields. Signed-off-by: Dmitry Goldenberg <[email protected]>

Signed-off-by: Dmitry Goldenberg <[email protected]>

* Issue 424: Add support for handling of vector fields. Added a line to the CHANGELOG file. Signed-off-by: Dmitry Goldenberg <[email protected]>

Signed-off-by: Dmitry Goldenberg <[email protected]>

* Issue 424: Add support for handling of vector fields. Added the PR number to the CHANGELOG file.

Signed-off-by: Dmitry Goldenberg <[email protected]>

* Issue 424: Add support for handling of vector fields. Added the missing enum variant for KNN_VECTOR to the FieldType enum.

Signed-off-by: Dmitry Goldenberg <[email protected]>

* Issue 424: Add support for handling of vector fields. Signed-off-by: Dmitry Goldenberg <[email protected]>

Signed-off-by: Dmitry Goldenberg <[email protected]>

* Issue 424: Add support for handling of vector fields. Added two unit tests.

Signed-off-by: Dmitry Goldenberg <[email protected]>

* Issue 424: Add support for handling of vector fields. Added a couple of comments.

Signed-off-by: Dmitry Goldenberg <[email protected]>

* Issue 424: Add support for handling of vector fields. Added Ignore for testKnnVectorAsArrayOfFloats for now as the k-NN plugin is currently missing and needs to be added.

Signed-off-by: Dmitry Goldenberg <[email protected]>

---------

Signed-off-by: Dmitry Goldenberg <[email protected]>
  • Loading branch information
dgoldenberg-ias committed Sep 12, 2024
1 parent 766ae02 commit 0621ef2
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ metastore_db
out/
localRepo/
.vscode/
*.jar
*.jar
.DS_Store
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased]
### Added
- Added basic support for HTTP compression when writing to OpenSearch ([#451](https://github.com/opensearch-project/opensearch-hadoop/pull/451))
- Added support for k-nn vectors ([#424](https://github.com/opensearch-project/opensearch-hadoop/pull/489))

### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public Object createArray(FieldType type) {
case HALF_FLOAT:
case SCALED_FLOAT:
case FLOAT:
case KNN_VECTOR:
arrayType = FloatWritable.class;
break;
case DOUBLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public enum FieldType {
TOKEN_COUNT,
TEXT, KEYWORD, HALF_FLOAT, SCALED_FLOAT,
WILDCARD,
KNN_VECTOR,

GEO_POINT,
GEO_SHAPE,
Expand Down Expand Up @@ -92,6 +93,7 @@ public enum FieldType {
CAST_HIERARCHY.put(TEXT, new LinkedHashSet<FieldType>(Collections.singletonList(KEYWORD)));
CAST_HIERARCHY.put(KEYWORD, new LinkedHashSet<FieldType>());
CAST_HIERARCHY.put(WILDCARD, new LinkedHashSet<FieldType>(Collections.singletonList(KEYWORD)));
CAST_HIERARCHY.put(HALF_FLOAT, new LinkedHashSet<FieldType>(Arrays.asList(FLOAT)));
CAST_HIERARCHY.put(HALF_FLOAT, new LinkedHashSet<FieldType>(Arrays.asList(FLOAT, DOUBLE, KEYWORD)));
CAST_HIERARCHY.put(SCALED_FLOAT, new LinkedHashSet<FieldType>(Arrays.asList(DOUBLE, KEYWORD)));
CAST_HIERARCHY.put(GEO_POINT, new LinkedHashSet<FieldType>());
Expand All @@ -101,6 +103,7 @@ public enum FieldType {
CAST_HIERARCHY.put(JOIN, new LinkedHashSet<FieldType>());
CAST_HIERARCHY.put(IP, new LinkedHashSet<FieldType>(Collections.singletonList(KEYWORD)));
CAST_HIERARCHY.put(COMPLETION, new LinkedHashSet<FieldType>());
CAST_HIERARCHY.put(KNN_VECTOR, new LinkedHashSet<FieldType>(Arrays.asList(FLOAT)));
}

public static FieldType parse(String name) {
Expand Down Expand Up @@ -137,4 +140,4 @@ public LinkedHashSet<FieldType> getCastingTypes() {
}
return types;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public Object readValue(Parser parser, String value, FieldType esType) {
return longValue(value, parser);
case HALF_FLOAT:
case FLOAT:
case KNN_VECTOR:
return floatValue(value, parser);
case SCALED_FLOAT:
case DOUBLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import static org.opensearch.hadoop.serialization.FieldType.INTEGER;
import static org.opensearch.hadoop.serialization.FieldType.JOIN;
import static org.opensearch.hadoop.serialization.FieldType.KEYWORD;
import static org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR;
import static org.opensearch.hadoop.serialization.FieldType.LONG;
import static org.opensearch.hadoop.serialization.FieldType.NESTED;
import static org.opensearch.hadoop.serialization.FieldType.OBJECT;
Expand Down Expand Up @@ -147,7 +148,7 @@ public void testPrimitivesParsing() throws Exception {
MappingSet mappings = getMappingsForResource("primitives.json");
Mapping mapping = ensureAndGet("index", "primitives", mappings);
Field[] props = mapping.getFields();
assertEquals(16, props.length);
assertEquals(17, props.length);
assertEquals("field01", props[0].name());
assertEquals(BOOLEAN, props[0].type());
assertEquals("field02", props[1].name());
Expand Down Expand Up @@ -180,6 +181,8 @@ public void testPrimitivesParsing() throws Exception {
assertEquals(DATE_NANOS, props[14].type());
assertEquals("field16", props[15].name());
assertEquals(WILDCARD, props[15].type());
assertEquals("field17", props[16].name());
assertEquals(KNN_VECTOR, props[16].type());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
},
"field16" : {
"type" : "wildcard"
},
"field17" : {
"type" : "knn_vector"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
},
"field16" : {
"type" : "wildcard"
},
"field17" : {
"type" : "knn_vector"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
*/
package org.opensearch.spark.serialization

import java.io.IOException
import java.util.Collections
import java.util.Date
import java.util.{List => JList}
import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.Seq
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.Map
import org.opensearch.hadoop.serialization.FieldType.BINARY
import org.opensearch.hadoop.serialization.FieldType.BOOLEAN
Expand All @@ -56,6 +58,7 @@ import org.opensearch.hadoop.serialization.FieldType.STRING
import org.opensearch.hadoop.serialization.FieldType.TEXT
import org.opensearch.hadoop.serialization.FieldType.TOKEN_COUNT
import org.opensearch.hadoop.serialization.FieldType.WILDCARD
import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR
import org.opensearch.hadoop.serialization.Parser.Token.VALUE_BOOLEAN
import org.opensearch.hadoop.serialization.Parser.Token.VALUE_NULL
import org.opensearch.hadoop.serialization.Parser.Token.VALUE_NUMBER
Expand Down Expand Up @@ -103,6 +106,7 @@ class ScalaValueReader extends AbstractValueReader with SettingsAware {
case BINARY => binaryValue(Option(parser.binaryValue()).getOrElse(value.getBytes()))
case DATE => date(value, parser)
case DATE_NANOS => dateNanos(value, parser)
case KNN_VECTOR => floatValue(value, parser)
// GEO is ambiguous so use the JSON type instead to differentiate between doubles (a lot in GEO_SHAPE) and strings
case GEO_POINT | GEO_SHAPE => {
if (parser.currentToken() == VALUE_NUMBER) doubleValue(value, parser) else textValue(value, parser)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -2450,6 +2451,50 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool
assertEquals("again", samples.get(0).asInstanceOf[Row].get(0))
}

/**
* Tests the handling of k-nn vector fields.
*/
@Test
@Ignore("k-NN plugin is currently missing")
def testKnnVectorAsArrayOfFloats(): Unit = {
val mapping = wrapMapping("data", s"""{
| "properties": {
| "name": {
| "type": "$keyword"
| },
| "vector": {
| "type": "knn_vector",
| "dimension": 2
| }
| }
| }
""".stripMargin)

val index = wrapIndex("sparksql-test-knnvector-array-knnvector")
val typed = "data"
val (target, _) = makeTargets(index, typed)
RestUtils.touch(index)
RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))

val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin
sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target)

RestUtils.refresh(index)

val df = sqc.read.format("opensearch").load(index)

val dataType = df.schema("vector").dataType
assertEquals("array", dataType.typeName)
val array = dataType.asInstanceOf[ArrayType]
assertEquals(FloatType, array.elementType)

val head = df.head()
val vector = head.getSeq(1)
assertEquals(2, vector.length)
assertEquals(-0.013f, vector(0))
assertEquals(0.009f, vector(1))
}

/**
* Take advantage of the fixed method order and clear out all created indices.
* The indices will last in OpenSearch for all parameters of this test suite.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import org.opensearch.hadoop.serialization.FieldType.GEO_SHAPE
import org.opensearch.hadoop.serialization.FieldType.INTEGER
import org.opensearch.hadoop.serialization.FieldType.JOIN
import org.opensearch.hadoop.serialization.FieldType.KEYWORD
import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR
import org.opensearch.hadoop.serialization.FieldType.LONG
import org.opensearch.hadoop.serialization.FieldType.NESTED
import org.opensearch.hadoop.serialization.FieldType.NULL
Expand Down Expand Up @@ -169,6 +170,7 @@ private[sql] object SchemaUtils {
case WILDCARD => StringType
case DATE => if (cfg.getMappingDateRich) TimestampType else StringType
case DATE_NANOS => if (cfg.getMappingDateRich) TimestampType else StringType
case KNN_VECTOR => DataTypes.createArrayType(FloatType)
case OBJECT => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)
case NESTED => DataTypes.createArrayType(convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg))
case JOIN => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -2466,6 +2467,50 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool
df.count()
}

/**
* Tests the handling of k-nn vector fields.
*/
@Test
@Ignore("k-NN plugin is currently missing")
def testKnnVectorAsArrayOfFloats(): Unit = {
val mapping = wrapMapping("data", s"""{
| "properties": {
| "name": {
| "type": "$keyword"
| },
| "vector": {
| "type": "knn_vector",
| "dimension": 2
| }
| }
| }
""".stripMargin)

val index = wrapIndex("sparksql-test-knnvector-array-knnvector")
val typed = "data"
val (target, _) = makeTargets(index, typed)
RestUtils.touch(index)
RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))

val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin
sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target)

RestUtils.refresh(index)

val df = sqc.read.format("opensearch").load(index)

val dataType = df.schema("vector").dataType
assertEquals("array", dataType.typeName)
val array = dataType.asInstanceOf[ArrayType]
assertEquals(FloatType, array.elementType)

val head = df.head()
val vector = head.getSeq(1)
assertEquals(2, vector.length)
assertEquals(-0.013f, vector(0))
assertEquals(0.009f, vector(1))
}

/**
* Take advantage of the fixed method order and clear out all created indices.
* The indices will last in OpenSearch for all parameters of this test suite.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import org.opensearch.hadoop.serialization.FieldType.GEO_SHAPE
import org.opensearch.hadoop.serialization.FieldType.INTEGER
import org.opensearch.hadoop.serialization.FieldType.JOIN
import org.opensearch.hadoop.serialization.FieldType.KEYWORD
import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR
import org.opensearch.hadoop.serialization.FieldType.LONG
import org.opensearch.hadoop.serialization.FieldType.NESTED
import org.opensearch.hadoop.serialization.FieldType.NULL
Expand Down Expand Up @@ -169,6 +170,7 @@ private[sql] object SchemaUtils {
case WILDCARD => StringType
case DATE => if (cfg.getMappingDateRich) TimestampType else StringType
case DATE_NANOS => if (cfg.getMappingDateRich) TimestampType else StringType
case KNN_VECTOR => DataTypes.createArrayType(FloatType)
case OBJECT => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)
case NESTED => DataTypes.createArrayType(convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg))
case JOIN => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)
Expand Down

0 comments on commit 0621ef2

Please sign in to comment.