Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 424: Add support for handling of vector fields. #489

Merged
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ metastore_db
out/
localRepo/
.vscode/
*.jar
*.jar
.DS_Store
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased]
### Added
- Added basic support for HTTP compression when writing to OpenSearch ([#451](https://github.com/opensearch-project/opensearch-hadoop/pull/451))
- Added support for k-nn vectors ([#424](https://github.com/opensearch-project/opensearch-hadoop/pull/489))

### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public Object createArray(FieldType type) {
case HALF_FLOAT:
case SCALED_FLOAT:
case FLOAT:
case KNN_VECTOR:
arrayType = FloatWritable.class;
break;
case DOUBLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public enum FieldType {
TOKEN_COUNT,
TEXT, KEYWORD, HALF_FLOAT, SCALED_FLOAT,
WILDCARD,
KNN_VECTOR,

GEO_POINT,
GEO_SHAPE,
Expand Down Expand Up @@ -92,6 +93,7 @@ public enum FieldType {
CAST_HIERARCHY.put(TEXT, new LinkedHashSet<FieldType>(Collections.singletonList(KEYWORD)));
CAST_HIERARCHY.put(KEYWORD, new LinkedHashSet<FieldType>());
CAST_HIERARCHY.put(WILDCARD, new LinkedHashSet<FieldType>(Collections.singletonList(KEYWORD)));
CAST_HIERARCHY.put(HALF_FLOAT, new LinkedHashSet<FieldType>(Arrays.asList(FLOAT)));
CAST_HIERARCHY.put(HALF_FLOAT, new LinkedHashSet<FieldType>(Arrays.asList(FLOAT, DOUBLE, KEYWORD)));
CAST_HIERARCHY.put(SCALED_FLOAT, new LinkedHashSet<FieldType>(Arrays.asList(DOUBLE, KEYWORD)));
CAST_HIERARCHY.put(GEO_POINT, new LinkedHashSet<FieldType>());
Expand All @@ -101,6 +103,7 @@ public enum FieldType {
CAST_HIERARCHY.put(JOIN, new LinkedHashSet<FieldType>());
CAST_HIERARCHY.put(IP, new LinkedHashSet<FieldType>(Collections.singletonList(KEYWORD)));
CAST_HIERARCHY.put(COMPLETION, new LinkedHashSet<FieldType>());
CAST_HIERARCHY.put(KNN_VECTOR, new LinkedHashSet<FieldType>(Arrays.asList(FLOAT)));
}

public static FieldType parse(String name) {
Expand Down Expand Up @@ -137,4 +140,4 @@ public LinkedHashSet<FieldType> getCastingTypes() {
}
return types;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public Object readValue(Parser parser, String value, FieldType esType) {
return longValue(value, parser);
case HALF_FLOAT:
case FLOAT:
case KNN_VECTOR:
return floatValue(value, parser);
case SCALED_FLOAT:
case DOUBLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import static org.opensearch.hadoop.serialization.FieldType.INTEGER;
import static org.opensearch.hadoop.serialization.FieldType.JOIN;
import static org.opensearch.hadoop.serialization.FieldType.KEYWORD;
import static org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR;
import static org.opensearch.hadoop.serialization.FieldType.LONG;
import static org.opensearch.hadoop.serialization.FieldType.NESTED;
import static org.opensearch.hadoop.serialization.FieldType.OBJECT;
Expand Down Expand Up @@ -147,7 +148,7 @@ public void testPrimitivesParsing() throws Exception {
MappingSet mappings = getMappingsForResource("primitives.json");
Mapping mapping = ensureAndGet("index", "primitives", mappings);
Field[] props = mapping.getFields();
assertEquals(16, props.length);
assertEquals(17, props.length);
assertEquals("field01", props[0].name());
assertEquals(BOOLEAN, props[0].type());
assertEquals("field02", props[1].name());
Expand Down Expand Up @@ -180,6 +181,8 @@ public void testPrimitivesParsing() throws Exception {
assertEquals(DATE_NANOS, props[14].type());
assertEquals("field16", props[15].name());
assertEquals(WILDCARD, props[15].type());
assertEquals("field17", props[16].name());
assertEquals(KNN_VECTOR, props[16].type());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
},
"field16" : {
"type" : "wildcard"
},
"field17" : {
"type" : "knn_vector"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
},
"field16" : {
"type" : "wildcard"
},
"field17" : {
"type" : "knn_vector"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
*/
package org.opensearch.spark.serialization

import java.io.IOException
import java.util.Collections
import java.util.Date
import java.util.{List => JList}
import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.Seq
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.Map
import org.opensearch.hadoop.serialization.FieldType.BINARY
import org.opensearch.hadoop.serialization.FieldType.BOOLEAN
Expand All @@ -56,6 +58,7 @@ import org.opensearch.hadoop.serialization.FieldType.STRING
import org.opensearch.hadoop.serialization.FieldType.TEXT
import org.opensearch.hadoop.serialization.FieldType.TOKEN_COUNT
import org.opensearch.hadoop.serialization.FieldType.WILDCARD
import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR
import org.opensearch.hadoop.serialization.Parser.Token.VALUE_BOOLEAN
import org.opensearch.hadoop.serialization.Parser.Token.VALUE_NULL
import org.opensearch.hadoop.serialization.Parser.Token.VALUE_NUMBER
Expand Down Expand Up @@ -103,6 +106,7 @@ class ScalaValueReader extends AbstractValueReader with SettingsAware {
case BINARY => binaryValue(Option(parser.binaryValue()).getOrElse(value.getBytes()))
case DATE => date(value, parser)
case DATE_NANOS => dateNanos(value, parser)
case KNN_VECTOR => floatValue(value, parser)
// GEO is ambiguous so use the JSON type instead to differentiate between doubles (a lot in GEO_SHAPE) and strings
case GEO_POINT | GEO_SHAPE => {
if (parser.currentToken() == VALUE_NUMBER) doubleValue(value, parser) else textValue(value, parser)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -2450,6 +2451,50 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool
assertEquals("again", samples.get(0).asInstanceOf[Row].get(0))
}

/**
* Tests the handling of k-nn vector fields.
*/
@Test
@Ignore("k-NN plugin is currently missing")
def testKnnVectorAsArrayOfFloats(): Unit = {
val mapping = wrapMapping("data", s"""{
| "properties": {
| "name": {
| "type": "$keyword"
| },
| "vector": {
| "type": "knn_vector",
| "dimension": 2
| }
| }
| }
""".stripMargin)

val index = wrapIndex("sparksql-test-knnvector-array-knnvector")
val typed = "data"
val (target, _) = makeTargets(index, typed)
RestUtils.touch(index)
RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))

val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin
sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target)

RestUtils.refresh(index)

val df = sqc.read.format("opensearch").load(index)

val dataType = df.schema("vector").dataType
assertEquals("array", dataType.typeName)
val array = dataType.asInstanceOf[ArrayType]
assertEquals(FloatType, array.elementType)

val head = df.head()
val vector = head.getSeq(1)
assertEquals(2, vector.length)
assertEquals(-0.013f, vector(0))
assertEquals(0.009f, vector(1))
}

/**
* Take advantage of the fixed method order and clear out all created indices.
* The indices will last in OpenSearch for all parameters of this test suite.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import org.opensearch.hadoop.serialization.FieldType.GEO_SHAPE
import org.opensearch.hadoop.serialization.FieldType.INTEGER
import org.opensearch.hadoop.serialization.FieldType.JOIN
import org.opensearch.hadoop.serialization.FieldType.KEYWORD
import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR
import org.opensearch.hadoop.serialization.FieldType.LONG
import org.opensearch.hadoop.serialization.FieldType.NESTED
import org.opensearch.hadoop.serialization.FieldType.NULL
Expand Down Expand Up @@ -169,6 +170,7 @@ private[sql] object SchemaUtils {
case WILDCARD => StringType
case DATE => if (cfg.getMappingDateRich) TimestampType else StringType
case DATE_NANOS => if (cfg.getMappingDateRich) TimestampType else StringType
case KNN_VECTOR => DataTypes.createArrayType(FloatType)
case OBJECT => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)
case NESTED => DataTypes.createArrayType(convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg))
case JOIN => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -2466,6 +2467,50 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool
df.count()
}

/**
* Tests the handling of k-nn vector fields.
*/
@Test
@Ignore("k-NN plugin is currently missing")
def testKnnVectorAsArrayOfFloats(): Unit = {
val mapping = wrapMapping("data", s"""{
| "properties": {
| "name": {
| "type": "$keyword"
| },
| "vector": {
| "type": "knn_vector",
| "dimension": 2
| }
| }
| }
""".stripMargin)

val index = wrapIndex("sparksql-test-knnvector-array-knnvector")
val typed = "data"
val (target, _) = makeTargets(index, typed)
RestUtils.touch(index)
RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))

val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin
sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target)

RestUtils.refresh(index)

val df = sqc.read.format("opensearch").load(index)

val dataType = df.schema("vector").dataType
assertEquals("array", dataType.typeName)
val array = dataType.asInstanceOf[ArrayType]
assertEquals(FloatType, array.elementType)

val head = df.head()
val vector = head.getSeq(1)
assertEquals(2, vector.length)
assertEquals(-0.013f, vector(0))
assertEquals(0.009f, vector(1))
}

/**
* Take advantage of the fixed method order and clear out all created indices.
* The indices will last in OpenSearch for all parameters of this test suite.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import org.opensearch.hadoop.serialization.FieldType.GEO_SHAPE
import org.opensearch.hadoop.serialization.FieldType.INTEGER
import org.opensearch.hadoop.serialization.FieldType.JOIN
import org.opensearch.hadoop.serialization.FieldType.KEYWORD
import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR
import org.opensearch.hadoop.serialization.FieldType.LONG
import org.opensearch.hadoop.serialization.FieldType.NESTED
import org.opensearch.hadoop.serialization.FieldType.NULL
Expand Down Expand Up @@ -169,6 +170,7 @@ private[sql] object SchemaUtils {
case WILDCARD => StringType
case DATE => if (cfg.getMappingDateRich) TimestampType else StringType
case DATE_NANOS => if (cfg.getMappingDateRich) TimestampType else StringType
case KNN_VECTOR => DataTypes.createArrayType(FloatType)
case OBJECT => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)
case NESTED => DataTypes.createArrayType(convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg))
case JOIN => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)
Expand Down
Loading