diff --git a/.gitignore b/.gitignore index 4d8e20aa6..187a76d96 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,5 @@ metastore_db out/ localRepo/ .vscode/ -*.jar \ No newline at end of file +*.jar +.DS_Store diff --git a/CHANGELOG.md b/CHANGELOG.md index 89a52ce31..5639991a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] ### Added - Added basic support for HTTP compression when writing to OpenSearch ([#451](https://github.com/opensearch-project/opensearch-hadoop/pull/451)) +- Added support for k-nn vectors ([#424](https://github.com/opensearch-project/opensearch-hadoop/pull/489)) ### Changed diff --git a/mr/src/main/java/org/opensearch/hadoop/mr/WritableValueReader.java b/mr/src/main/java/org/opensearch/hadoop/mr/WritableValueReader.java index 77d1ccf64..93933e7ce 100644 --- a/mr/src/main/java/org/opensearch/hadoop/mr/WritableValueReader.java +++ b/mr/src/main/java/org/opensearch/hadoop/mr/WritableValueReader.java @@ -72,6 +72,7 @@ public Object createArray(FieldType type) { case HALF_FLOAT: case SCALED_FLOAT: case FLOAT: + case KNN_VECTOR: arrayType = FloatWritable.class; break; case DOUBLE: diff --git a/mr/src/main/java/org/opensearch/hadoop/serialization/FieldType.java b/mr/src/main/java/org/opensearch/hadoop/serialization/FieldType.java index 10f0e658e..913a99d67 100644 --- a/mr/src/main/java/org/opensearch/hadoop/serialization/FieldType.java +++ b/mr/src/main/java/org/opensearch/hadoop/serialization/FieldType.java @@ -54,6 +54,7 @@ public enum FieldType { TOKEN_COUNT, TEXT, KEYWORD, HALF_FLOAT, SCALED_FLOAT, WILDCARD, + KNN_VECTOR, GEO_POINT, GEO_SHAPE, @@ -92,6 +93,7 @@ public enum FieldType { CAST_HIERARCHY.put(TEXT, new LinkedHashSet(Collections.singletonList(KEYWORD))); CAST_HIERARCHY.put(KEYWORD, new LinkedHashSet()); CAST_HIERARCHY.put(WILDCARD, new LinkedHashSet(Collections.singletonList(KEYWORD))); + CAST_HIERARCHY.put(HALF_FLOAT, new LinkedHashSet(Arrays.asList(FLOAT))); CAST_HIERARCHY.put(HALF_FLOAT, new LinkedHashSet(Arrays.asList(FLOAT, DOUBLE, KEYWORD))); CAST_HIERARCHY.put(SCALED_FLOAT, new LinkedHashSet(Arrays.asList(DOUBLE, KEYWORD))); CAST_HIERARCHY.put(GEO_POINT, new LinkedHashSet()); @@ -101,6 +103,7 @@ public enum FieldType { CAST_HIERARCHY.put(JOIN, new LinkedHashSet()); CAST_HIERARCHY.put(IP, new LinkedHashSet(Collections.singletonList(KEYWORD))); CAST_HIERARCHY.put(COMPLETION, new LinkedHashSet()); + CAST_HIERARCHY.put(KNN_VECTOR, new LinkedHashSet(Arrays.asList(FLOAT))); } public static FieldType parse(String name) { @@ -137,4 +140,4 @@ public LinkedHashSet getCastingTypes() { } return types; } -} \ No newline at end of file +} diff --git a/mr/src/main/java/org/opensearch/hadoop/serialization/builder/JdkValueReader.java b/mr/src/main/java/org/opensearch/hadoop/serialization/builder/JdkValueReader.java index ecb3f3008..ea4eeb5c9 100644 --- a/mr/src/main/java/org/opensearch/hadoop/serialization/builder/JdkValueReader.java +++ b/mr/src/main/java/org/opensearch/hadoop/serialization/builder/JdkValueReader.java @@ -82,6 +82,7 @@ public Object readValue(Parser parser, String value, FieldType esType) { return longValue(value, parser); case HALF_FLOAT: case FLOAT: + case KNN_VECTOR: return floatValue(value, parser); case SCALED_FLOAT: case DOUBLE: diff --git a/mr/src/test/java/org/opensearch/hadoop/serialization/dto/mapping/MappingTest.java b/mr/src/test/java/org/opensearch/hadoop/serialization/dto/mapping/MappingTest.java index 58be85449..edada4090 100644 --- a/mr/src/test/java/org/opensearch/hadoop/serialization/dto/mapping/MappingTest.java +++ b/mr/src/test/java/org/opensearch/hadoop/serialization/dto/mapping/MappingTest.java @@ -57,6 +57,7 @@ import static org.opensearch.hadoop.serialization.FieldType.INTEGER; import static org.opensearch.hadoop.serialization.FieldType.JOIN; import static org.opensearch.hadoop.serialization.FieldType.KEYWORD; +import static org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR; import static org.opensearch.hadoop.serialization.FieldType.LONG; import static org.opensearch.hadoop.serialization.FieldType.NESTED; import static org.opensearch.hadoop.serialization.FieldType.OBJECT; @@ -147,7 +148,7 @@ public void testPrimitivesParsing() throws Exception { MappingSet mappings = getMappingsForResource("primitives.json"); Mapping mapping = ensureAndGet("index", "primitives", mappings); Field[] props = mapping.getFields(); - assertEquals(16, props.length); + assertEquals(17, props.length); assertEquals("field01", props[0].name()); assertEquals(BOOLEAN, props[0].type()); assertEquals("field02", props[1].name()); @@ -180,6 +181,8 @@ public void testPrimitivesParsing() throws Exception { assertEquals(DATE_NANOS, props[14].type()); assertEquals("field16", props[15].name()); assertEquals(WILDCARD, props[15].type()); + assertEquals("field17", props[16].name()); + assertEquals(KNN_VECTOR, props[16].type()); } @Test diff --git a/mr/src/test/resources/org/opensearch/hadoop/serialization/dto/mapping/typed/primitives.json b/mr/src/test/resources/org/opensearch/hadoop/serialization/dto/mapping/typed/primitives.json index 26879c939..1cbc8d0c4 100644 --- a/mr/src/test/resources/org/opensearch/hadoop/serialization/dto/mapping/typed/primitives.json +++ b/mr/src/test/resources/org/opensearch/hadoop/serialization/dto/mapping/typed/primitives.json @@ -51,6 +51,9 @@ }, "field16" : { "type" : "wildcard" + }, + "field17" : { + "type" : "knn_vector" } } } diff --git a/mr/src/test/resources/org/opensearch/hadoop/serialization/dto/mapping/typeless/primitives.json b/mr/src/test/resources/org/opensearch/hadoop/serialization/dto/mapping/typeless/primitives.json index 1c2f5358c..dbe73dbdf 100644 --- a/mr/src/test/resources/org/opensearch/hadoop/serialization/dto/mapping/typeless/primitives.json +++ b/mr/src/test/resources/org/opensearch/hadoop/serialization/dto/mapping/typeless/primitives.json @@ -50,6 +50,9 @@ }, "field16" : { "type" : "wildcard" + }, + "field17" : { + "type" : "knn_vector" } } } diff --git a/spark/core/src/main/scala/org/opensearch/spark/serialization/ScalaValueReader.scala b/spark/core/src/main/scala/org/opensearch/spark/serialization/ScalaValueReader.scala index 3ffee8ce3..7d8dad5f7 100644 --- a/spark/core/src/main/scala/org/opensearch/spark/serialization/ScalaValueReader.scala +++ b/spark/core/src/main/scala/org/opensearch/spark/serialization/ScalaValueReader.scala @@ -28,12 +28,14 @@ */ package org.opensearch.spark.serialization +import java.io.IOException import java.util.Collections import java.util.Date import java.util.{List => JList} import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.Seq import scala.collection.mutable.LinkedHashMap +import scala.collection.mutable.ListBuffer import scala.collection.mutable.Map import org.opensearch.hadoop.serialization.FieldType.BINARY import org.opensearch.hadoop.serialization.FieldType.BOOLEAN @@ -56,6 +58,7 @@ import org.opensearch.hadoop.serialization.FieldType.STRING import org.opensearch.hadoop.serialization.FieldType.TEXT import org.opensearch.hadoop.serialization.FieldType.TOKEN_COUNT import org.opensearch.hadoop.serialization.FieldType.WILDCARD +import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR import org.opensearch.hadoop.serialization.Parser.Token.VALUE_BOOLEAN import org.opensearch.hadoop.serialization.Parser.Token.VALUE_NULL import org.opensearch.hadoop.serialization.Parser.Token.VALUE_NUMBER @@ -103,6 +106,7 @@ class ScalaValueReader extends AbstractValueReader with SettingsAware { case BINARY => binaryValue(Option(parser.binaryValue()).getOrElse(value.getBytes())) case DATE => date(value, parser) case DATE_NANOS => dateNanos(value, parser) + case KNN_VECTOR => floatValue(value, parser) // GEO is ambiguous so use the JSON type instead to differentiate between doubles (a lot in GEO_SHAPE) and strings case GEO_POINT | GEO_SHAPE => { if (parser.currentToken() == VALUE_NUMBER) doubleValue(value, parser) else textValue(value, parser) diff --git a/spark/sql-20/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala b/spark/sql-20/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala index 8298acbd1..33eeb7955 100644 --- a/spark/sql-20/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala +++ b/spark/sql-20/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala @@ -52,6 +52,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.types.DecimalType import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.FloatType import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.MapType import org.apache.spark.sql.types.StringType @@ -2450,6 +2451,50 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool assertEquals("again", samples.get(0).asInstanceOf[Row].get(0)) } + /** + * Tests the handling of k-nn vector fields. + */ + @Test + @Ignore("k-NN plugin is currently missing") + def testKnnVectorAsArrayOfFloats(): Unit = { + val mapping = wrapMapping("data", s"""{ + | "properties": { + | "name": { + | "type": "$keyword" + | }, + | "vector": { + | "type": "knn_vector", + | "dimension": 2 + | } + | } + | } + """.stripMargin) + + val index = wrapIndex("sparksql-test-knnvector-array-knnvector") + val typed = "data" + val (target, _) = makeTargets(index, typed) + RestUtils.touch(index) + RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8)) + + val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin + sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target) + + RestUtils.refresh(index) + + val df = sqc.read.format("opensearch").load(index) + + val dataType = df.schema("vector").dataType + assertEquals("array", dataType.typeName) + val array = dataType.asInstanceOf[ArrayType] + assertEquals(FloatType, array.elementType) + + val head = df.head() + val vector = head.getSeq(1) + assertEquals(2, vector.length) + assertEquals(-0.013f, vector(0)) + assertEquals(0.009f, vector(1)) + } + /** * Take advantage of the fixed method order and clear out all created indices. * The indices will last in OpenSearch for all parameters of this test suite. diff --git a/spark/sql-20/src/main/scala/org/opensearch/spark/sql/SchemaUtils.scala b/spark/sql-20/src/main/scala/org/opensearch/spark/sql/SchemaUtils.scala index 232aaf710..62f1b3842 100644 --- a/spark/sql-20/src/main/scala/org/opensearch/spark/sql/SchemaUtils.scala +++ b/spark/sql-20/src/main/scala/org/opensearch/spark/sql/SchemaUtils.scala @@ -66,6 +66,7 @@ import org.opensearch.hadoop.serialization.FieldType.GEO_SHAPE import org.opensearch.hadoop.serialization.FieldType.INTEGER import org.opensearch.hadoop.serialization.FieldType.JOIN import org.opensearch.hadoop.serialization.FieldType.KEYWORD +import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR import org.opensearch.hadoop.serialization.FieldType.LONG import org.opensearch.hadoop.serialization.FieldType.NESTED import org.opensearch.hadoop.serialization.FieldType.NULL @@ -169,6 +170,7 @@ private[sql] object SchemaUtils { case WILDCARD => StringType case DATE => if (cfg.getMappingDateRich) TimestampType else StringType case DATE_NANOS => if (cfg.getMappingDateRich) TimestampType else StringType + case KNN_VECTOR => DataTypes.createArrayType(FloatType) case OBJECT => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg) case NESTED => DataTypes.createArrayType(convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)) case JOIN => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg) diff --git a/spark/sql-30/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala b/spark/sql-30/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala index 3df08754c..47c0090a6 100644 --- a/spark/sql-30/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala +++ b/spark/sql-30/src/itest/scala/org/opensearch/spark/integration/AbstractScalaOpenSearchSparkSQL.scala @@ -52,6 +52,7 @@ import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.types.DecimalType import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.FloatType import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.MapType import org.apache.spark.sql.types.StringType @@ -2466,6 +2467,50 @@ class AbstractScalaOpenSearchScalaSparkSQL(prefix: String, readMetadata: jl.Bool df.count() } + /** + * Tests the handling of k-nn vector fields. + */ + @Test + @Ignore("k-NN plugin is currently missing") + def testKnnVectorAsArrayOfFloats(): Unit = { + val mapping = wrapMapping("data", s"""{ + | "properties": { + | "name": { + | "type": "$keyword" + | }, + | "vector": { + | "type": "knn_vector", + | "dimension": 2 + | } + | } + | } + """.stripMargin) + + val index = wrapIndex("sparksql-test-knnvector-array-knnvector") + val typed = "data" + val (target, _) = makeTargets(index, typed) + RestUtils.touch(index) + RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8)) + + val arrayOfFloats = """{ "name": "Mini Munchies Pizza", "vector": [ -0.013f, 0.009f ]}""".stripMargin + sc.makeRDD(Seq(arrayOfFloats)).saveJsonToOpenSearch(target) + + RestUtils.refresh(index) + + val df = sqc.read.format("opensearch").load(index) + + val dataType = df.schema("vector").dataType + assertEquals("array", dataType.typeName) + val array = dataType.asInstanceOf[ArrayType] + assertEquals(FloatType, array.elementType) + + val head = df.head() + val vector = head.getSeq(1) + assertEquals(2, vector.length) + assertEquals(-0.013f, vector(0)) + assertEquals(0.009f, vector(1)) + } + /** * Take advantage of the fixed method order and clear out all created indices. * The indices will last in OpenSearch for all parameters of this test suite. diff --git a/spark/sql-30/src/main/scala/org/opensearch/spark/sql/SchemaUtils.scala b/spark/sql-30/src/main/scala/org/opensearch/spark/sql/SchemaUtils.scala index d05b0fef5..e63ca7b43 100644 --- a/spark/sql-30/src/main/scala/org/opensearch/spark/sql/SchemaUtils.scala +++ b/spark/sql-30/src/main/scala/org/opensearch/spark/sql/SchemaUtils.scala @@ -67,6 +67,7 @@ import org.opensearch.hadoop.serialization.FieldType.GEO_SHAPE import org.opensearch.hadoop.serialization.FieldType.INTEGER import org.opensearch.hadoop.serialization.FieldType.JOIN import org.opensearch.hadoop.serialization.FieldType.KEYWORD +import org.opensearch.hadoop.serialization.FieldType.KNN_VECTOR import org.opensearch.hadoop.serialization.FieldType.LONG import org.opensearch.hadoop.serialization.FieldType.NESTED import org.opensearch.hadoop.serialization.FieldType.NULL @@ -169,6 +170,7 @@ private[sql] object SchemaUtils { case WILDCARD => StringType case DATE => if (cfg.getMappingDateRich) TimestampType else StringType case DATE_NANOS => if (cfg.getMappingDateRich) TimestampType else StringType + case KNN_VECTOR => DataTypes.createArrayType(FloatType) case OBJECT => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg) case NESTED => DataTypes.createArrayType(convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)) case JOIN => convertToStruct(field, geoInfo, absoluteName, arrayIncludes, arrayExcludes, cfg)