From bd5788287957d8610a6d19c273b75bd4cdd2d166 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 5 May 2017 11:08:26 -0700 Subject: [PATCH 01/30] [SPARK-20603][SS][TEST] Set default number of topic partitions to 1 to reduce the load ## What changes were proposed in this pull request? I checked the logs of https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.2-test-maven-hadoop-2.7/47/ and found it took several seconds to create Kafka internal topic `__consumer_offsets`. As Kafka creates this topic lazily, the topic creation happens in the first test `deserialization of initial offset with Spark 2.1.0` and causes it timeout. This PR changes `offsets.topic.num.partitions` from the default value 50 to 1 to make creating `__consumer_offsets` (50 partitions -> 1 partition) much faster. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17863 from zsxwing/fix-kafka-flaky-test. --- .../scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 2ce2760b7f463..f86b8f586d2a0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -292,6 +292,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") + props.put("offsets.topic.num.partitions", "1") props.putAll(withBrokerProps.asJava) props } From b31648c081e8db34e0d6c71875318f7b0b047c8b Mon Sep 17 00:00:00 2001 From: Jannik Arndt Date: Fri, 5 May 2017 11:42:55 -0700 Subject: [PATCH 02/30] [SPARK-20557][SQL] Support for db column type TIMESTAMP WITH TIME ZONE ## What changes were proposed in this pull request? SparkSQL can now read from a database table with column type [TIMESTAMP WITH TIME ZONE](https://docs.oracle.com/javase/8/docs/api/java/sql/Types.html#TIMESTAMP_WITH_TIMEZONE). ## How was this patch tested? Tested against Oracle database. JoshRosen, you seem to know the class, would you look at this? Thanks! Author: Jannik Arndt Closes #17832 from JannikArndt/spark-20557-timestamp-with-timezone. --- .../spark/sql/jdbc/OracleIntegrationSuite.scala | 13 +++++++++++++ .../sql/execution/datasources/jdbc/JdbcUtils.scala | 3 +++ 2 files changed, 16 insertions(+) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 1bb89a361ca75..85d4a4a791e6b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -70,6 +70,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")).executeUpdate() conn.commit() + conn.prepareStatement("CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)") + .executeUpdate() + conn.prepareStatement("INSERT INTO ts_with_timezone VALUES (1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))") + .executeUpdate() + conn.commit() + sql( s""" |CREATE TEMPORARY VIEW datetime @@ -185,4 +191,11 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo sql("INSERT INTO TABLE datetime1 SELECT * FROM datetime where id = 1") checkRow(sql("SELECT * FROM datetime1 where id = 1").head()) } + + test("SPARK-20557: column type TIMEZONE with TIME STAMP should be recognized") { + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + val rows = dfRead.collect() + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(1).equals("class java.sql.Timestamp")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 0183805d56257..fb877d1ca7639 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -223,6 +223,9 @@ object JdbcUtils extends Logging { case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType case java.sql.Types.TIMESTAMP => TimestampType + case java.sql.Types.TIMESTAMP_WITH_TIMEZONE + => TimestampType + case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType From 5d75b14bf0f4c1f0813287efaabf49797908ed55 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 May 2017 15:31:06 -0700 Subject: [PATCH 03/30] [SPARK-20616] RuleExecutor logDebug of batch results should show diff to start of batch ## What changes were proposed in this pull request? Due to a likely typo, the logDebug msg printing the diff of query plans shows a diff to the initial plan, not diff to the start of batch. ## How was this patch tested? Now the debug message prints the diff between start and end of batch. Author: Juliusz Sompolski Closes #17875 from juliuszsompolski/SPARK-20616. --- .../org/apache/spark/sql/catalyst/rules/RuleExecutor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6fc828f63f152..85b368c862630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -122,7 +122,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { logDebug( s""" |=== Result of Batch ${batch.name} === - |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} + |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { logTrace(s"Batch ${batch.name} has no effect.") From b433acae74887e59f2e237a6284a4ae04fbbe854 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 5 May 2017 21:26:55 -0700 Subject: [PATCH 04/30] [SPARK-20614][PROJECT INFRA] Use the same log4j configuration with Jenkins in AppVeyor ## What changes were proposed in this pull request? Currently, there are flooding logs in AppVeyor (in the console). This has been fine because we can download all the logs. However, (given my observations so far), logs are truncated when there are too many. It has been grown recently and it started to get truncated. For example, see https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark/build/1209-master Even after the log is downloaded, it looks truncated as below: ``` [00:44:21] 17/05/04 18:56:18 INFO TaskSetManager: Finished task 197.0 in stage 601.0 (TID 9211) in 0 ms on localhost (executor driver) (194/200) [00:44:21] 17/05/04 18:56:18 INFO Executor: Running task 199.0 in stage 601.0 (TID 9213) [00:44:21] 17/05/04 18:56:18 INFO Executor: Finished task 198.0 in stage 601.0 (TID 9212). 2473 bytes result sent to driver ... ``` Probably, it looks better to use the same log4j configuration that we are using for SparkR tests in Jenkins(please see https://github.com/apache/spark/blob/fc472bddd1d9c6a28e57e31496c0166777af597e/R/run-tests.sh#L26 and https://github.com/apache/spark/blob/fc472bddd1d9c6a28e57e31496c0166777af597e/R/log4j.properties) ``` # Set everything to be logged to the file target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=R/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN org.eclipse.jetty.LEVEL=WARN ``` ## How was this patch tested? Manually tested with spark-test account - https://ci.appveyor.com/project/spark-test/spark/build/672-r-log4j (there is an example for flaky test here) - https://ci.appveyor.com/project/spark-test/spark/build/673-r-log4j (I re-ran the build). Author: hyukjinkwon Closes #17873 from HyukjinKwon/appveyor-reduce-logs. --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index bbb27589cad09..4d31af70f056e 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -49,7 +49,7 @@ build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package test_script: - - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R + - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R notifications: - provider: Email From cafca54c0ea8bd9c3b80dcbc88d9f2b8d708a026 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 6 May 2017 22:21:19 -0700 Subject: [PATCH 05/30] [SPARK-20557][SQL] Support JDBC data type Time with Time Zone ### What changes were proposed in this pull request? This PR is to support JDBC data type TIME WITH TIME ZONE. It can be converted to TIMESTAMP In addition, before this PR, for unsupported data types, we simply output the type number instead of the type name. ``` java.sql.SQLException: Unsupported type 2014 ``` After this PR, the message is like ``` java.sql.SQLException: Unsupported type TIMESTAMP_WITH_TIMEZONE ``` - Also upgrade the H2 version to `1.4.195` which has the type fix for "TIMESTAMP WITH TIMEZONE". However, it is not fully supported. Thus, we capture the exception, but we still need it to partially test the support of "TIMESTAMP WITH TIMEZONE", because Docker tests are not regularly run. ### How was this patch tested? Added test cases. Author: Xiao Li Closes #17835 from gatorsmile/h2. --- .../sql/jdbc/OracleIntegrationSuite.scala | 2 +- .../sql/jdbc/PostgresIntegrationSuite.scala | 15 ++++++++++++ sql/core/pom.xml | 2 +- .../datasources/jdbc/JdbcUtils.scala | 12 +++++++--- .../spark/sql/internal/CatalogImpl.scala | 1 - .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 24 +++++++++++++++++-- 6 files changed, 48 insertions(+), 8 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 85d4a4a791e6b..f7b1ec34ced76 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -192,7 +192,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo checkRow(sql("SELECT * FROM datetime1 where id = 1").head()) } - test("SPARK-20557: column type TIMEZONE with TIME STAMP should be recognized") { + test("SPARK-20557: column type TIMESTAMP with TIME ZONE should be recognized") { val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) val rows = dfRead.collect() val types = rows(0).toSeq.map(x => x.getClass.toString) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index a1a065a443e67..eb3c458360e7b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -55,6 +55,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + "null, null, null, null, null, " + "null, null, null, null, null, null, null)" ).executeUpdate() + + conn.prepareStatement("CREATE TABLE ts_with_timezone " + + "(id integer, tstz TIMESTAMP WITH TIME ZONE, ttz TIME WITH TIME ZONE)") + .executeUpdate() + conn.prepareStatement("INSERT INTO ts_with_timezone VALUES " + + "(1, TIMESTAMP WITH TIME ZONE '2016-08-12 10:22:31.949271-07', TIME WITH TIME ZONE '17:22:31.949271+00')") + .executeUpdate() } test("Type mapping for various types") { @@ -126,4 +133,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(schema(0).dataType == FloatType) assert(schema(1).dataType == ShortType) } + + test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE should be recognized") { + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + val rows = dfRead.collect() + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index e170133f0f0bf..fe4be963e8184 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -115,7 +115,7 @@ com.h2database h2 - 1.4.183 + 1.4.195 test diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index fb877d1ca7639..71eaab119d75d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} import java.util.Locale import scala.collection.JavaConverters._ @@ -217,11 +217,14 @@ object JdbcUtils extends Logging { case java.sql.Types.OTHER => null case java.sql.Types.REAL => DoubleType case java.sql.Types.REF => StringType + case java.sql.Types.REF_CURSOR => null case java.sql.Types.ROWID => LongType case java.sql.Types.SMALLINT => IntegerType case java.sql.Types.SQLXML => StringType case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType + case java.sql.Types.TIME_WITH_TIMEZONE + => TimestampType case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => TimestampType @@ -229,11 +232,14 @@ object JdbcUtils extends Logging { case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType - case _ => null + case _ => + throw new SQLException("Unrecognized SQL type " + sqlType) // scalastyle:on } - if (answer == null) throw new SQLException("Unsupported type " + sqlType) + if (answer == null) { + throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName) + } answer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index e1049c665a417..142b005850a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel - /** * Internal implementation of the user-facing `Catalog`. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5bd36ec25ccb0..d9f3689411ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.{Date, DriverManager, SQLException, Timestamp} import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec @@ -141,6 +141,15 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.timezone (tz TIMESTAMP WITH TIME ZONE) " + + "AS SELECT '1999-01-08 04:05:06.543543543 GMT-08:00'") + .executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE test.array (ar ARRAY) " + + "AS SELECT '(1, 2, 3)'") + .executeUpdate() + conn.commit() conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() @@ -919,6 +928,17 @@ class JDBCSuite extends SparkFunSuite assert(res === (foobarCnt, 0L, foobarCnt) :: Nil) } + test("unsupported types") { + var e = intercept[SparkException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.TIMEZONE", new Properties()).collect() + }.getMessage + assert(e.contains("java.lang.UnsupportedOperationException: unimplemented")) + e = intercept[SQLException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY", new Properties()).collect() + }.getMessage + assert(e.contains("Unsupported type ARRAY")) + } + test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive From 63d90e7da4913917982c0501d63ccc433a9b6b46 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sat, 6 May 2017 22:28:42 -0700 Subject: [PATCH 06/30] [SPARK-18777][PYTHON][SQL] Return UDF from udf.register ## What changes were proposed in this pull request? - Move udf wrapping code from `functions.udf` to `functions.UserDefinedFunction`. - Return wrapped udf from `catalog.registerFunction` and dependent methods. - Update docstrings in `catalog.registerFunction` and `SQLContext.registerFunction`. - Unit tests. ## How was this patch tested? - Existing unit tests and docstests. - Additional tests covering new feature. Author: zero323 Closes #17831 from zero323/SPARK-18777. --- python/pyspark/sql/catalog.py | 11 ++++++++--- python/pyspark/sql/context.py | 12 ++++++++---- python/pyspark/sql/functions.py | 23 ++++++++++++++--------- python/pyspark/sql/tests.py | 9 +++++++++ 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 41e68a45a6159..5f25dce161963 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -237,23 +237,28 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return: a wrapped :class:`UserDefinedFunction` - >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ udf = UserDefinedFunction(f, returnType, name) self._jsparkSession.udf().registerPython(name, udf._judf) + return udf._wrapped() @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index fdb7abbad4e5f..5197a9e004610 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -185,22 +185,26 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return: a wrapped :class:`UserDefinedFunction` - >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - self.sparkSession.catalog.registerFunction(name, f, returnType) + return self.sparkSession.catalog.registerFunction(name, f, returnType) @ignore_unicode_prefix @since(2.1) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 843ae3816f061..8b3487c3f1083 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1917,6 +1917,19 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + @functools.wraps(self.func) + def wrapper(*args): + return self(*args) + + wrapper.func = self.func + wrapper.returnType = self.returnType + + return wrapper + @since(1.3) def udf(f=None, returnType=StringType()): @@ -1951,15 +1964,7 @@ def udf(f=None, returnType=StringType()): """ def _udf(f, returnType=StringType()): udf_obj = UserDefinedFunction(f, returnType) - - @functools.wraps(f) - def wrapper(*args): - return udf_obj(*args) - - wrapper.func = udf_obj.func - wrapper.returnType = udf_obj.returnType - - return wrapper + return udf_obj._wrapped() # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f644624f7f317..7983bc536fc6c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -436,6 +436,15 @@ def test_udf_with_order_by_and_limit(self): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_udf_registration_returns_udf(self): + df = self.spark.range(10) + add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_three(id) AS plus_three").collect(), + df.select(add_three("id").alias("plus_three")).collect() + ) + def test_wholefile_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", From 37f963ac13ec1bd958c44c7c15b5e8cb6c06cbbc Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sun, 7 May 2017 10:08:06 +0100 Subject: [PATCH 07/30] [SPARK-20518][CORE] Supplement the new blockidsuite unit tests ## What changes were proposed in this pull request? This PR adds the new unit tests to support ShuffleDataBlockId , ShuffleIndexBlockId , TempShuffleBlockId , TempLocalBlockId ## How was this patch tested? The new unit test. Author: caoxuewen Closes #17794 from heary-cao/blockidsuite. --- .../apache/spark/storage/BlockIdSuite.scala | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index 89ed031b6fcd1..f0c521b00b583 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.UUID + import org.apache.spark.SparkFunSuite class BlockIdSuite extends SparkFunSuite { @@ -67,6 +69,32 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("shuffle data") { + val id = ShuffleDataBlockId(4, 5, 6) + assertSame(id, ShuffleDataBlockId(4, 5, 6)) + assertDifferent(id, ShuffleDataBlockId(6, 5, 6)) + assert(id.name === "shuffle_4_5_6.data") + assert(id.asRDDId === None) + assert(id.shuffleId === 4) + assert(id.mapId === 5) + assert(id.reduceId === 6) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle index") { + val id = ShuffleIndexBlockId(7, 8, 9) + assertSame(id, ShuffleIndexBlockId(7, 8, 9)) + assertDifferent(id, ShuffleIndexBlockId(9, 8, 9)) + assert(id.name === "shuffle_7_8_9.index") + assert(id.asRDDId === None) + assert(id.shuffleId === 7) + assert(id.mapId === 8) + assert(id.reduceId === 9) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("broadcast") { val id = BroadcastBlockId(42) assertSame(id, BroadcastBlockId(42)) @@ -101,6 +129,30 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("temp local") { + val id = TempLocalBlockId(new UUID(5, 2)) + assertSame(id, TempLocalBlockId(new UUID(5, 2))) + assertDifferent(id, TempLocalBlockId(new UUID(5, 3))) + assert(id.name === "temp_local_00000000-0000-0005-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 5) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + + test("temp shuffle") { + val id = TempShuffleBlockId(new UUID(1, 2)) + assertSame(id, TempShuffleBlockId(new UUID(1, 2))) + assertDifferent(id, TempShuffleBlockId(new UUID(1, 3))) + assert(id.name === "temp_shuffle_00000000-0000-0001-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 1) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + test("test") { val id = TestBlockId("abc") assertSame(id, TestBlockId("abc")) From 88e6d75072c23fa99d4df00d087d03d8c38e8c69 Mon Sep 17 00:00:00 2001 From: Daniel Li Date: Sun, 7 May 2017 10:09:58 +0100 Subject: [PATCH 08/30] [SPARK-20484][MLLIB] Add documentation to ALS code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR adds documentation to the ALS code. ## How was this patch tested? Existing tests were used. mengxr srowen This contribution is my original work. I have the license to work on this project under the Spark project’s open source license. Author: Daniel Li Closes #17793 from danielyli/spark-20484. --- .../apache/spark/ml/recommendation/ALS.scala | 236 +++++++++++++++--- 1 file changed, 202 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index a20ef72446661..1562bf1beb7e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -774,6 +774,28 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: * Implementation of the ALS algorithm. + * + * This implementation of the ALS factorization algorithm partitions the two sets of factors among + * Spark workers so as to reduce network communication by only sending one copy of each factor + * vector to each Spark worker on each iteration, and only if needed. This is achieved by + * precomputing some information about the ratings matrix to determine which users require which + * item factors and vice versa. See the Scaladoc for `InBlock` for a detailed explanation of how + * the precomputation is done. + * + * In addition, since each iteration of calculating the factor matrices depends on the known + * ratings, which are spread across Spark partitions, a naive implementation would incur + * significant network communication overhead between Spark workers, as the ratings RDD would be + * repeatedly shuffled during each iteration. This implementation reduces that overhead by + * performing the shuffling operation up front, precomputing each partition's ratings dependencies + * and duplicating those values to the appropriate workers before starting iterations to solve for + * the factor matrices. See the Scaladoc for `OutBlock` for a detailed explanation of how the + * precomputation is done. + * + * Note that the term "rating block" is a bit of a misnomer, as the ratings are not partitioned by + * contiguous blocks from the ratings matrix but by a hash function on the rating's location in + * the matrix. If it helps you to visualize the partitions, it is easier to think of the term + * "block" as referring to a subset of an RDD containing the ratings rather than a contiguous + * submatrix of the ratings matrix. */ @DeveloperApi def train[ID: ClassTag]( // scalastyle:ignore @@ -791,32 +813,43 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { checkpointInterval: Int = 10, seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { + require(!ratings.isEmpty(), s"No ratings available from $ratings") require(intermediateRDDStorageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") + val sc = ratings.sparkContext + + // Precompute the rating dependencies of each partition val userPart = new ALSPartitioner(numUserBlocks) val itemPart = new ALSPartitioner(numItemBlocks) - val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) - val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) - val solver = if (nonnegative) new NNLSSolver else new CholeskySolver val blockRatings = partitionRatings(ratings, userPart, itemPart) .persist(intermediateRDDStorageLevel) val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel) - // materialize blockRatings and user blocks - userOutBlocks.count() + userOutBlocks.count() // materialize blockRatings and user blocks val swappedBlockRatings = blockRatings.map { case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) } val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel) - // materialize item blocks - itemOutBlocks.count() + itemOutBlocks.count() // materialize item blocks + + // Encoders for storing each user/item's partition ID and index within its partition using a + // single integer; used as an optimization + val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) + val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) + + // These are the user and item factor matrices that, once trained, are multiplied together to + // estimate the rating matrix. The two matrices are stored in RDDs, partitioned by column such + // that each factor column resides on the same Spark worker as its corresponding user or item. val seedGen = new XORShiftRandom(seed) var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) + + val solver = if (nonnegative) new NNLSSolver else new CholeskySolver + var previousCheckpointFile: Option[String] = None val shouldCheckpoint: Int => Boolean = (iter) => sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0) @@ -830,6 +863,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { logWarning(s"Cannot delete checkpoint file $file:", e) } } + if (implicitPrefs) { for (iter <- 1 to maxIter) { userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) @@ -910,26 +944,154 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { private type FactorBlock = Array[Array[Float]] /** - * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to - * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the - * src factors in this block to send to dst block 0. + * A mapping of the columns of the items factor matrix that are needed when calculating each row + * of the users factor matrix, and vice versa. + * + * Specifically, when calculating a user factor vector, since only those columns of the items + * factor matrix that correspond to the items that that user has rated are needed, we can avoid + * having to repeatedly copy the entire items factor matrix to each worker later in the algorithm + * by precomputing these dependencies for all users, storing them in an RDD of `OutBlock`s. The + * items' dependencies on the columns of the users factor matrix is computed similarly. + * + * =Example= + * + * Using the example provided in the `InBlock` Scaladoc, `userOutBlocks` would look like the + * following: + * + * {{{ + * userOutBlocks.collect() == Seq( + * 0 -> Array(Array(0, 1), Array(0, 1)), + * 1 -> Array(Array(0), Array(0)) + * ) + * }}} + * + * Each value in this map-like sequence is of type `Array[Array[Int]]`. The values in the + * inner array are the ranks of the sorted user IDs in that partition; so in the example above, + * `Array(0, 1)` in partition 0 refers to user IDs 0 and 6, since when all unique user IDs in + * partition 0 are sorted, 0 is the first ID and 6 is the second. The position of each inner + * array in its enclosing outer array denotes the partition number to which item IDs map; in the + * example, the first `Array(0, 1)` is in position 0 of its outer array, denoting item IDs that + * map to partition 0. + * + * In summary, the data structure encodes the following information: + * + * * There are ratings with user IDs 0 and 6 (encoded in `Array(0, 1)`, where 0 and 1 are the + * indices of the user IDs 0 and 6 on partition 0) whose item IDs map to partitions 0 and 1 + * (represented by the fact that `Array(0, 1)` appears in both the 0th and 1st positions). + * + * * There are ratings with user ID 3 (encoded in `Array(0)`, where 0 is the index of the user + * ID 3 on partition 1) whose item IDs map to partitions 0 and 1 (represented by the fact that + * `Array(0)` appears in both the 0th and 1st positions). */ private type OutBlock = Array[Array[Int]] /** - * In-link block for computing src (user/item) factors. This includes the original src IDs - * of the elements within this block as well as encoded dst (item/user) indices and corresponding - * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original - * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. - * For example, if we have an in-link record + * In-link block for computing user and item factor matrices. + * + * The ALS algorithm partitions the columns of the users factor matrix evenly among Spark workers. + * Since each column of the factor matrix is calculated using the known ratings of the correspond- + * ing user, and since the ratings don't change across iterations, the ALS algorithm preshuffles + * the ratings to the appropriate partitions, storing them in `InBlock` objects. + * + * The ratings shuffled by item ID are computed similarly and also stored in `InBlock` objects. + * Note that this means every rating is stored twice, once as shuffled by user ID and once by item + * ID. This is a necessary tradeoff, since in general a rating will not be on the same worker + * when partitioned by user as by item. + * + * =Example= + * + * Say we have a small collection of eight items to offer the seven users in our application. We + * have some known ratings given by the users, as seen in the matrix below: + * + * {{{ + * Items + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * 0 | |0.1| | |0.4| | |0.7| + * +---+---+---+---+---+---+---+---+ + * 1 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * U 2 | | | | | | | | | + * s +---+---+---+---+---+---+---+---+ + * e 3 | |3.1| | |3.4| | |3.7| + * r +---+---+---+---+---+---+---+---+ + * s 4 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * 5 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * 6 | |6.1| | |6.4| | |6.7| + * +---+---+---+---+---+---+---+---+ + * }}} + * + * The ratings are represented as an RDD, passed to the `partitionRatings` method as the `ratings` + * parameter: + * + * {{{ + * ratings.collect() == Seq( + * Rating(0, 1, 0.1f), + * Rating(0, 4, 0.4f), + * Rating(0, 7, 0.7f), + * Rating(3, 1, 3.1f), + * Rating(3, 4, 3.4f), + * Rating(3, 7, 3.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 4, 6.4f), + * Rating(6, 7, 6.7f) + * ) + * }}} * - * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, + * Say that we are using two partitions to calculate each factor matrix: * - * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which - * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). + * {{{ + * val userPart = new ALSPartitioner(2) + * val itemPart = new ALSPartitioner(2) + * val blockRatings = partitionRatings(ratings, userPart, itemPart) + * }}} * - * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can - * compute src factors one after another using only one normal equation instance. + * Ratings are mapped to partitions using the user/item IDs modulo the number of partitions. With + * two partitions, ratings with even-valued user IDs are shuffled to partition 0 while those with + * odd-valued user IDs are shuffled to partition 1: + * + * {{{ + * userInBlocks.collect() == Seq( + * 0 -> Seq( + * // Internally, the class stores the ratings in a more optimized format than + * // a sequence of `Rating`s, but for clarity we show it as such here. + * Rating(0, 1, 0.1f), + * Rating(0, 4, 0.4f), + * Rating(0, 7, 0.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 4, 6.4f), + * Rating(6, 7, 6.7f) + * ), + * 1 -> Seq( + * Rating(3, 1, 3.1f), + * Rating(3, 4, 3.4f), + * Rating(3, 7, 3.7f) + * ) + * ) + * }}} + * + * Similarly, ratings with even-valued item IDs are shuffled to partition 0 while those with + * odd-valued item IDs are shuffled to partition 1: + * + * {{{ + * itemInBlocks.collect() == Seq( + * 0 -> Seq( + * Rating(0, 4, 0.4f), + * Rating(3, 4, 3.4f), + * Rating(6, 4, 6.4f) + * ), + * 1 -> Seq( + * Rating(0, 1, 0.1f), + * Rating(0, 7, 0.7f), + * Rating(3, 1, 3.1f), + * Rating(3, 7, 3.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 7, 6.7f) + * ) + * ) + * }}} * * @param srcIds src ids (ordered) * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and @@ -1026,7 +1188,24 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } /** - * Partitions raw ratings into blocks. + * Groups an RDD of [[Rating]]s by the user partition and item partition to which each `Rating` + * maps according to the given partitioners. The returned pair RDD holds the ratings, encoded in + * a memory-efficient format but otherwise unchanged, keyed by the (user partition ID, item + * partition ID) pair. + * + * Performance note: This is an expensive operation that performs an RDD shuffle. + * + * Implementation note: This implementation produces the same result as the following but + * generates fewer intermediate objects: + * + * {{{ + * ratings.map { r => + * ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) + * }.aggregateByKey(new RatingBlockBuilder)( + * seqOp = (b, r) => b.add(r), + * combOp = (b0, b1) => b0.merge(b1.build())) + * .mapValues(_.build()) + * }}} * * @param ratings raw ratings * @param srcPart partitioner for src IDs @@ -1037,17 +1216,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { ratings: RDD[Rating[ID]], srcPart: Partitioner, dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = { - - /* The implementation produces the same result as the following but generates less objects. - - ratings.map { r => - ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) - }.aggregateByKey(new RatingBlockBuilder)( - seqOp = (b, r) => b.add(r), - combOp = (b0, b1) => b0.merge(b1.build())) - .mapValues(_.build()) - */ - val numPartitions = srcPart.numPartitions * dstPart.numPartitions ratings.mapPartitions { iter => val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID]) @@ -1135,8 +1303,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { def length: Int = srcIds.length /** - * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a - * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. + * Compresses the block into an `InBlock`. The algorithm is the same as converting a sparse + * matrix from coordinate list (COO) format into compressed sparse column (CSC) format. * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. */ def compress(): InBlock[ID] = { From 2cf83c47838115f71419ba5b9296c69ec1d746cd Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Sun, 7 May 2017 10:15:31 +0100 Subject: [PATCH 09/30] [SPARK-7481][BUILD] Add spark-hadoop-cloud module to pull in object store access. ## What changes were proposed in this pull request? Add a new `spark-hadoop-cloud` module and maven profile to pull in object store support from `hadoop-openstack`, `hadoop-aws` and `hadoop-azure` (Hadoop 2.7+) JARs, along with their dependencies, fixing up the dependencies so that everything works, in particular Jackson. It restores `s3n://` access to S3, adds its `s3a://` replacement, OpenStack `swift://` and azure `wasb://`. There's a documentation page, `cloud_integration.md`, which covers the basic details of using Spark with object stores, referring the reader to the supplier's own documentation, with specific warnings on security and the possible mismatch between a store's behavior and that of a filesystem. In particular, users are advised be very cautious when trying to use an object store as the destination of data, and to consult the documentation of the storage supplier and the connector. (this is the successor to #12004; I can't re-open it) ## How was this patch tested? Downstream tests exist in [https://github.com/steveloughran/spark-cloud-examples/tree/master/cloud-examples](https://github.com/steveloughran/spark-cloud-examples/tree/master/cloud-examples) Those verify that the dependencies are sufficient to allow downstream applications to work with s3a, azure wasb and swift storage connectors, and perform basic IO & dataframe operations thereon. All seems well. Manually clean build & verify that assembly contains the relevant aws-* hadoop-* artifacts on Hadoop 2.6; azure on a hadoop-2.7 profile. SBT build: `build/sbt -Phadoop-cloud -Phadoop-2.7 package` maven build `mvn install -Phadoop-cloud -Phadoop-2.7` This PR *does not* update `dev/deps/spark-deps-hadoop-2.7` or `dev/deps/spark-deps-hadoop-2.6`, because unless the hadoop-cloud profile is enabled, no extra JARs show up in the dependency list. The dependency check in Jenkins isn't setting the property, so the new JARs aren't visible. Author: Steve Loughran Author: Steve Loughran Closes #17834 from steveloughran/cloud/SPARK-7481-current. --- assembly/pom.xml | 14 +++ docs/cloud-integration.md | 200 ++++++++++++++++++++++++++++++++ docs/index.md | 1 + docs/rdd-programming-guide.md | 6 +- docs/storage-openstack-swift.md | 38 ++---- hadoop-cloud/pom.xml | 185 +++++++++++++++++++++++++++++ pom.xml | 7 ++ project/SparkBuild.scala | 4 +- 8 files changed, 424 insertions(+), 31 deletions(-) create mode 100644 docs/cloud-integration.md create mode 100644 hadoop-cloud/pom.xml diff --git a/assembly/pom.xml b/assembly/pom.xml index 742a4a1531e71..464af16e46f6e 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -226,5 +226,19 @@ provided + + + + hadoop-cloud + + + org.apache.spark + spark-hadoop-cloud_${scala.binary.version} + ${project.version} + + + diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md new file mode 100644 index 0000000000000..751a192da4ffd --- /dev/null +++ b/docs/cloud-integration.md @@ -0,0 +1,200 @@ +--- +layout: global +displayTitle: Integration with Cloud Infrastructures +title: Integration with Cloud Infrastructures +description: Introduction to cloud storage support in Apache Spark SPARK_VERSION_SHORT +--- + + +* This will become a table of contents (this text will be scraped). +{:toc} + +## Introduction + + +All major cloud providers offer persistent data storage in *object stores*. +These are not classic "POSIX" file systems. +In order to store hundreds of petabytes of data without any single points of failure, +object stores replace the classic filesystem directory tree +with a simpler model of `object-name => data`. To enable remote access, operations +on objects are usually offered as (slow) HTTP REST operations. + +Spark can read and write data in object stores through filesystem connectors implemented +in Hadoop or provided by the infrastructure suppliers themselves. +These connectors make the object stores look *almost* like filesystems, with directories and files +and the classic operations on them such as list, delete and rename. + + +### Important: Cloud Object Stores are Not Real Filesystems + +While the stores appear to be filesystems, underneath +they are still object stores, [and the difference is significant](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/filesystem/introduction.html) + +They cannot be used as a direct replacement for a cluster filesystem such as HDFS +*except where this is explicitly stated*. + +Key differences are: + +* Changes to stored objects may not be immediately visible, both in directory listings and actual data access. +* The means by which directories are emulated may make working with them slow. +* Rename operations may be very slow and, on failure, leave the store in an unknown state. +* Seeking within a file may require new HTTP calls, hurting performance. + +How does this affect Spark? + +1. Reading and writing data can be significantly slower than working with a normal filesystem. +1. Some directory structures may be very inefficient to scan during query split calculation. +1. The output of work may not be immediately visible to a follow-on query. +1. The rename-based algorithm by which Spark normally commits work when saving an RDD, DataFrame or Dataset + is potentially both slow and unreliable. + +For these reasons, it is not always safe to use an object store as a direct destination of queries, or as +an intermediate store in a chain of queries. Consult the documentation of the object store and its +connector to determine which uses are considered safe. + +In particular: *without some form of consistency layer, Amazon S3 cannot +be safely used as the direct destination of work with the normal rename-based committer.* + +### Installation + +With the relevant libraries on the classpath and Spark configured with valid credentials, +objects can be can be read or written by using their URLs as the path to data. +For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create +an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. + +To add the relevant libraries to an application's classpath, include the `hadoop-cloud` +module and its dependencies. + +In Maven, add the following to the `pom.xml` file, assuming `spark.version` +is set to the chosen version of Spark: + +{% highlight xml %} + + ... + + org.apache.spark + hadoop-cloud_2.11 + ${spark.version} + + ... + +{% endhighlight %} + +Commercial products based on Apache Spark generally directly set up the classpath +for talking to cloud infrastructures, in which case this module may not be needed. + +### Authenticating + +Spark jobs must authenticate with the object stores to access data within them. + +1. When Spark is running in a cloud infrastructure, the credentials are usually automatically set up. +1. `spark-submit` reads the `AWS_ACCESS_KEY`, `AWS_SECRET_KEY` +and `AWS_SESSION_TOKEN` environment variables and sets the associated authentication options +for the `s3n` and `s3a` connectors to Amazon S3. +1. In a Hadoop cluster, settings may be set in the `core-site.xml` file. +1. Authentication details may be manually added to the Spark configuration in `spark-default.conf` +1. Alternatively, they can be programmatically set in the `SparkConf` instance used to configure +the application's `SparkContext`. + +*Important: never check authentication secrets into source code repositories, +especially public ones* + +Consult [the Hadoop documentation](https://hadoop.apache.org/docs/current/) for the relevant +configuration and security options. + +## Configuring + +Each cloud connector has its own set of configuration parameters, again, +consult the relevant documentation. + +### Recommended settings for writing to object stores + +For object stores whose consistency model means that rename-based commits are safe +use the `FileOutputCommitter` v2 algorithm for performance: + +``` +spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version 2 +``` + +This does less renaming at the end of a job than the "version 1" algorithm. +As it still uses `rename()` to commit files, it is unsafe to use +when the object store does not have consistent metadata/listings. + +The committer can also be set to ignore failures when cleaning up temporary +files; this reduces the risk that a transient network problem is escalated into a +job failure: + +``` +spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored true +``` + +As storing temporary files can run up charges; delete +directories called `"_temporary"` on a regular basis to avoid this. + +### Parquet I/O Settings + +For optimal performance when working with Parquet data use the following settings: + +``` +spark.hadoop.parquet.enable.summary-metadata false +spark.sql.parquet.mergeSchema false +spark.sql.parquet.filterPushdown true +spark.sql.hive.metastorePartitionPruning true +``` + +These minimise the amount of data read during queries. + +### ORC I/O Settings + +For best performance when working with ORC data, use these settings: + +``` +spark.sql.orc.filterPushdown true +spark.sql.orc.splits.include.file.footer true +spark.sql.orc.cache.stripe.details.size 10000 +spark.sql.hive.metastorePartitionPruning true +``` + +Again, these minimise the amount of data read during queries. + +## Spark Streaming and Object Storage + +Spark Streaming can monitor files added to object stores, by +creating a `FileInputDStream` to monitor a path in the store through a call to +`StreamingContext.textFileStream()`. + +1. The time to scan for new files is proportional to the number of files +under the path, not the number of *new* files, so it can become a slow operation. +The size of the window needs to be set to handle this. + +1. Files only appear in an object store once they are completely written; there +is no need for a worklow of write-then-rename to ensure that files aren't picked up +while they are still being written. Applications can write straight to the monitored directory. + +1. Streams should only be checkpointed to an store implementing a fast and +atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. + +## Further Reading + +Here is the documentation on the standard connectors both from Apache and the cloud providers. + +* [OpenStack Swift](https://hadoop.apache.org/docs/current/hadoop-openstack/index.html). Hadoop 2.6+ +* [Azure Blob Storage](https://hadoop.apache.org/docs/current/hadoop-aws/tools/hadoop-aws/index.html). Since Hadoop 2.7 +* [Azure Data Lake](https://hadoop.apache.org/docs/current/hadoop-azure-datalake/index.html). Since Hadoop 2.8 +* [Amazon S3 via S3A and S3N](https://hadoop.apache.org/docs/current/hadoop-aws/tools/hadoop-aws/index.html). Hadoop 2.6+ +* [Amazon EMR File System (EMRFS)](https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-fs.html). From Amazon +* [Google Cloud Storage Connector for Spark and Hadoop](https://cloud.google.com/hadoop/google-cloud-storage-connector). From Google + + diff --git a/docs/index.md b/docs/index.md index ad4f24ff1a5d1..960b968454d0e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -126,6 +126,7 @@ options for deployment: * [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware * Integration with other storage systems: + * [Cloud Infrastructures](cloud-integration.html) * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](http://spark.apache.org/contributing.html) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index e2bf2d7ca77ca..52e59df9990e9 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -323,7 +323,7 @@ One important parameter for parallel collections is the number of *partitions* t Spark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight scala %} scala> val distFile = sc.textFile("data.txt") @@ -356,7 +356,7 @@ Apart from text files, Spark's Scala API also supports several other data format Spark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight java %} JavaRDD distFile = sc.textFile("data.txt"); @@ -388,7 +388,7 @@ Apart from text files, Spark's Java API also supports several other data formats PySpark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight python %} >>> distFile = sc.textFile("data.txt") diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index c39ef1ce59e1c..f4bb2353e3c49 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -8,7 +8,8 @@ same URI formats as in Hadoop. You can specify a path in Swift as input through URI of the form swift://container.PROVIDER/path. You will also need to set your Swift security credentials, through core-site.xml or via SparkContext.hadoopConfiguration. -Current Swift driver requires Swift to use Keystone authentication method. +The current Swift driver requires Swift to use the Keystone authentication method, or +its Rackspace-specific predecessor. # Configuring Swift for Better Data Locality @@ -19,41 +20,30 @@ Although not mandatory, it is recommended to configure the proxy server of Swift # Dependencies -The Spark application should include hadoop-openstack dependency. +The Spark application should include hadoop-openstack dependency, which can +be done by including the `hadoop-cloud` module for the specific version of spark used. For example, for Maven support, add the following to the pom.xml file: {% highlight xml %} ... - org.apache.hadoop - hadoop-openstack - 2.3.0 + org.apache.spark + hadoop-cloud_2.11 + ${spark.version} ... {% endhighlight %} - # Configuration Parameters Create core-site.xml and place it inside Spark's conf directory. -There are two main categories of parameters that should to be configured: declaration of the -Swift driver and the parameters that are required by Keystone. +The main category of parameters that should be configured are the authentication parameters +required by Keystone. -Configuration of Hadoop to use Swift File system achieved via - - - - - - - -
Property NameValue
fs.swift.implorg.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem
- -Additional parameters required by Keystone (v2.0) and should be provided to the Swift driver. Those -parameters will be used to perform authentication in Keystone to access Swift. The following table -contains a list of Keystone mandatory parameters. PROVIDER can be any name. +The following table contains a list of Keystone mandatory parameters. PROVIDER can be +any (alphanumeric) name. @@ -94,7 +84,7 @@ contains a list of Keystone mandatory parameters. PROVIDER can be a - +
Property NameMeaningRequired
fs.swift.service.PROVIDER.publicIndicates if all URLs are publicIndicates whether to use the public (off cloud) or private (in cloud; no transfer fees) endpoints Mandatory
@@ -104,10 +94,6 @@ defined for tenant test. Then core-site.xml should inc {% highlight xml %} - - fs.swift.impl - org.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem - fs.swift.service.SparkTest.auth.url http://127.0.0.1:5000/v2.0/tokens diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml new file mode 100644 index 0000000000000..aa36dd4774d86 --- /dev/null +++ b/hadoop-cloud/pom.xml @@ -0,0 +1,185 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../pom.xml + + + spark-hadoop-cloud_2.11 + jar + Spark Project Cloud Integration through Hadoop Libraries + + Contains support for cloud infrastructures, specifically the Hadoop JARs and + transitive dependencies needed to interact with the infrastructures, + making everything consistent with Spark's other dependencies. + + + hadoop-cloud + + + + + + org.apache.hadoop + hadoop-aws + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + commons-logging + commons-logging + + + org.codehaus.jackson + jackson-mapper-asl + + + org.codehaus.jackson + jackson-core-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + + + org.apache.hadoop + hadoop-openstack + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + commons-logging + commons-logging + + + junit + junit + + + org.mockito + mockito-all + + + + + + + joda-time + joda-time + ${hadoop.deps.scope} + + + + com.fasterxml.jackson.core + jackson-databind + ${hadoop.deps.scope} + + + com.fasterxml.jackson.core + jackson-annotations + ${hadoop.deps.scope} + + + com.fasterxml.jackson.dataformat + jackson-dataformat-cbor + ${fasterxml.jackson.version} + + + + org.apache.httpcomponents + httpclient + ${hadoop.deps.scope} + + + + org.apache.httpcomponents + httpcore + ${hadoop.deps.scope} + + + + + + + hadoop-2.7 + + + + + + org.apache.hadoop + hadoop-azure + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + + + + diff --git a/pom.xml b/pom.xml index a1a1817e2f7d3..0533a8dcf2e0a 100644 --- a/pom.xml +++ b/pom.xml @@ -2546,6 +2546,13 @@ + + hadoop-cloud + + hadoop-cloud + + + scala-2.10 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e52baf51aed1a..b5362ec1ae452 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -57,9 +57,9 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(mesos, yarn, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests) = + streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = Seq("mesos", "yarn", "ganglia-lgpl", "streaming-kinesis-asl", - "docker-integration-tests").map(ProjectRef(buildLocation, _)) + "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") From 7087e01194964a1aad0b45bdb41506a17100eacf Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 7 May 2017 13:10:10 -0700 Subject: [PATCH 10/30] [SPARK-20543][SPARKR][FOLLOWUP] Don't skip tests on AppVeyor ## What changes were proposed in this pull request? add environment ## How was this patch tested? wait for appveyor run Author: Felix Cheung Closes #17878 from felixcheung/appveyorrcran. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- appveyor.yml | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 47cc34a6c5b75..232246d6be9b4 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -3387,7 +3387,7 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { - skip_on_cran() + skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. diff --git a/appveyor.yml b/appveyor.yml index 4d31af70f056e..58c2e98289e96 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,6 +48,9 @@ install: build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package +environment: + NOT_CRAN: true + test_script: - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R @@ -56,4 +59,3 @@ notifications: on_build_success: false on_build_failure: false on_build_status_changed: false - From 500436b4368207db9e9b9cef83f9c11d33e31e1a Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Sun, 7 May 2017 13:56:13 -0700 Subject: [PATCH 11/30] [MINOR][SQL][DOCS] Improve unix_timestamp's scaladoc (and typo hunting) ## What changes were proposed in this pull request? * Docs are consistent (across different `unix_timestamp` variants and their internal expressions) * typo hunting ## How was this patch tested? local build Author: Jacek Laskowski Closes #17801 from jaceklaskowski/unix_timestamp. --- .../expressions/datetimeExpressions.scala | 6 ++--- .../sql/catalyst/util/DateTimeUtils.scala | 2 +- .../org/apache/spark/sql/functions.scala | 26 ++++++++++++------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index bb8fd5032d63d..a98cd33f2780c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -488,7 +488,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti * Deterministic version of [[UnixTimestamp]], must have at least one parameter. */ @ExpressionDescription( - usage = "_FUNC_(expr[, pattern]) - Returns the UNIX timestamp of the give time.", + usage = "_FUNC_(expr[, pattern]) - Returns the UNIX timestamp of the given time.", extended = """ Examples: > SELECT _FUNC_('2016-04-08', 'yyyy-MM-dd'); @@ -1225,8 +1225,8 @@ case class ParseToTimestamp(left: Expression, format: Expression, child: Express extends RuntimeReplaceable { def this(left: Expression, format: Expression) = { - this(left, format, Cast(UnixTimestamp(left, format), TimestampType)) -} + this(left, format, Cast(UnixTimestamp(left, format), TimestampType)) + } override def flatArguments: Iterator[Any] = Iterator(left, format) override def sql: String = s"$prettyName(${left.sql}, ${format.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index eb6aad5b2d2bb..6c1592fd8881d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -423,7 +423,7 @@ object DateTimeUtils { } /** - * Parses a given UTF8 date string to the corresponding a corresponding [[Int]] value. + * Parses a given UTF8 date string to a corresponding [[Int]] value. * The return type is [[Option]] in order to distinguish between 0 and null. The following * formats are allowed: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f07e04368389f..987011edfe1e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2491,10 +2491,10 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All - * pattern letters of `java.text.SimpleDateFormat` can be used. + * A pattern `dd.MM.yyyy` would return a string like `18.03.1993`. + * All pattern letters of `java.text.SimpleDateFormat` can be used. * - * @note Use when ever possible specialized functions like [[year]]. These benefit from a + * @note Use specialized functions like [[year]] whenever possible as they benefit from a * specialized implementation. * * @group datetime_funcs @@ -2647,7 +2647,11 @@ object functions { } /** - * Gets current Unix timestamp in seconds. + * Returns the current Unix timestamp (in seconds). + * + * @note All calls of `unix_timestamp` within the same query return the same value + * (i.e. the current timestamp is calculated at the start of query evaluation). + * * @group datetime_funcs * @since 1.5.0 */ @@ -2657,7 +2661,9 @@ object functions { /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), - * using the default timezone and the default locale, return null if fail. + * using the default timezone and the default locale. + * Returns `null` if fails. + * * @group datetime_funcs * @since 1.5.0 */ @@ -2666,13 +2672,15 @@ object functions { } /** - * Convert time string with given pattern - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix time stamp (in seconds), return null if fail. + * Converts time string with given pattern to Unix timestamp (in seconds). + * Returns `null` if fails. + * + * @see + * Customizing Formats * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) } + def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } /** * Convert time string to a Unix timestamp (in seconds). From 1f73d3589a84b78473598c17ac328a9805896778 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 7 May 2017 16:24:42 -0700 Subject: [PATCH 12/30] [SPARK-20550][SPARKR] R wrapper for Dataset.alias ## What changes were proposed in this pull request? - Add SparkR wrapper for `Dataset.alias`. - Adjust roxygen annotations for `functions.alias` (including example usage). ## How was this patch tested? Unit tests, `check_cran.sh`. Author: zero323 Closes #17825 from zero323/SPARK-20550. --- R/pkg/R/DataFrame.R | 24 +++++++++++++++++++++++ R/pkg/R/column.R | 16 +++++++-------- R/pkg/R/generics.R | 11 +++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 10 ++++++++++ 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1c8869202f677..b56dddcb9f2ef 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3745,3 +3745,27 @@ setMethod("hint", jdf <- callJMethod(x@sdf, "hint", name, parameters) dataFrame(jdf) }) + +#' alias +#' +#' @aliases alias,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname alias +#' @name alias +#' @export +#' @examples +#' \dontrun{ +#' df <- alias(createDataFrame(mtcars), "mtcars") +#' avg_mpg <- alias(agg(groupBy(df, df$cyl), avg(df$mpg)), "avg_mpg") +#' +#' head(select(df, column("mtcars.mpg"))) +#' head(join(df, avg_mpg, column("mtcars.cyl") == column("avg_mpg.cyl"))) +#' } +#' @note alias(SparkDataFrame) since 2.3.0 +setMethod("alias", + signature(object = "SparkDataFrame"), + function(object, data) { + stopifnot(is.character(data)) + sdf <- callJMethod(object@sdf, "alias", data) + dataFrame(sdf) + }) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 147ee4b6887b9..574078012adad 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -130,19 +130,19 @@ createMethods <- function() { createMethods() -#' alias -#' -#' Set a new name for a column -#' -#' @param object Column to rename -#' @param data new name to use -#' #' @rdname alias #' @name alias #' @aliases alias,Column-method #' @family colum_func #' @export -#' @note alias since 1.4.0 +#' @examples \dontrun{ +#' df <- createDataFrame(iris) +#' +#' head(select( +#' df, alias(df$Sepal_Length, "slength"), alias(df$Petal_Length, "plength") +#' )) +#' } +#' @note alias(Column) since 1.4.0 setMethod("alias", signature(object = "Column"), function(object, data) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e835ef3e4f40d..3c84bf8a4803e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -387,6 +387,17 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @export setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +#' alias +#' +#' Returns a new SparkDataFrame or a Column with an alias set. Equivalent to SQL "AS" keyword. +#' +#' @name alias +#' @rdname alias +#' @param object x a SparkDataFrame or a Column +#' @param data new name to use +#' @return a SparkDataFrame or a Column +NULL + #' @rdname arrange #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 232246d6be9b4..0856bab5686c5 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1223,6 +1223,16 @@ test_that("select with column", { expect_equal(columns(df4), c("name", "age")) expect_equal(count(df4), 3) + # Test select with alias + df5 <- alias(df, "table") + + expect_equal(columns(select(df5, column("table.name"))), "name") + expect_equal(columns(select(df5, "table.name")), "name") + + # Test that stats::alias is not masked + expect_is(alias(aov(yield ~ block + N * P * K, npk)), "listof") + + expect_error(select(df, c("name", "age"), "name"), "To select multiple columns, use a character vector or list for col") }) From f53a820721fe0525c275e2bb4415c20909c42dc3 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 8 May 2017 10:58:27 +0800 Subject: [PATCH 13/30] [SPARK-16931][PYTHON][SQL] Add Python wrapper for bucketBy ## What changes were proposed in this pull request? Adds Python wrappers for `DataFrameWriter.bucketBy` and `DataFrameWriter.sortBy` ([SPARK-16931](https://issues.apache.org/jira/browse/SPARK-16931)) ## How was this patch tested? Unit tests covering new feature. __Note__: Based on work of GregBowyer (f49b9a23468f7af32cb53d2b654272757c151725) CC HyukjinKwon Author: zero323 Author: Greg Bowyer Closes #17077 from zero323/SPARK-16931. --- python/pyspark/sql/readwriter.py | 57 ++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 54 ++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 960fb882cf901..90ce8f81eb7fd 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -563,6 +563,63 @@ def partitionBy(self, *cols): self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) return self + @since(2.3) + def bucketBy(self, numBuckets, col, *cols): + """Buckets the output by the given columns.If specified, + the output is laid out on the file system similar to Hive's bucketing scheme. + + :param numBuckets: the number of buckets to save + :param col: a name of a column, or a list of names. + :param cols: additional names (optional). If `col` is a list it should be empty. + + .. note:: Applicable for file-based data sources in combination with + :py:meth:`DataFrameWriter.saveAsTable`. + + >>> (df.write.format('parquet') + ... .bucketBy(100, 'year', 'month') + ... .mode("overwrite") + ... .saveAsTable('bucketed_table')) + """ + if not isinstance(numBuckets, int): + raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets))) + + if isinstance(col, (list, tuple)): + if cols: + raise ValueError("col is a {0} but cols are not empty".format(type(col))) + + col, cols = col[0], col[1:] + + if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + raise TypeError("all names should be `str`") + + self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols)) + return self + + @since(2.3) + def sortBy(self, col, *cols): + """Sorts the output in each bucket by the given columns on the file system. + + :param col: a name of a column, or a list of names. + :param cols: additional names (optional). If `col` is a list it should be empty. + + >>> (df.write.format('parquet') + ... .bucketBy(100, 'year', 'month') + ... .sortBy('day') + ... .mode("overwrite") + ... .saveAsTable('sorted_bucketed_table')) + """ + if isinstance(col, (list, tuple)): + if cols: + raise ValueError("col is a {0} but cols are not empty".format(type(col))) + + col, cols = col[0], col[1:] + + if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + raise TypeError("all names should be `str`") + + self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols)) + return self + @since(1.4) def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7983bc536fc6c..e3fe01eae243f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -211,6 +211,12 @@ def test_sqlcontext_reuses_sparksession(self): sqlContext2 = SQLContext(self.sc) self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) + def tearDown(self): + super(SQLTests, self).tearDown() + + # tear down test_bucketed_write state + self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket") + def test_row_should_be_read_only(self): row = Row(a=1, b=2) self.assertEqual(1, row.a) @@ -2196,6 +2202,54 @@ def test_BinaryType_serialization(self): df = self.spark.createDataFrame(data, schema=schema) df.collect() + def test_bucketed_write(self): + data = [ + (1, "foo", 3.0), (2, "foo", 5.0), + (3, "bar", -1.0), (4, "bar", 6.0), + ] + df = self.spark.createDataFrame(data, ["x", "y", "z"]) + + def count_bucketed_cols(names, table="pyspark_bucket"): + """Given a sequence of column names and a table name + query the catalog and return number o columns which are + used for bucketing + """ + cols = self.spark.catalog.listColumns(table) + num = len([c for c in cols if c.name in names and c.isBucket]) + return num + + # Test write with one bucketing column + df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write two bucketing columns + df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort + df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with a list of columns + df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with a list of columns + (df.write.bucketBy(2, "x") + .sortBy(["y", "z"]) + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with multiple columns + (df.write.bucketBy(2, "x") + .sortBy("y", "z") + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + class HiveSparkSubmitTests(SparkSubmitTests): From 22691556e5f0dfbac81b8cc9ca0a67c70c1711ca Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 8 May 2017 12:16:00 +0900 Subject: [PATCH 14/30] [SPARK-12297][SQL] Hive compatibility for Parquet Timestamps ## What changes were proposed in this pull request? This change allows timestamps in parquet-based hive table to behave as a "floating time", without a timezone, as timestamps are for other file formats. If the storage timezone is the same as the session timezone, this conversion is a no-op. When data is read from a hive table, the table property is *always* respected. This allows spark to not change behavior when reading old data, but read newly written data correctly (whatever the source of the data is). Spark inherited the original behavior from Hive, but Hive is also updating behavior to use the same scheme in HIVE-12767 / HIVE-16231. The default for Spark remains unchanged; created tables do not include the new table property. This will only apply to hive tables; nothing is added to parquet metadata to indicate the timezone, so data that is read or written directly from parquet files will never have any conversions applied. ## How was this patch tested? Added a unit test which creates tables, reads and writes data, under a variety of permutations (different storage timezones, different session timezones, vectorized reading on and off). Author: Imran Rashid Closes #16781 from squito/SPARK-12297. --- .../sql/catalyst/catalog/interface.scala | 4 +- .../sql/catalyst/util/DateTimeUtils.scala | 5 + .../parquet/VectorizedColumnReader.java | 28 +- .../VectorizedParquetRecordReader.java | 6 +- .../spark/sql/execution/command/tables.scala | 8 +- .../parquet/ParquetFileFormat.scala | 2 + .../parquet/ParquetReadSupport.scala | 3 +- .../parquet/ParquetRecordMaterializer.scala | 9 +- .../parquet/ParquetRowConverter.scala | 53 ++- .../parquet/ParquetWriteSupport.scala | 25 +- .../spark/sql/hive/HiveExternalCatalog.scala | 11 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 12 +- .../hive/ParquetHiveCompatibilitySuite.scala | 379 +++++++++++++++++- 13 files changed, 516 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index cc0cbba275b81..c39017ebbfe60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -132,10 +132,10 @@ case class CatalogTablePartition( /** * Given the partition schema, returns a row with that schema holding the partition values. */ - def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { + def toRow(partitionSchema: StructType, defaultTimeZoneId: String): InternalRow = { val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) val timeZoneId = caseInsensitiveProperties.getOrElse( - DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) + DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId) InternalRow.fromSeq(partitionSchema.map { field => val partValue = if (spec(field.name) == ExternalCatalogUtils.DEFAULT_PARTITION_NAME) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 6c1592fd8881d..bf596fa0a89db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -498,6 +498,11 @@ object DateTimeUtils { false } + lazy val validTimezones = TimeZone.getAvailableIDs().toSet + def isValidTimezone(timezoneId: String): Boolean = { + validTimezones.contains(timezoneId) + } + /** * Returns the microseconds since year zero (-17999) from microseconds since epoch. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 9d641b528723a..dabbc2b6387e4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.TimeZone; +import org.apache.hadoop.conf.Configuration; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -30,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -90,11 +93,30 @@ public class VectorizedColumnReader { private final PageReader pageReader; private final ColumnDescriptor descriptor; + private final TimeZone storageTz; + private final TimeZone sessionTz; - public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader, + Configuration conf) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + // The conf is sometimes null in tests. + String sessionTzString = + conf == null ? null : conf.get(SQLConf.SESSION_LOCAL_TIMEZONE().key()); + if (sessionTzString == null || sessionTzString.isEmpty()) { + sessionTz = DateTimeUtils.defaultTimeZone(); + } else { + sessionTz = TimeZone.getTimeZone(sessionTzString); + } + String storageTzString = + conf == null ? null : conf.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY()); + if (storageTzString == null || storageTzString.isEmpty()) { + storageTz = sessionTz; + } else { + storageTz = TimeZone.getTimeZone(storageTzString); + } this.maxDefLevel = descriptor.getMaxDefinitionLevel(); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); @@ -289,7 +311,7 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, // TODO: Convert dictionary of Binaries to dictionary of Longs if (!column.isNullAt(i)) { Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); - column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v)); + column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v, sessionTz, storageTz)); } } } else { @@ -422,7 +444,7 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, // Read 12 bytes for INT96 - ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12))); + ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12), sessionTz, storageTz)); } else { column.putNull(rowId + i); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f2291..d8974ddf24704 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.column.ColumnDescriptor; @@ -95,6 +96,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private boolean returnColumnarBatch; + private Configuration conf; + /** * The default config on whether columnarBatch should be offheap. */ @@ -107,6 +110,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException, UnsupportedOperationException { super.initialize(inputSplit, taskAttemptContext); + this.conf = taskAttemptContext.getConfiguration(); initializeInternal(); } @@ -277,7 +281,7 @@ private void checkEndOfRowGroup() throws IOException { for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader(columns.get(i), - pages.getPageReader(columns.get(i))); + pages.getPageReader(columns.get(i)), conf); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index ebf03e1bf8869..5843c5b56d44c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import scala.util.Try -import org.apache.commons.lang3.StringEscapeUtils import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -37,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -74,6 +73,10 @@ case class CreateTableLikeCommand( sourceTableDesc.provider } + val properties = sourceTableDesc.properties.filter { case (k, _) => + k == ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + } + // If the location is specified, we create an external table internally. // Otherwise create a managed table. val tblType = if (location.isEmpty) CatalogTableType.MANAGED else CatalogTableType.EXTERNAL @@ -86,6 +89,7 @@ case class CreateTableLikeCommand( locationUri = location.map(CatalogUtils.stringToURI(_))), schema = sourceTableDesc.schema, provider = newProvider, + properties = properties, partitionColumnNames = sourceTableDesc.partitionColumnNames, bucketSpec = sourceTableDesc.bucketSpec) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 2f3a2c62b912c..8113768cd793f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -632,4 +632,6 @@ object ParquetFileFormat extends Logging { Failure(cause) }.toOption } + + val PARQUET_TIMEZONE_TABLE_PROPERTY = "parquet.mr.int96.write.zone" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index f1a35dd8a6200..bf395a0bef745 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -95,7 +95,8 @@ private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Lo new ParquetRecordMaterializer( parquetRequestedSchema, ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetSchemaConverter(conf)) + new ParquetSchemaConverter(conf), + conf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index 4e49a0dac97c0..df041996cdea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import org.apache.hadoop.conf.Configuration import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType @@ -29,13 +30,17 @@ import org.apache.spark.sql.types.StructType * @param parquetSchema Parquet schema of the records to be read * @param catalystSchema Catalyst schema of the rows to be constructed * @param schemaConverter A Parquet-Catalyst schema converter that helps initializing row converters + * @param hadoopConf hadoop Configuration for passing extra params for parquet conversion */ private[parquet] class ParquetRecordMaterializer( - parquetSchema: MessageType, catalystSchema: StructType, schemaConverter: ParquetSchemaConverter) + parquetSchema: MessageType, + catalystSchema: StructType, + schemaConverter: ParquetSchemaConverter, + hadoopConf: Configuration) extends RecordMaterializer[UnsafeRow] { private val rootConverter = - new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, NoopUpdater) + new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, hadoopConf, NoopUpdater) override def getCurrentRecord: UnsafeRow = rootConverter.currentRecord diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 32e6c60cd9766..d52ff62d93b26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder +import java.util.TimeZone import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.conf.Configuration import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type} @@ -34,6 +36,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -117,12 +120,14 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd * @param parquetType Parquet schema of Parquet records * @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined * types should have been expanded. + * @param hadoopConf a hadoop Configuration for passing any extra parameters for parquet conversion * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class ParquetRowConverter( schemaConverter: ParquetSchemaConverter, parquetType: GroupType, catalystType: StructType, + hadoopConf: Configuration, updater: ParentContainerUpdater) extends ParquetGroupConverter(updater) with Logging { @@ -261,18 +266,18 @@ private[parquet] class ParquetRowConverter( case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + val sessionTzString = hadoopConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key) + val sessionTz = Option(sessionTzString).map(TimeZone.getTimeZone(_)) + .getOrElse(DateTimeUtils.defaultTimeZone()) + val storageTzString = hadoopConf.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + val storageTz = Option(storageTzString).map(TimeZone.getTimeZone(_)).getOrElse(sessionTz) new ParquetPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { - assert( - value.length() == 12, - "Timestamps (with nanoseconds) are expected to be stored in 12-byte long binaries, " + - s"but got a ${value.length()}-byte binary.") - - val buf = value.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) - val timeOfDayNanos = buf.getLong - val julianDay = buf.getInt - updater.setLong(DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos)) + val timestamp = ParquetRowConverter.binaryToSQLTimestamp(value, sessionTz = sessionTz, + storageTz = storageTz) + updater.setLong(timestamp) } } @@ -302,7 +307,7 @@ private[parquet] class ParquetRowConverter( case t: StructType => new ParquetRowConverter( - schemaConverter, parquetType.asGroupType(), t, new ParentContainerUpdater { + schemaConverter, parquetType.asGroupType(), t, hadoopConf, new ParentContainerUpdater { override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) @@ -651,6 +656,7 @@ private[parquet] class ParquetRowConverter( } private[parquet] object ParquetRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without @@ -673,12 +679,35 @@ private[parquet] object ParquetRowConverter { unscaled } - def binaryToSQLTimestamp(binary: Binary): SQLTimestamp = { + /** + * Converts an int96 to a SQLTimestamp, given both the storage timezone and the local timezone. + * The timestamp is really meant to be interpreted as a "floating time", but since we + * actually store it as micros since epoch, why we have to apply a conversion when timezones + * change. + * + * @param binary a parquet Binary which holds one int96 + * @param sessionTz the session timezone. This will be used to determine how to display the time, + * and compute functions on the timestamp which involve a timezone, eg. extract + * the hour. + * @param storageTz the timezone which was used to store the timestamp. This should come from the + * timestamp table property, or else assume its the same as the sessionTz + * @return a timestamp (millis since epoch) which will render correctly in the sessionTz + */ + def binaryToSQLTimestamp( + binary: Binary, + sessionTz: TimeZone, + storageTz: TimeZone): SQLTimestamp = { assert(binary.length() == 12, s"Timestamps (with nanoseconds) are expected to be stored in" + s" 12-byte long binaries. Found a ${binary.length()}-byte binary instead.") val buffer = binary.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) val timeOfDayNanos = buffer.getLong val julianDay = buffer.getInt - DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + val utcEpochMicros = DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + // avoid expensive time logic if possible. + if (sessionTz.getID() != storageTz.getID()) { + DateTimeUtils.convertTz(utcEpochMicros, sessionTz, storageTz) + } else { + utcEpochMicros + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 38b0e33937f3c..679ed8e361b74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.{ByteBuffer, ByteOrder} import java.util +import java.util.TimeZone import scala.collection.JavaConverters.mapAsJavaMapConverter @@ -75,6 +76,9 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // Reusable byte array used to write decimal values private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) + private var storageTz: TimeZone = _ + private var sessionTz: TimeZone = _ + override def init(configuration: Configuration): WriteContext = { val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) this.schema = StructType.fromString(schemaString) @@ -91,6 +95,19 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + val sessionTzString = configuration.get(SQLConf.SESSION_LOCAL_TIMEZONE.key) + sessionTz = if (sessionTzString == null || sessionTzString == "") { + TimeZone.getDefault() + } else { + TimeZone.getTimeZone(sessionTzString) + } + val storageTzString = configuration.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + storageTz = if (storageTzString == null || storageTzString == "") { + sessionTz + } else { + TimeZone.getTimeZone(storageTzString) + } val messageType = new ParquetSchemaConverter(configuration).convert(schema) val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava @@ -178,7 +195,13 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // NOTE: Starting from Spark 1.5, Spark SQL `TimestampType` only has microsecond // precision. Nanosecond parts of timestamp values read from INT96 are simply stripped. - val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) + val rawMicros = row.getLong(ordinal) + val adjustedMicros = if (sessionTz.getID() == storageTz.getID()) { + rawMicros + } else { + DateTimeUtils.convertTz(rawMicros, storageTz, sessionTz) + } + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(adjustedMicros) val buf = ByteBuffer.wrap(timestampBuffer) buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ba48facff2933..8fef467f5f5cb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -39,9 +39,10 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -224,6 +225,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat throw new TableAlreadyExistsException(db = db, table = table) } + val tableTz = tableDefinition.properties.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + tableTz.foreach { tz => + if (!DateTimeUtils.isValidTimezone(tz)) { + throw new AnalysisException(s"Cannot set" + + s" ${ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY} to invalid timezone $tz") + } + } + if (tableDefinition.tableType == VIEW) { client.createTable(tableDefinition, ignoreIfExists) } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b98066cb76c8..e0b565c0d79a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ @@ -174,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, fileFormat = fileFormat, - options = options)(sparkSession = sparkSession) + options = options ++ getStorageTzOptions(relation))(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) tableRelationCache.put(tableIdentifier, created) created @@ -201,7 +202,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log userSpecifiedSchema = Option(dataSchema), // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, - options = options, + options = options ++ getStorageTzOptions(relation), className = fileType).resolveRelation(), table = updatedTable) @@ -222,6 +223,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log result.copy(output = newOutput) } + private def getStorageTzOptions(relation: CatalogRelation): Map[String, String] = { + // We add the table timezone to the relation options, which automatically gets injected into the + // hadoopConf for the Parquet Converters + val storageTzKey = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + relation.tableMeta.properties.get(storageTzKey).map(storageTzKey -> _).toMap + } + private def inferIfNeeded( relation: CatalogRelation, options: Map[String, String], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 05b6059472f59..2bfd63d9b56e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -17,12 +17,22 @@ package org.apache.spark.sql.hive +import java.io.File +import java.net.URLDecoder import java.sql.Timestamp +import java.util.TimeZone -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName + +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.parquet.{ParquetCompatibilityTest, ParquetFileFormat} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StringType, StructType, TimestampType} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { /** @@ -141,4 +151,369 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi Row(Seq(Row(1))), "ARRAY>") } + + val testTimezones = Seq( + "UTC" -> "UTC", + "LA" -> "America/Los_Angeles", + "Berlin" -> "Europe/Berlin" + ) + // Check creating parquet tables with timestamps, writing data into them, and reading it back out + // under a variety of conditions: + // * tables with explicit tz and those without + // * altering table properties directly + // * variety of timezones, local & non-local + val sessionTimezones = testTimezones.map(_._2).map(Some(_)) ++ Seq(None) + sessionTimezones.foreach { sessionTzOpt => + val sparkSession = spark.newSession() + sessionTzOpt.foreach { tz => sparkSession.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, tz) } + testCreateWriteRead(sparkSession, "no_tz", None, sessionTzOpt) + val localTz = TimeZone.getDefault.getID() + testCreateWriteRead(sparkSession, "local", Some(localTz), sessionTzOpt) + // check with a variety of timezones. The unit tests currently are configured to always use + // America/Los_Angeles, but even if they didn't, we'd be sure to cover a non-local timezone. + testTimezones.foreach { case (tableName, zone) => + if (zone != localTz) { + testCreateWriteRead(sparkSession, tableName, Some(zone), sessionTzOpt) + } + } + } + + private def testCreateWriteRead( + sparkSession: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + testCreateAlterTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + testWriteTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + testReadTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + } + + private def checkHasTz(spark: SparkSession, table: String, tz: Option[String]): Unit = { + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier(table)) + assert(tableMetadata.properties.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) === tz) + } + + private def testCreateAlterTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + test(s"SPARK-12297: Create and Alter Parquet tables and timezones; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + withTable(baseTable, s"like_$baseTable", s"select_$baseTable", s"partitioned_$baseTable") { + // If we ever add a property to set the table timezone by default, defaultTz would change + val defaultTz = None + // check that created tables have correct TBLPROPERTIES + val tblProperties = explicitTz.map { + tz => s"""TBLPROPERTIES ($key="$tz")""" + }.getOrElse("") + spark.sql( + s"""CREATE TABLE $baseTable ( + | x int + | ) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + val expectedTableTz = explicitTz.orElse(defaultTz) + checkHasTz(spark, baseTable, expectedTableTz) + spark.sql( + s"""CREATE TABLE partitioned_$baseTable ( + | x int + | ) + | PARTITIONED BY (y int) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + checkHasTz(spark, s"partitioned_$baseTable", expectedTableTz) + spark.sql(s"CREATE TABLE like_$baseTable LIKE $baseTable") + checkHasTz(spark, s"like_$baseTable", expectedTableTz) + spark.sql( + s"""CREATE TABLE select_$baseTable + | STORED AS PARQUET + | AS + | SELECT * from $baseTable + """.stripMargin) + checkHasTz(spark, s"select_$baseTable", defaultTz) + + // check alter table, setting, unsetting, resetting the property + spark.sql( + s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="America/Los_Angeles")""") + checkHasTz(spark, baseTable, Some("America/Los_Angeles")) + spark.sql(s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="UTC")""") + checkHasTz(spark, baseTable, Some("UTC")) + spark.sql(s"""ALTER TABLE $baseTable UNSET TBLPROPERTIES ($key)""") + checkHasTz(spark, baseTable, None) + explicitTz.foreach { tz => + spark.sql(s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="$tz")""") + checkHasTz(spark, baseTable, expectedTableTz) + } + } + } + } + + val desiredTimestampStrings = Seq( + "2015-12-31 22:49:59.123", + "2015-12-31 23:50:59.123", + "2016-01-01 00:39:59.123", + "2016-01-01 01:29:59.123" + ) + // We don't want to mess with timezones inside the tests themselves, since we use a shared + // spark context, and then we might be prone to issues from lazy vals for timezones. Instead, + // we manually adjust the timezone just to determine what the desired millis (since epoch, in utc) + // is for various "wall-clock" times in different timezones, and then we can compare against those + // in our tests. + val timestampTimezoneToMillis = { + val originalTz = TimeZone.getDefault + try { + desiredTimestampStrings.flatMap { timestampString => + Seq("America/Los_Angeles", "Europe/Berlin", "UTC").map { tzId => + TimeZone.setDefault(TimeZone.getTimeZone(tzId)) + val timestamp = Timestamp.valueOf(timestampString) + (timestampString, tzId) -> timestamp.getTime() + } + }.toMap + } finally { + TimeZone.setDefault(originalTz) + } + } + + private def createRawData(spark: SparkSession): Dataset[(String, Timestamp)] = { + import spark.implicits._ + val df = desiredTimestampStrings.toDF("display") + // this will get the millis corresponding to the display time given the current *session* + // timezone. + df.withColumn("ts", expr("cast(display as timestamp)")).as[(String, Timestamp)] + } + + private def testWriteTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]) : Unit = { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + test(s"SPARK-12297: Write to Parquet tables with Timestamps; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + + withTable(s"saveAsTable_$baseTable", s"insert_$baseTable", s"partitioned_ts_$baseTable") { + val sessionTzId = sessionTzOpt.getOrElse(TimeZone.getDefault().getID()) + // check that created tables have correct TBLPROPERTIES + val tblProperties = explicitTz.map { + tz => s"""TBLPROPERTIES ($key="$tz")""" + }.getOrElse("") + + val rawData = createRawData(spark) + // Check writing data out. + // We write data into our tables, and then check the raw parquet files to see whether + // the correct conversion was applied. + rawData.write.saveAsTable(s"saveAsTable_$baseTable") + checkHasTz(spark, s"saveAsTable_$baseTable", None) + spark.sql( + s"""CREATE TABLE insert_$baseTable ( + | display string, + | ts timestamp + | ) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + checkHasTz(spark, s"insert_$baseTable", explicitTz) + rawData.write.insertInto(s"insert_$baseTable") + // no matter what, roundtripping via the table should leave the data unchanged + val readFromTable = spark.table(s"insert_$baseTable").collect() + .map { row => (row.getAs[String](0), row.getAs[Timestamp](1)).toString() }.sorted + assert(readFromTable === rawData.collect().map(_.toString()).sorted) + + // Now we load the raw parquet data on disk, and check if it was adjusted correctly. + // Note that we only store the timezone in the table property, so when we read the + // data this way, we're bypassing all of the conversion logic, and reading the raw + // values in the parquet file. + val onDiskLocation = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(s"insert_$baseTable")).location.getPath + // we test reading the data back with and without the vectorized reader, to make sure we + // haven't broken reading parquet from non-hive tables, with both readers. + Seq(false, true).foreach { vectorized => + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized) + val readFromDisk = spark.read.parquet(onDiskLocation).collect() + val storageTzId = explicitTz.getOrElse(sessionTzId) + readFromDisk.foreach { row => + val displayTime = row.getAs[String](0) + val millis = row.getAs[Timestamp](1).getTime() + val expectedMillis = timestampTimezoneToMillis((displayTime, storageTzId)) + assert(expectedMillis === millis, s"Display time '$displayTime' was stored " + + s"incorrectly with sessionTz = ${sessionTzOpt}; Got $millis, expected " + + s"$expectedMillis (delta = ${millis - expectedMillis})") + } + } + + // check tables partitioned by timestamps. We don't compare the "raw" data in this case, + // since they are adjusted even when we bypass the hive table. + rawData.write.partitionBy("ts").saveAsTable(s"partitioned_ts_$baseTable") + val partitionDiskLocation = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(s"partitioned_ts_$baseTable")).location.getPath + // no matter what mix of timezones we use, the dirs should specify the value with the + // same time we use for display. + val parts = new File(partitionDiskLocation).list().collect { + case name if name.startsWith("ts=") => URLDecoder.decode(name.stripPrefix("ts=")) + }.toSet + assert(parts === desiredTimestampStrings.toSet) + } + } + } + + private def testReadTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + test(s"SPARK-12297: Read from Parquet tables with Timestamps; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + withTable(s"external_$baseTable", s"partitioned_$baseTable") { + // we intentionally save this data directly, without creating a table, so we can + // see that the data is read back differently depending on table properties. + // we'll save with adjusted millis, so that it should be the correct millis after reading + // back. + val rawData = createRawData(spark) + // to avoid closing over entire class + val timestampTimezoneToMillis = this.timestampTimezoneToMillis + import spark.implicits._ + val adjustedRawData = (explicitTz match { + case Some(tzId) => + rawData.map { case (displayTime, _) => + val storageMillis = timestampTimezoneToMillis((displayTime, tzId)) + (displayTime, new Timestamp(storageMillis)) + } + case _ => + rawData + }).withColumnRenamed("_1", "display").withColumnRenamed("_2", "ts") + withTempPath { basePath => + val unpartitionedPath = new File(basePath, "flat") + val partitionedPath = new File(basePath, "partitioned") + adjustedRawData.write.parquet(unpartitionedPath.getCanonicalPath) + val options = Map("path" -> unpartitionedPath.getCanonicalPath) ++ + explicitTz.map { tz => Map(key -> tz) }.getOrElse(Map()) + + spark.catalog.createTable( + tableName = s"external_$baseTable", + source = "parquet", + schema = new StructType().add("display", StringType).add("ts", TimestampType), + options = options + ) + + // also write out a partitioned table, to make sure we can access that correctly. + // add a column we can partition by (value doesn't particularly matter). + val partitionedData = adjustedRawData.withColumn("id", monotonicallyIncreasingId) + partitionedData.write.partitionBy("id") + .parquet(partitionedPath.getCanonicalPath) + // unfortunately, catalog.createTable() doesn't let us specify partitioning, so just use + // a "CREATE TABLE" stmt. + val tblOpts = explicitTz.map { tz => s"""TBLPROPERTIES ($key="$tz")""" }.getOrElse("") + spark.sql(s"""CREATE EXTERNAL TABLE partitioned_$baseTable ( + | display string, + | ts timestamp + |) + |PARTITIONED BY (id bigint) + |STORED AS parquet + |LOCATION 'file:${partitionedPath.getCanonicalPath}' + |$tblOpts + """.stripMargin) + spark.sql(s"msck repair table partitioned_$baseTable") + + for { + vectorized <- Seq(false, true) + partitioned <- Seq(false, true) + } { + withClue(s"vectorized = $vectorized; partitioned = $partitioned") { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized) + val sessionTz = sessionTzOpt.getOrElse(TimeZone.getDefault().getID()) + val table = if (partitioned) s"partitioned_$baseTable" else s"external_$baseTable" + val query = s"select display, cast(ts as string) as ts_as_string, ts " + + s"from $table" + val collectedFromExternal = spark.sql(query).collect() + assert( collectedFromExternal.size === 4) + collectedFromExternal.foreach { row => + val displayTime = row.getAs[String](0) + // the timestamp should still display the same, despite the changes in timezones + assert(displayTime === row.getAs[String](1).toString()) + // we'll also check that the millis behind the timestamp has the appropriate + // adjustments. + val millis = row.getAs[Timestamp](2).getTime() + val expectedMillis = timestampTimezoneToMillis((displayTime, sessionTz)) + val delta = millis - expectedMillis + val deltaHours = delta / (1000L * 60 * 60) + assert(millis === expectedMillis, s"Display time '$displayTime' did not have " + + s"correct millis: was $millis, expected $expectedMillis; delta = $delta " + + s"($deltaHours hours)") + } + + // Now test that the behavior is still correct even with a filter which could get + // pushed down into parquet. We don't need extra handling for pushed down + // predicates because (a) in ParquetFilters, we ignore TimestampType and (b) parquet + // does not read statistics from int96 fields, as they are unsigned. See + // scalastyle:off line.size.limit + // https://github.com/apache/parquet-mr/blob/2fd62ee4d524c270764e9b91dca72e5cf1a005b7/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java#L419 + // https://github.com/apache/parquet-mr/blob/2fd62ee4d524c270764e9b91dca72e5cf1a005b7/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java#L348 + // scalastyle:on line.size.limit + // + // Just to be defensive in case anything ever changes in parquet, this test checks + // the assumption on column stats, and also the end-to-end behavior. + + val hadoopConf = sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val parts = if (partitioned) { + val subdirs = fs.listStatus(new Path(partitionedPath.getCanonicalPath)) + .filter(_.getPath().getName().startsWith("id=")) + fs.listStatus(subdirs.head.getPath()) + .filter(_.getPath().getName().endsWith(".parquet")) + } else { + fs.listStatus(new Path(unpartitionedPath.getCanonicalPath)) + .filter(_.getPath().getName().endsWith(".parquet")) + } + // grab the meta data from the parquet file. The next section of asserts just make + // sure the test is configured correctly. + assert(parts.size == 1) + val oneFooter = ParquetFileReader.readFooter(hadoopConf, parts.head.getPath) + assert(oneFooter.getFileMetaData.getSchema.getColumns.size === 2) + assert(oneFooter.getFileMetaData.getSchema.getColumns.get(1).getType() === + PrimitiveTypeName.INT96) + val oneBlockMeta = oneFooter.getBlocks().get(0) + val oneBlockColumnMeta = oneBlockMeta.getColumns().get(1) + val columnStats = oneBlockColumnMeta.getStatistics + // This is the important assert. Column stats are written, but they are ignored + // when the data is read back as mentioned above, b/c int96 is unsigned. This + // assert makes sure this holds even if we change parquet versions (if eg. there + // were ever statistics even on unsigned columns). + assert(columnStats.isEmpty) + + // These queries should return the entire dataset, but if the predicates were + // applied to the raw values in parquet, they would incorrectly filter data out. + Seq( + ">" -> "2015-12-31 22:00:00", + "<" -> "2016-01-01 02:00:00" + ).foreach { case (comparison, value) => + val query = + s"select ts from $table where ts $comparison '$value'" + val countWithFilter = spark.sql(query).count() + assert(countWithFilter === 4, query) + } + } + } + } + } + } + } + + test("SPARK-12297: exception on bad timezone") { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + val badTzException = intercept[AnalysisException] { + spark.sql( + s"""CREATE TABLE bad_tz_table ( + | x int + | ) + | STORED AS PARQUET + | TBLPROPERTIES ($key="Blart Versenwald III") + """.stripMargin) + } + assert(badTzException.getMessage.contains("Blart Versenwald III")) + } } From c24bdaab5a234d18b273544cefc44cc4005bf8fc Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 7 May 2017 23:10:18 -0700 Subject: [PATCH 15/30] [SPARK-20626][SPARKR] address date test warning with timezone on windows ## What changes were proposed in this pull request? set timezone on windows ## How was this patch tested? unit test, AppVeyor Author: Felix Cheung Closes #17892 from felixcheung/rtimestamptest. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 0856bab5686c5..f517ce6713133 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -96,6 +96,10 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} + test_that("calling sparkRSQL.init returns existing SQL context", { skip_on_cran() From 42cc6d13edbebb7c435ec47c0c12b445e05fdd49 Mon Sep 17 00:00:00 2001 From: sujith71955 Date: Sun, 7 May 2017 23:15:00 -0700 Subject: [PATCH 16/30] [SPARK-20380][SQL] Unable to set/unset table comment property using ALTER TABLE SET/UNSET TBLPROPERTIES ddl ### What changes were proposed in this pull request? Table comment was not getting set/unset using **ALTER TABLE SET/UNSET TBLPROPERTIES** query eg: ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment) when user alter the table properties and adds/updates table comment,table comment which is a field of **CatalogTable** instance is not getting updated and old table comment if exists was shown to user, inorder to handle this issue, update the comment field value in **CatalogTable** with the newly added/modified comment along with other table level properties when user executes **ALTER TABLE SET TBLPROPERTIES** query. This pr has also taken care of unsetting the table comment when user executes query **ALTER TABLE UNSET TBLPROPERTIES** inorder to unset or remove table comment. eg: ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment') ### How was this patch tested? Added test cases as part of **SQLQueryTestSuite** for verifying table comment using desc formatted table query after adding/modifying table comment as part of **AlterTableSetPropertiesCommand** and unsetting the table comment using **AlterTableUnsetPropertiesCommand**. Author: sujith71955 Closes #17649 from sujith71955/alter_table_comment. --- .../catalyst/catalog/InMemoryCatalog.scala | 8 +- .../spark/sql/execution/command/ddl.scala | 12 +- .../describe-table-after-alter-table.sql | 29 ++++ .../describe-table-after-alter-table.sql.out | 161 ++++++++++++++++++ 4 files changed, 204 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 81dd8efc0015f..8a5319bebe54e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -216,8 +216,8 @@ class InMemoryCatalog( } else { tableDefinition } - - catalog(db).tables.put(table, new TableDesc(tableWithLocation)) + val tableProp = tableWithLocation.properties.filter(_._1 != "comment") + catalog(db).tables.put(table, new TableDesc(tableWithLocation.copy(properties = tableProp))) } } @@ -298,7 +298,9 @@ class InMemoryCatalog( assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) - catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition + val updatedProperties = tableDefinition.properties.filter(kv => kv._1 != "comment") + val newTableDefinition = tableDefinition.copy(properties = updatedProperties) + catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } override def alterTableSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 55540563ef911..793fb9b795596 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -231,8 +231,12 @@ case class AlterTableSetPropertiesCommand( val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) DDLUtils.verifyAlterTableType(catalog, table, isView) - // This overrides old properties - val newTable = table.copy(properties = table.properties ++ properties) + // This overrides old properties and update the comment parameter of CatalogTable + // with the newly added/modified comment since CatalogTable also holds comment as its + // direct property. + val newTable = table.copy( + properties = table.properties ++ properties, + comment = properties.get("comment")) catalog.alterTable(newTable) Seq.empty[Row] } @@ -267,8 +271,10 @@ case class AlterTableUnsetPropertiesCommand( } } } + // If comment is in the table property, we reset it to None + val tableComment = if (propKeys.contains("comment")) None else table.properties.get("comment") val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) } - val newTable = table.copy(properties = newProperties) + val newTable = table.copy(properties = newProperties, comment = tableComment) catalog.alterTable(newTable) Seq.empty[Row] } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql new file mode 100644 index 0000000000000..69bff6656c43a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql @@ -0,0 +1,29 @@ +CREATE TABLE table_with_comment (a STRING, b INT, c STRING, d STRING) USING parquet COMMENT 'added'; + +DESC FORMATTED table_with_comment; + +-- ALTER TABLE BY MODIFYING COMMENT +ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment", "type"= "parquet"); + +DESC FORMATTED table_with_comment; + +-- DROP TEST TABLE +DROP TABLE table_with_comment; + +-- CREATE TABLE WITHOUT COMMENT +CREATE TABLE table_comment (a STRING, b INT) USING parquet; + +DESC FORMATTED table_comment; + +-- ALTER TABLE BY ADDING COMMENT +ALTER TABLE table_comment SET TBLPROPERTIES(comment = "added comment"); + +DESC formatted table_comment; + +-- ALTER UNSET PROPERTIES COMMENT +ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment'); + +DESC FORMATTED table_comment; + +-- DROP TEST TABLE +DROP TABLE table_comment; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out new file mode 100644 index 0000000000000..1cc11c475bc40 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out @@ -0,0 +1,161 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +CREATE TABLE table_with_comment (a STRING, b INT, c STRING, d STRING) USING parquet COMMENT 'added' +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +DESC FORMATTED table_with_comment +-- !query 1 schema +struct +-- !query 1 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table table_with_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment added +Location [not included in comparison]sql/core/spark-warehouse/table_with_comment + + +-- !query 2 +ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment", "type"= "parquet") +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +DESC FORMATTED table_with_comment +-- !query 3 schema +struct +-- !query 3 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table table_with_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment modified comment +Properties [type=parquet] +Location [not included in comparison]sql/core/spark-warehouse/table_with_comment + + +-- !query 4 +DROP TABLE table_with_comment +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE table_comment (a STRING, b INT) USING parquet +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +DESC FORMATTED table_comment +-- !query 6 schema +struct +-- !query 6 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 7 +ALTER TABLE table_comment SET TBLPROPERTIES(comment = "added comment") +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +DESC formatted table_comment +-- !query 8 schema +struct +-- !query 8 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment added comment +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 9 +ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment') +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +DESC FORMATTED table_comment +-- !query 10 schema +struct +-- !query 10 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 11 +DROP TABLE table_comment +-- !query 11 schema +struct<> +-- !query 11 output + From 2fdaeb52bbe2ed1a9127ac72917286e505303c85 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Sun, 7 May 2017 23:16:30 -0700 Subject: [PATCH 17/30] [SPARKR][DOC] fix typo in vignettes ## What changes were proposed in this pull request? Fix typo in vignettes Author: Wayne Zhang Closes #17884 from actuaryzhang/typo. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 36 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d38ec4f1b6f37..49f4ab8f146a8 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -65,7 +65,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun head(carsDF) ``` -Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`. ```{r} carsSubDF <- select(carsDF, "model", "mpg", "hp") carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) @@ -379,7 +379,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` -Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} out <- dapplyCollect( @@ -405,7 +405,7 @@ result <- gapply( head(arrange(result, "max_mpg", decreasing = TRUE)) ``` -Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} result <- gapplyCollect( @@ -458,20 +458,20 @@ options(ops) ### SQL Queries -A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. ```{r} people <- read.df(paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/people.json"), "json") ``` -Register this SparkDataFrame as a temporary view. +Register this `SparkDataFrame` as a temporary view. ```{r} createOrReplaceTempView(people, "people") ``` -SQL statements can be run by using the sql method. +SQL statements can be run using the sql method. ```{r} teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) @@ -780,7 +780,7 @@ head(predict(isoregModel, newDF)) `spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: +We use the `longley` dataset to train a gradient-boosted tree and make predictions: ```{r, warning=FALSE} df <- createDataFrame(longley) @@ -820,7 +820,7 @@ head(select(fitted, "Class", "prediction")) `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. -We use a simulated example to demostrate the usage. +We use a simulated example to demonstrate the usage. ```{r} X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) @@ -851,9 +851,9 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 * Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). -* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. +* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. -To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column: * character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. @@ -901,7 +901,7 @@ perplexity `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). -There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. ```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -981,7 +981,7 @@ testSummary ### Model Persistence -The following example shows how to save/load an ML model by SparkR. +The following example shows how to save/load an ML model in SparkR. ```{r} t <- as.data.frame(Titanic) training <- createDataFrame(t) @@ -1079,19 +1079,19 @@ There are three main object classes in SparkR you may be working with. + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + `env` saves the meta-information of the object such as `isCached`. -It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. -* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. +* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend. -It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. -* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend. -This is often an intermediate object with group information and followed up by aggregation operations. + This is often an intermediate object with group information and followed up by aggregation operations. ### Architecture -A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. +A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*. Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. From 0f820e2b6c507dc4156703862ce65e598ca41cca Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 8 May 2017 10:00:58 +0100 Subject: [PATCH 18/30] [SPARK-20519][SQL][CORE] Modify to prevent some possible runtime exceptions Signed-off-by: liuxian ## What changes were proposed in this pull request? When the input parameter is null, may be a runtime exception occurs ## How was this patch tested? Existing unit tests Author: liuxian Closes #17796 from 10110346/wip_lx_0428. --- .../scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- .../scala/org/apache/spark/deploy/DeployMessage.scala | 8 ++++---- .../scala/org/apache/spark/deploy/master/Master.scala | 2 +- .../org/apache/spark/deploy/master/MasterArguments.scala | 4 ++-- .../org/apache/spark/deploy/master/WorkerInfo.scala | 2 +- .../scala/org/apache/spark/deploy/worker/Worker.scala | 2 +- .../org/apache/spark/deploy/worker/WorkerArguments.scala | 4 ++-- .../main/scala/org/apache/spark/executor/Executor.scala | 2 +- .../scala/org/apache/spark/storage/BlockManagerId.scala | 2 +- core/src/main/scala/org/apache/spark/util/RpcUtils.scala | 2 +- core/src/main/scala/org/apache/spark/util/Utils.scala | 9 +++++---- .../deploy/mesos/MesosClusterDispatcherArguments.scala | 2 +- 12 files changed, 21 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b0dd2fc187baf..fb0405b1a69c6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -879,7 +879,7 @@ private[spark] class PythonAccumulatorV2( private val serverPort: Int) extends CollectionAccumulator[Array[Byte]] { - Utils.checkHost(serverHost, "Expected hostname") + Utils.checkHost(serverHost) val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index ac09c6c497f8b..b5cb3f0a0f9dc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -43,7 +43,7 @@ private[deploy] object DeployMessages { memory: Int, workerWebUiUrl: String) extends DeployMessage { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } @@ -131,7 +131,7 @@ private[deploy] object DeployMessages { // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { - Utils.checkHostPort(hostPort, "Required hostport") + Utils.checkHostPort(hostPort) } case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -183,7 +183,7 @@ private[deploy] object DeployMessages { completedDrivers: Array[DriverInfo], status: MasterState) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) def uri: String = "spark://" + host + ":" + port @@ -201,7 +201,7 @@ private[deploy] object DeployMessages { drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 816bf37e39fee..e061939623cbb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -80,7 +80,7 @@ private[deploy] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(address.host, "Expected hostname") + Utils.checkHost(address.host) private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index c63793c16dcef..615d2533cf085 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -60,12 +60,12 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) exte @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 4e20c10fd1427..c87d6e24b78c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -32,7 +32,7 @@ private[spark] class WorkerInfo( val webUiAddress: String) extends Serializable { - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 00b9d1af373db..34e3a4c020c80 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -55,7 +55,7 @@ private[deploy] class Worker( private val host = rpcEnv.address.host private val port = rpcEnv.address.port - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) // A scheduled executor used to send messages at the specified time. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 777020d4d5c84..bd07d342e04ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -68,12 +68,12 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 51b6c373c4daf..3bc47b670305b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -71,7 +71,7 @@ private[spark] class Executor( private val conf = env.conf // No ip or host:port - just hostname - Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname) // must not have port specified. assert (0 == Utils.parseHostPort(executorHostname)._2) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index c37a3604d28fa..2c3da0ee85e06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -46,7 +46,7 @@ class BlockManagerId private ( def executorId: String = executorId_ if (null != host_) { - Utils.checkHost(host_, "Expected hostname") + Utils.checkHost(host_) assert (port_ > 0) } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 46a5cb2cff5a5..e5cccf39f9455 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -28,7 +28,7 @@ private[spark] object RpcUtils { def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") + Utils.checkHost(driverHost) rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4d37db96dfc37..edfe229792323 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -937,12 +937,13 @@ private[spark] object Utils extends Logging { customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } - def checkHost(host: String, message: String = "") { - assert(host.indexOf(':') == -1, message) + def checkHost(host: String) { + assert(host != null && host.indexOf(':') == -1, s"Expected hostname (not IP) but got $host") } - def checkHostPort(hostPort: String, message: String = "") { - assert(hostPort.indexOf(':') != -1, message) + def checkHostPort(hostPort: String) { + assert(hostPort != null && hostPort.indexOf(':') != -1, + s"Expected host and port but got $hostPort") } // Typically, this will be of order of number of nodes in cluster diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index ef08502ec8dd6..ddea762fdb919 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -59,7 +59,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: @tailrec private def parse(args: List[String]): Unit = args match { case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) From 15526653a93a32cde3c9ea0c0e68e35622b0a590 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 8 May 2017 17:33:47 +0800 Subject: [PATCH 19/30] [SPARK-19956][CORE] Optimize a location order of blocks with topology information ## What changes were proposed in this pull request? When call the method getLocations of BlockManager, we only compare the data block host. Random selection for non-local data blocks, this may cause the selected data block to be in a different rack. So in this patch to increase the sort of the rack. ## How was this patch tested? New test case. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xianyang Liu Closes #17300 from ConeyLiu/blockmanager. --- .../apache/spark/storage/BlockManager.scala | 11 +++++-- .../spark/storage/BlockManagerSuite.scala | 31 +++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3219969bcd06f..33ce30c58e1ad 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -612,12 +612,19 @@ private[spark] class BlockManager( /** * Return a list of locations for the given block, prioritizing the local machine since - * multiple block managers can share the same host. + * multiple block managers can share the same host, followed by hosts on the same rack. */ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { val locs = Random.shuffle(master.getLocations(blockId)) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } - preferredLocs ++ otherLocs + blockManagerId.topologyInfo match { + case None => preferredLocs ++ otherLocs + case Some(_) => + val (sameRackLocs, differentRackLocs) = otherLocs.partition { + loc => blockManagerId.topologyInfo == loc.topologyInfo + } + preferredLocs ++ sameRackLocs ++ differentRackLocs + } } /** diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a8b9604899838..1e7bcdb6740f6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -496,8 +496,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } - test("optimize a location order of blocks") { - val localHost = Utils.localHostName() + test("optimize a location order of blocks without topology information") { + val localHost = "localhost" val otherHost = "otherHost" val bmMaster = mock(classOf[BlockManagerMaster]) val bmId1 = BlockManagerId("id1", localHost, 1) @@ -508,7 +508,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) - assert(locations.map(_.host).toSet === Set(localHost, localHost, otherHost)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) + } + + test("optimize a location order of blocks with topology information") { + val localHost = "localhost" + val otherHost = "otherHost" + val localRack = "localRack" + val otherRack = "otherRack" + + val bmMaster = mock(classOf[BlockManagerMaster]) + val bmId1 = BlockManagerId("id1", localHost, 1, Some(localRack)) + val bmId2 = BlockManagerId("id2", localHost, 2, Some(localRack)) + val bmId3 = BlockManagerId("id3", otherHost, 3, Some(otherRack)) + val bmId4 = BlockManagerId("id4", otherHost, 4, Some(otherRack)) + val bmId5 = BlockManagerId("id5", otherHost, 5, Some(localRack)) + when(bmMaster.getLocations(mc.any[BlockId])) + .thenReturn(Seq(bmId1, bmId2, bmId5, bmId3, bmId4)) + + val blockManager = makeBlockManager(128, "exec", bmMaster) + blockManager.blockManagerId = + BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) + val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) + val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) + assert(locations.flatMap(_.topologyInfo) + === Seq(localRack, localRack, localRack, otherRack, otherRack)) } test("SPARK-9591: getRemoteBytes from another location when Exception throw") { From 58518d070777fc0665c4d02bad8adf910807df98 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 8 May 2017 12:45:00 +0200 Subject: [PATCH 20/30] [SPARK-20596][ML][TEST] Consolidate and improve ALS recommendAll test cases Existing test cases for `recommendForAllX` methods (added in [SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)) test `k < num items` and `k = num items`. Technically we should also test that `k > num items` returns the same results as `k = num items`. ## How was this patch tested? Updated existing unit tests. Author: Nick Pentreath Closes #17860 from MLnick/SPARK-20596-als-rec-tests. --- .../spark/ml/recommendation/ALSSuite.scala | 63 ++++++++----------- 1 file changed, 25 insertions(+), 38 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 7574af3d77ea8..9d31e792633cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -671,58 +671,45 @@ class ALSSuite .setItemCol("item") } - test("recommendForAllUsers with k < num_items") { - val topItems = getALSModel.recommendForAllUsers(2) - assert(topItems.count() == 3) - assert(topItems.columns.contains("user")) - - val expected = Map( - 0 -> Array((3, 54f), (4, 44f)), - 1 -> Array((3, 39f), (5, 33f)), - 2 -> Array((3, 51f), (5, 45f)) - ) - checkRecommendations(topItems, expected, "item") - } - - test("recommendForAllUsers with k = num_items") { - val topItems = getALSModel.recommendForAllUsers(4) - assert(topItems.count() == 3) - assert(topItems.columns.contains("user")) - + test("recommendForAllUsers with k <, = and > num_items") { + val model = getALSModel + val numUsers = model.userFactors.count + val numItems = model.itemFactors.count val expected = Map( 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) ) - checkRecommendations(topItems, expected, "item") - } - test("recommendForAllItems with k < num_users") { - val topUsers = getALSModel.recommendForAllItems(2) - assert(topUsers.count() == 4) - assert(topUsers.columns.contains("item")) - - val expected = Map( - 3 -> Array((0, 54f), (2, 51f)), - 4 -> Array((0, 44f), (2, 30f)), - 5 -> Array((2, 45f), (0, 42f)), - 6 -> Array((0, 28f), (2, 18f)) - ) - checkRecommendations(topUsers, expected, "user") + Seq(2, 4, 6).foreach { k => + val n = math.min(k, numItems).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topItems = model.recommendForAllUsers(k) + assert(topItems.count() == numUsers) + assert(topItems.columns.contains("user")) + checkRecommendations(topItems, expectedUpToN, "item") + } } - test("recommendForAllItems with k = num_users") { - val topUsers = getALSModel.recommendForAllItems(3) - assert(topUsers.count() == 4) - assert(topUsers.columns.contains("item")) - + test("recommendForAllItems with k <, = and > num_users") { + val model = getALSModel + val numUsers = model.userFactors.count + val numItems = model.itemFactors.count val expected = Map( 3 -> Array((0, 54f), (2, 51f), (1, 39f)), 4 -> Array((0, 44f), (2, 30f), (1, 26f)), 5 -> Array((2, 45f), (0, 42f), (1, 33f)), 6 -> Array((0, 28f), (2, 18f), (1, 16f)) ) - checkRecommendations(topUsers, expected, "user") + + Seq(2, 3, 4).foreach { k => + val n = math.min(k, numUsers).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topUsers = getALSModel.recommendForAllItems(k) + assert(topUsers.count() == numItems) + assert(topUsers.columns.contains("item")) + checkRecommendations(topUsers, expectedUpToN, "user") + } } private def checkRecommendations( From f3b7e0bb9c141058fdbcf202a4b8a47a25237613 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 3 Oct 2016 12:09:18 -0700 Subject: [PATCH 21/30] SHS-NG M1: Add KVStore abstraction, LevelDB implementation. The interface is described in KVIndex.java (see javadoc). Specifics of the LevelDB implementation are discussed in the javadocs of both LevelDB.java and LevelDBTypeInfo.java. Included also are a few small benchmarks just to get some idea of latency. Because they're too slow for regular unit test runs, they're disabled by default. --- common/kvstore/pom.xml | 90 ++++ .../org/apache/spark/kvstore/KVIndex.java | 69 +++ .../org/apache/spark/kvstore/KVStore.java | 137 ++++++ .../apache/spark/kvstore/KVStoreIterator.java | 47 +++ .../spark/kvstore/KVStoreSerializer.java | 70 +++ .../org/apache/spark/kvstore/KVStoreView.java | 91 ++++ .../org/apache/spark/kvstore/LevelDB.java | 239 +++++++++++ .../apache/spark/kvstore/LevelDBIterator.java | 249 +++++++++++ .../apache/spark/kvstore/LevelDBTypeInfo.java | 375 ++++++++++++++++ .../UnsupportedStoreVersionException.java | 27 ++ .../org/apache/spark/kvstore/CustomType1.java | 60 +++ .../spark/kvstore/LevelDBBenchmark.java | 323 ++++++++++++++ .../spark/kvstore/LevelDBIteratorSuite.java | 399 ++++++++++++++++++ .../apache/spark/kvstore/LevelDBSuite.java | 281 ++++++++++++ .../spark/kvstore/LevelDBTypeInfoSuite.java | 177 ++++++++ .../src/test/resources/log4j.properties | 27 ++ pom.xml | 11 + project/SparkBuild.scala | 6 +- 18 files changed, 2675 insertions(+), 3 deletions(-) create mode 100644 common/kvstore/pom.xml create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java create mode 100644 common/kvstore/src/test/resources/log4j.properties diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml new file mode 100644 index 0000000000000..ab296c5b5fb9b --- /dev/null +++ b/common/kvstore/pom.xml @@ -0,0 +1,90 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../pom.xml + + + spark-kvstore_2.11 + jar + Spark Project Local DB + http://spark.apache.org/ + + kvstore + + + + + com.google.guava + guava + + + org.fusesource.leveldbjni + leveldbjni-all + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + + commons-io + commons-io + test + + + log4j + log4j + test + + + org.slf4j + slf4j-api + test + + + org.slf4j + slf4j-log4j12 + test + + + io.dropwizard.metrics + metrics-core + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java new file mode 100644 index 0000000000000..3c61e7706079a --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Tags a field to be indexed when storing an object. + * + *

+ * Types are required to have a natural index that uniquely identifies instances in the store. + * The default value of the annotation identifies the natural index for the type. + *

+ * + *

+ * Indexes allow for more efficient sorting of data read from the store. By annotating a field or + * "getter" method with this annotation, an index will be created that will provide sorting based on + * the string value of that field. + *

+ * + *

+ * Note that creating indices means more space will be needed, and maintenance operations like + * updating or deleting a value will become more expensive. + *

+ * + *

+ * Indices are restricted to String, and integral types (byte, short, int, long, boolean). + *

+ */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.FIELD, ElementType.METHOD}) +public @interface KVIndex { + + public static final String NATURAL_INDEX_NAME = "__main__"; + + /** + * The name of the index to be created for the annotated entity. Must be unique within + * the class. Index names are not allowed to start with an underscore (that's reserved for + * internal use). The default value is the natural index name (which is always a copy index + * regardless of the annotation's values). + */ + String value() default NATURAL_INDEX_NAME; + + /** + * Whether to copy the instance's data to the index, instead of just storing a pointer to the + * data. The default behavior is to just store a reference; that saves disk space but is slower + * to read, since there's a level of indirection. + */ + boolean copy() default false; + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java new file mode 100644 index 0000000000000..31d4e6fefc289 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.Closeable; +import java.util.Iterator; +import java.util.Map; + +/** + * Abstraction for a local key/value store for storing app data. + * + *

+ * Use {@link KVStoreBuilder} to create an instance. There are two main features provided by the + * implementations of this interface: + *

+ * + *
    + *
  • serialization: this feature is not optional; data will be serialized to and deserialized + * from the underlying data store using a {@link KVStoreSerializer}, which can be customized by + * the application. The serializer is based on Jackson, so it supports all the Jackson annotations + * for controlling the serialization of app-defined types.
  • + * + *
  • key management: by using {@link #read(Class, Object)} and {@link #write(Class, Object)}, + * applications can leave key management to the implementation. For applications that want to + * manage their own keys, the {@link #get(byte[], Class)} and {@link #set(byte[], Object)} methods + * are available.
  • + *
+ * + *

Automatic Key Management

+ * + *

+ * When using the built-in key management, the implementation will automatically create unique + * keys for each type written to the store. Keys are based on the type name, and always start + * with the "+" prefix character (so that it's easy to use both manual and automatic key + * management APIs without conflicts). + *

+ * + *

+ * Another feature of automatic key management is indexing; by annotating fields or methods of + * objects written to the store with {@link KVIndex}, indices are created to sort the data + * by the values of those properties. This makes it possible to provide sorting without having + * to load all instances of those types from the store. + *

+ * + *

+ * KVStore instances are thread-safe for both reads and writes. + *

+ */ +public interface KVStore extends Closeable { + + /** + * Returns app-specific metadata from the store, or null if it's not currently set. + * + *

+ * The metadata type is application-specific. This is a convenience method so that applications + * don't need to define their own keys for this information. + *

+ */ + T getMetadata(Class klass) throws Exception; + + /** + * Writes the given value in the store metadata key. + */ + void setMetadata(Object value) throws Exception; + + /** + * Returns the value of a specific key, deserialized to the given type. + */ + T get(byte[] key, Class klass) throws Exception; + + /** + * Write a single key directly to the store, atomically. + */ + void put(byte[] key, Object value) throws Exception; + + /** + * Removes a key from the store. + */ + void delete(byte[] key) throws Exception; + + /** + * Returns an iterator that will only list values with keys starting with the given prefix. + */ + KVStoreIterator iterator(byte[] prefix, Class klass) throws Exception; + + /** + * Read a specific instance of an object. + */ + T read(Class klass, Object naturalKey) throws Exception; + + /** + * Writes the given object to the store, including indexed fields. Indices are updated based + * on the annotated fields of the object's class. + * + *

+ * Writes may be slower when the object already exists in the store, since it will involve + * updating existing indices. + *

+ * + * @param value The object to write. + */ + void write(Object value) throws Exception; + + /** + * Removes an object and all data related to it, like index entries, from the store. + * + * @param type The object's type. + * @param naturalKey The object's "natural key", which uniquely identifies it. + */ + void delete(Class type, Object naturalKey) throws Exception; + + /** + * Returns a configurable view for iterating over entities of the given type. + */ + KVStoreView view(Class type) throws Exception; + + /** + * Returns the number of items of the given type currently in the store. + */ + long count(Class type) throws Exception; + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java new file mode 100644 index 0000000000000..3efdec9ed32be --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.util.Iterator; +import java.util.List; + +/** + * An iterator for KVStore. + * + *

+ * Iterators may keep references to resources that need to be closed. It's recommended that users + * explicitly close iterators after they're used. + *

+ */ +public interface KVStoreIterator extends Iterator, AutoCloseable { + + /** + * Retrieve multiple elements from the store. + * + * @param max Maximum number of elements to retrieve. + */ + List next(int max); + + /** + * Skip in the iterator. + * + * @return Whether there are items left after skipping. + */ + boolean skip(long n); + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java new file mode 100644 index 0000000000000..d9f9e2646cc14 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Serializer used to translate between app-defined types and the LevelDB store. + * + *

+ * The serializer is based on Jackson, so values are written as JSON. It also allows "naked strings" + * and integers to be written as values directly, which will be written as UTF-8 strings. + *

+ */ +public class KVStoreSerializer { + + /** + * Object mapper used to process app-specific types. If an application requires a specific + * configuration of the mapper, it can subclass this serializer and add custom configuration + * to this object. + */ + protected final ObjectMapper mapper; + + public KVStoreSerializer() { + this.mapper = new ObjectMapper(); + } + + public final byte[] serialize(Object o) throws Exception { + if (o instanceof String) { + return ((String) o).getBytes(UTF_8); + } else { + return mapper.writeValueAsBytes(o); + } + } + + @SuppressWarnings("unchecked") + public final T deserialize(byte[] data, Class klass) throws Exception { + if (klass.equals(String.class)) { + return (T) new String(data, UTF_8); + } else { + return mapper.readValue(data, klass); + } + } + + final byte[] serialize(long value) { + return String.valueOf(value).getBytes(UTF_8); + } + + final long deserializeLong(byte[] data) { + return Long.parseLong(new String(data, UTF_8)); + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java new file mode 100644 index 0000000000000..a68c37942dee4 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.util.Iterator; +import java.util.Map; + +import com.google.common.base.Preconditions; + +/** + * A configurable view that allows iterating over values in a {@link KVStore}. + * + *

+ * The different methods can be used to configure the behavior of the iterator. Calling the same + * method multiple times is allowed; the most recent value will be used. + *

+ * + *

+ * The iterators returns by this view are of type {@link KVStoreIterator}; they auto-close + * when used in a for loop that exhausts their contents, but when used manually, they need + * to be closed explicitly unless all elements are read. + *

+ */ +public abstract class KVStoreView implements Iterable { + + final Class type; + + boolean ascending = true; + String index = KVIndex.NATURAL_INDEX_NAME; + Object first = null; + long skip = 0L; + + public KVStoreView(Class type) { + this.type = type; + } + + /** + * Reverses the order of iteration. By default, iterates in ascending order. + */ + public KVStoreView reverse() { + ascending = !ascending; + return this; + } + + /** + * Iterates according to the given index. + */ + public KVStoreView index(String name) { + this.index = Preconditions.checkNotNull(name); + return this; + } + + /** + * Iterates starting at the given value of the chosen index. + */ + public KVStoreView first(Object value) { + this.first = value; + return this; + } + + /** + * Skips a number of elements in the resulting iterator. + */ + public KVStoreView skip(long n) { + this.skip = n; + return this; + } + + /** + * Returns an iterator for the current configuration. + */ + public KVStoreIterator closeableIterator() throws Exception { + return (KVStoreIterator) iterator(); + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java new file mode 100644 index 0000000000000..51287c02ebab1 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.fusesource.leveldbjni.JniDBFactory; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.Options; +import org.iq80.leveldb.WriteBatch; +import org.iq80.leveldb.WriteOptions; + +/** + * Implementation of KVStore that uses LevelDB as the underlying data store. + */ +public class LevelDB implements KVStore { + + @VisibleForTesting + static final long STORE_VERSION = 1L; + + @VisibleForTesting + static final byte[] STORE_VERSION_KEY = "__version__".getBytes(UTF_8); + + /** DB key where app metadata is stored. */ + private static final byte[] METADATA_KEY = "__meta__".getBytes(UTF_8); + + final DB db; + final KVStoreSerializer serializer; + + private final ConcurrentMap, LevelDBTypeInfo> types; + private boolean closed; + + public LevelDB(File path) throws IOException { + this(path, new KVStoreSerializer()); + } + + public LevelDB(File path, KVStoreSerializer serializer) throws IOException { + this.serializer = serializer; + this.types = new ConcurrentHashMap<>(); + + Options options = new Options(); + options.createIfMissing(!path.exists()); + this.db = JniDBFactory.factory.open(path, options); + + byte[] versionData = db.get(STORE_VERSION_KEY); + if (versionData != null) { + long version = serializer.deserializeLong(versionData); + if (version != STORE_VERSION) { + throw new UnsupportedStoreVersionException(); + } + } else { + db.put(STORE_VERSION_KEY, serializer.serialize(STORE_VERSION)); + } + } + + @Override + public T getMetadata(Class klass) throws Exception { + try { + return get(METADATA_KEY, klass); + } catch (NoSuchElementException nsee) { + return null; + } + } + + @Override + public void setMetadata(Object value) throws Exception { + if (value != null) { + put(METADATA_KEY, value); + } else { + db.delete(METADATA_KEY); + } + } + + @Override + public T get(byte[] key, Class klass) throws Exception { + byte[] data = db.get(key); + if (data == null) { + throw new NoSuchElementException(new String(key, UTF_8)); + } + return serializer.deserialize(data, klass); + } + + @Override + public void put(byte[] key, Object value) throws Exception { + Preconditions.checkArgument(value != null, "Null values are not allowed."); + db.put(key, serializer.serialize(value)); + } + + @Override + public void delete(byte[] key) throws Exception { + db.delete(key); + } + + @Override + public KVStoreIterator iterator(byte[] prefix, Class klass) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + public T read(Class klass, Object naturalKey) throws Exception { + Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); + byte[] key = getTypeInfo(klass).naturalIndex().start(naturalKey); + return get(key, klass); + } + + @Override + public void write(Object value) throws Exception { + write(value, false); + } + + public void write(Object value, boolean sync) throws Exception { + Preconditions.checkArgument(value != null, "Null values are not allowed."); + LevelDBTypeInfo ti = getTypeInfo(value.getClass()); + + WriteBatch batch = db.createWriteBatch(); + try { + byte[] data = serializer.serialize(value); + synchronized (ti) { + try { + Object existing = get(ti.naturalIndex().entityKey(value), value.getClass()); + removeInstance(ti, batch, existing); + } catch (NoSuchElementException e) { + // Ignore. No previous value. + } + for (LevelDBTypeInfo.Index idx : ti.indices()) { + idx.add(batch, value, data); + } + db.write(batch, new WriteOptions().sync(sync)); + } + } finally { + batch.close(); + } + } + + @Override + public void delete(Class type, Object naturalKey) throws Exception { + delete(type, naturalKey, false); + } + + public void delete(Class type, Object naturalKey, boolean sync) throws Exception { + Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); + WriteBatch batch = db.createWriteBatch(); + try { + LevelDBTypeInfo ti = getTypeInfo(type); + byte[] key = ti.naturalIndex().start(naturalKey); + byte[] data = db.get(key); + if (data != null) { + Object existing = serializer.deserialize(data, type); + synchronized (ti) { + removeInstance(ti, batch, existing); + db.write(batch, new WriteOptions().sync(sync)); + } + } + } finally { + batch.close(); + } + } + + @Override + public KVStoreView view(Class type) throws Exception { + return new KVStoreView(type) { + @Override + public Iterator iterator() { + try { + return new LevelDBIterator<>(LevelDB.this, this); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + }; + } + + @Override + public long count(Class type) throws Exception { + LevelDBTypeInfo.Index idx = getTypeInfo(type).naturalIndex(); + return idx.getCount(idx.end()); + } + + @Override + public synchronized void close() throws IOException { + if (closed) { + return; + } + + try { + db.close(); + closed = true; + } catch (IOException ioe) { + throw ioe; + } catch (Exception e) { + throw new IOException(e.getMessage(), e); + } + } + + /** Returns metadata about indices for the given type. */ + LevelDBTypeInfo getTypeInfo(Class type) throws Exception { + LevelDBTypeInfo idx = types.get(type); + if (idx == null) { + LevelDBTypeInfo tmp = new LevelDBTypeInfo<>(this, type); + idx = types.putIfAbsent(type, tmp); + if (idx == null) { + idx = tmp; + } + } + return idx; + } + + private void removeInstance(LevelDBTypeInfo ti, WriteBatch batch, Object instance) + throws Exception { + for (LevelDBTypeInfo.Index idx : ti.indices()) { + idx.remove(batch, instance); + } + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java new file mode 100644 index 0000000000000..d0b6e25420812 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.IOException; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; +import org.iq80.leveldb.DBIterator; + +class LevelDBIterator implements KVStoreIterator { + + private final LevelDB db; + private final boolean ascending; + private final DBIterator it; + private final Class type; + private final LevelDBTypeInfo ti; + private final LevelDBTypeInfo.Index index; + private final byte[] indexKeyPrefix; + private final byte[] end; + + private boolean checkedNext; + private T next; + private boolean closed; + + /** + * Creates a simple iterator over db keys. + */ + LevelDBIterator(LevelDB db, byte[] keyPrefix, Class type) throws Exception { + this.db = db; + this.ascending = true; + this.type = type; + this.ti = null; + this.index = null; + this.it = db.db.iterator(); + this.indexKeyPrefix = keyPrefix; + this.end = null; + it.seek(keyPrefix); + } + + /** + * Creates an iterator for indexed types (i.e., those whose keys are managed by the library). + */ + LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { + this.db = db; + this.ascending = params.ascending; + this.it = db.db.iterator(); + this.type = params.type; + this.ti = db.getTypeInfo(type); + this.index = ti.index(params.index); + this.indexKeyPrefix = index.keyPrefix(); + + byte[] firstKey; + if (params.first != null) { + if (ascending) { + firstKey = index.start(params.first); + } else { + firstKey = index.end(params.first); + } + } else if (ascending) { + firstKey = index.keyPrefix(); + } else { + firstKey = index.end(); + } + it.seek(firstKey); + + if (ascending) { + this.end = index.end(); + } else { + this.end = null; + if (it.hasNext()) { + it.next(); + } + } + + if (params.skip > 0) { + skip(params.skip); + } + } + + @Override + public boolean hasNext() { + if (!checkedNext && !closed) { + next = loadNext(); + checkedNext = true; + } + if (!closed && next == null) { + try { + close(); + } catch (IOException ioe) { + throw Throwables.propagate(ioe); + } + } + return next != null; + } + + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + checkedNext = false; + return next; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public List next(int max) { + List list = new ArrayList<>(max); + while (hasNext() && list.size() < max) { + list.add(next()); + } + return list; + } + + @Override + public boolean skip(long n) { + long skipped = 0; + while (skipped < n) { + next = null; + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + return false; + } + + Map.Entry e = ascending ? it.next() : it.prev(); + if (!isEndMarker(e.getKey())) { + skipped++; + } + } + + return true; + } + + @Override + public void close() throws IOException { + if (!closed) { + it.close(); + closed = true; + } + } + + private T loadNext() { + try { + while (true) { + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + return null; + } + + Map.Entry nextEntry; + try { + // Avoid races if another thread is updating the DB. + nextEntry = ascending ? it.next() : it.prev(); + } catch (NoSuchElementException e) { + return null; + } + byte[] nextKey = nextEntry.getKey(); + + // If the next key is an end marker, then skip it. + if (isEndMarker(nextKey)) { + continue; + } + + // Next key is not part of the index, stop. + if (!startsWith(nextKey, indexKeyPrefix)) { + return null; + } + + // If there's a known end key and it's found, stop. + if (end != null && Arrays.equals(nextKey, end)) { + return null; + } + + // Next element is part of the iteration, return it. + if (index == null || index.isCopy()) { + return db.serializer.deserialize(nextEntry.getValue(), type); + } else { + byte[] key = stitch(ti.naturalIndex().keyPrefix(), nextEntry.getValue()); + return db.get(key, type); + } + } + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @VisibleForTesting + static boolean startsWith(byte[] key, byte[] prefix) { + if (key.length < prefix.length) { + return false; + } + + for (int i = 0; i < prefix.length; i++) { + if (key[i] != prefix[i]) { + return false; + } + } + + return true; + } + + private boolean isEndMarker(byte[] key) { + return (key.length > 2 && + key[key.length - 2] == LevelDBTypeInfo.KEY_SEPARATOR && + key[key.length - 1] == (byte) LevelDBTypeInfo.END_MARKER.charAt(0)); + } + + private byte[] stitch(byte[]... comps) { + int len = 0; + for (byte[] comp : comps) { + len += comp.length; + } + + byte[] dest = new byte[len]; + int written = 0; + for (byte[] comp : comps) { + System.arraycopy(comp, 0, dest, written, comp.length); + written += comp.length; + } + + return dest; + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java new file mode 100644 index 0000000000000..c49b18324c00a --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.iq80.leveldb.WriteBatch; + +/** + * Holds metadata about app-specific types stored in LevelDB. Serves as a cache for data collected + * via reflection, to make it cheaper to access it multiple times. + */ +class LevelDBTypeInfo { + + static final String ENTRY_PREFIX = "+"; + static final String END_MARKER = "-"; + static final byte KEY_SEPARATOR = 0x0; + + // These constants are used in the Index.toKey() method below when encoding numbers into keys. + // See javadoc for that method for details. + private static final char POSITIVE_FILL = '.'; + private static final char NEGATIVE_FILL = '~'; + private static final char POSITIVE_MARKER = '='; + private static final char NEGATIVE_MARKER = '*'; + + @VisibleForTesting + static final int BYTE_ENCODED_LEN = String.valueOf(Byte.MAX_VALUE).length() + 1; + @VisibleForTesting + static final int INT_ENCODED_LEN = String.valueOf(Integer.MAX_VALUE).length() + 1; + @VisibleForTesting + static final int LONG_ENCODED_LEN = String.valueOf(Long.MAX_VALUE).length() + 1; + @VisibleForTesting + static final int SHORT_ENCODED_LEN = String.valueOf(Short.MAX_VALUE).length() + 1; + + private final LevelDB db; + private final Class type; + private final Map indices; + private final byte[] typePrefix; + + LevelDBTypeInfo(LevelDB db, Class type) throws Exception { + this.db = db; + this.type = type; + this.indices = new HashMap<>(); + + for (Field f : type.getFields()) { + KVIndex idx = f.getAnnotation(KVIndex.class); + if (idx != null) { + register(idx, new FieldAccessor(f)); + } + } + + for (Method m : type.getMethods()) { + KVIndex idx = m.getAnnotation(KVIndex.class); + if (idx != null) { + Preconditions.checkArgument(m.getParameterTypes().length == 0, + "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + register(idx, new MethodAccessor(m)); + } + } + + Preconditions.checkArgument(indices.get(KVIndex.NATURAL_INDEX_NAME) != null, + "No natural index defined for type %s.", type.getName()); + + ByteArrayOutputStream typePrefix = new ByteArrayOutputStream(); + typePrefix.write(utf8(ENTRY_PREFIX)); + + // Change fully-qualified class names to make keys more spread out by placing the + // class name first, and the package name afterwards. + String[] components = type.getName().split("\\."); + typePrefix.write(utf8(components[components.length - 1])); + if (components.length > 1) { + typePrefix.write(utf8("/")); + } + for (int i = 0; i < components.length - 1; i++) { + typePrefix.write(utf8(components[i])); + if (i < components.length - 2) { + typePrefix.write(utf8(".")); + } + } + typePrefix.write(KEY_SEPARATOR); + this.typePrefix = typePrefix.toByteArray(); + } + + private void register(KVIndex idx, Accessor accessor) { + Preconditions.checkArgument(idx.value() != null && !idx.value().isEmpty(), + "No name provided for index in type %s.", type.getName()); + Preconditions.checkArgument( + !idx.value().startsWith("_") || idx.value().equals(KVIndex.NATURAL_INDEX_NAME), + "Index name %s (in type %s) is not allowed.", idx.value(), type.getName()); + Preconditions.checkArgument(indices.get(idx.value()) == null, + "Duplicate index %s for type %s.", idx.value(), type.getName()); + indices.put(idx.value(), new Index(idx.value(), idx.copy(), accessor)); + } + + Class type() { + return type; + } + + byte[] keyPrefix() { + return buildKey(false); + } + + Index naturalIndex() { + return index(KVIndex.NATURAL_INDEX_NAME); + } + + Index index(String name) { + Index i = indices.get(name); + Preconditions.checkArgument(i != null, "Index %s does not exist for type %s.", name, + type.getName()); + return i; + } + + Collection indices() { + return indices.values(); + } + + private byte[] utf8(String s) { + return s.getBytes(UTF_8); + } + + private byte[] buildKey(boolean trim, String... components) { + try { + ByteArrayOutputStream kos = new ByteArrayOutputStream(typePrefix.length * 2); + kos.write(typePrefix); + for (int i = 0; i < components.length; i++) { + kos.write(utf8(components[i])); + if (!trim || i < components.length - 1) { + kos.write(KEY_SEPARATOR); + } + } + return kos.toByteArray(); + } catch (IOException ioe) { + throw Throwables.propagate(ioe); + } + } + + /** + * Models a single index in LevelDB. Keys are stored under the type's prefix, in sequential + * order according to the indexed value. For non-natural indices, the key also contains the + * entity's natural key after the indexed value, so that it's possible for multiple entities + * to have the same indexed value. + * + *

+ * An end marker is used to mark where the index ends, and the boundaries of each indexed value + * within the index, to make descending iteration faster, at the expense of some disk space and + * minor overhead when iterating. A count of the number of indexed entities is kept at the end + * marker, so that it can be cleaned up when all entries are removed from the index. + *

+ */ + class Index { + + private final boolean copy; + private final boolean isNatural; + private final String name; + + @VisibleForTesting + final Accessor accessor; + + private Index(String name, boolean copy, Accessor accessor) { + this.name = name; + this.isNatural = name.equals(KVIndex.NATURAL_INDEX_NAME); + this.copy = isNatural || copy; + this.accessor = accessor; + } + + boolean isCopy() { + return copy; + } + + /** The prefix for all keys that belong to this index. */ + byte[] keyPrefix() { + return buildKey(false, name); + } + + /** The key where to start ascending iteration for entries that match the given value. */ + byte[] start(Object value) { + return buildKey(isNatural, name, toKey(value)); + } + + /** The key for the index's end marker. */ + byte[] end() { + return buildKey(true, name, END_MARKER); + } + + /** The key for the end marker for index entries with the given value. */ + byte[] end(Object value) throws Exception { + return buildKey(true, name, toKey(value), END_MARKER); + } + + /** The key in the index that identifies the given entity. */ + byte[] entityKey(Object entity) throws Exception { + Object indexValue = accessor.get(entity); + Preconditions.checkNotNull(indexValue, "Null index value for %s in type %s.", + name, type.getName()); + if (isNatural) { + return buildKey(true, name, toKey(indexValue)); + } else { + Object naturalKey = naturalIndex().accessor.get(entity); + return buildKey(true, name, toKey(accessor.get(entity)), toKey(naturalKey)); + } + } + + /** + * Add an entry to the index. + * + * @param batch Write batch with other related changes. + * @param entity The entity being added to the index. + * @param data Serialized entity to store (when storing the entity, not a reference). + * @param naturalKey The value's key. + */ + void add(WriteBatch batch, Object entity, byte[] data) throws Exception { + byte[] stored = data; + if (!copy) { + stored = db.serializer.serialize(toKey(naturalIndex().accessor.get(entity))); + } + batch.put(entityKey(entity), stored); + updateCount(batch, end(accessor.get(entity)), 1L); + updateCount(batch, end(), 1L); + } + + /** + * Remove a value from the index. + * + * @param batch Write batch with other related changes. + * @param entity The entity being removed, to identify the index entry to modify. + * @param naturalKey The value's key. + */ + void remove(WriteBatch batch, Object entity) throws Exception { + batch.delete(entityKey(entity)); + updateCount(batch, end(accessor.get(entity)), -1L); + updateCount(batch, end(), -1L); + } + + long getCount(byte[] key) throws Exception { + byte[] data = db.db.get(key); + return data != null ? db.serializer.deserializeLong(data) : 0; + } + + private void updateCount(WriteBatch batch, byte[] key, long delta) throws Exception { + long count = getCount(key) + delta; + if (count > 0) { + batch.put(key, db.serializer.serialize(count)); + } else { + batch.delete(key); + } + } + + /** + * Translates a value to be used as part of the store key. + * + * Integral numbers are encoded as a string in a way that preserves lexicographical + * ordering. The string is always as long as the maximum value for the given type (e.g. + * 11 characters for integers, including the character for the sign). The first character + * represents the sign (with the character for negative coming before the one for positive, + * which means you cannot use '-'...). The rest of the value is padded with a value that is + * "greater than 9" for negative values, so that for example "-123" comes before "-12" (the + * encoded value would look like "*~~~~~~~123"). For positive values, similarly, a value that + * is "lower than 0" (".") is used for padding. The fill characters were chosen for readability + * when looking at the encoded keys. + */ + @VisibleForTesting + String toKey(Object value) { + StringBuilder sb = new StringBuilder(ENTRY_PREFIX); + + if (value instanceof String) { + sb.append(value); + } else if (value instanceof Boolean) { + sb.append(((Boolean) value).toString().toLowerCase()); + } else { + int encodedLen; + + if (value instanceof Integer) { + encodedLen = INT_ENCODED_LEN; + } else if (value instanceof Long) { + encodedLen = LONG_ENCODED_LEN; + } else if (value instanceof Short) { + encodedLen = SHORT_ENCODED_LEN; + } else if (value instanceof Byte) { + encodedLen = BYTE_ENCODED_LEN; + } else { + throw new IllegalArgumentException(String.format("Type %s not allowed as key.", + value.getClass().getName())); + } + + long longValue = ((Number) value).longValue(); + String strVal; + if (longValue == Long.MIN_VALUE) { + // Math.abs() overflows for Long.MIN_VALUE. + strVal = String.valueOf(longValue).substring(1); + } else { + strVal = String.valueOf(Math.abs(longValue)); + } + + sb.append(longValue >= 0 ? POSITIVE_MARKER : NEGATIVE_MARKER); + + char fill = longValue >= 0 ? POSITIVE_FILL : NEGATIVE_FILL; + for (int i = 0; i < encodedLen - strVal.length() - 1; i++) { + sb.append(fill); + } + + sb.append(strVal); + } + + return sb.toString(); + } + + } + + /** + * Abstracts the difference between invoking a Field and a Method. + */ + @VisibleForTesting + interface Accessor { + + Object get(Object instance) throws Exception; + + } + + private class FieldAccessor implements Accessor { + + private final Field field; + + FieldAccessor(Field field) { + this.field = field; + } + + @Override + public Object get(Object instance) throws Exception { + return field.get(instance); + } + + } + + private class MethodAccessor implements Accessor { + + private final Method method; + + MethodAccessor(Method method) { + this.method = method; + } + + @Override + public Object get(Object instance) throws Exception { + return method.invoke(instance); + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java new file mode 100644 index 0000000000000..2ed246e4f4c97 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.IOException; + +/** + * Exception thrown when the store implementation is not compatible with the underlying data. + */ +public class UnsupportedStoreVersionException extends IOException { + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java new file mode 100644 index 0000000000000..2bea5b560681f --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import com.google.common.base.Objects; + +public class CustomType1 { + + @KVIndex + public String key; + + @KVIndex("id") + public String id; + + @KVIndex(value = "name", copy = true) + public String name; + + @KVIndex("int") + public int num; + + @Override + public boolean equals(Object o) { + if (o instanceof CustomType1) { + CustomType1 other = (CustomType1) o; + return id.equals(other.id) && name.equals(other.name); + } + return false; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("key", key) + .add("id", id) + .add("name", name) + .add("num", num) + .toString(); + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java new file mode 100644 index 0000000000000..aecea26ec82f3 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Slf4jReporter; +import com.codahale.metrics.Snapshot; +import com.codahale.metrics.Timer; +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +/** + * A set of small benchmarks for the LevelDB implementation. + * + * The benchmarks are run over two different types (one with just a natural index, and one + * with a ref index), over a set of 2^20 elements, and the following tests are performed: + * + * - write (then update) elements in sequential natural key order + * - write (then update) elements in random natural key order + * - iterate over natural index, ascending and descending + * - iterate over ref index, ascending and descending + */ +@Ignore +public class LevelDBBenchmark { + + private static final int COUNT = 1024; + private static final AtomicInteger IDGEN = new AtomicInteger(); + private static final MetricRegistry metrics = new MetricRegistry(); + private static final Timer dbCreation = metrics.timer("dbCreation"); + private static final Timer dbClose = metrics.timer("dbClose"); + + private LevelDB db; + private File dbpath; + + @Before + public void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + try(Timer.Context ctx = dbCreation.time()) { + db = new LevelDB(dbpath); + } + } + + @After + public void cleanup() throws Exception { + if (db != null) { + try(Timer.Context ctx = dbClose.time()) { + db.close(); + } + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @AfterClass + public static void report() { + if (metrics.getTimers().isEmpty()) { + return; + } + + int headingPrefix = 0; + for (Map.Entry e : metrics.getTimers().entrySet()) { + headingPrefix = Math.max(e.getKey().length(), headingPrefix); + } + headingPrefix += 4; + + StringBuilder heading = new StringBuilder(); + for (int i = 0; i < headingPrefix; i++) { + heading.append(" "); + } + heading.append("\tcount"); + heading.append("\tmean"); + heading.append("\tmin"); + heading.append("\tmax"); + heading.append("\t95th"); + System.out.println(heading); + + for (Map.Entry e : metrics.getTimers().entrySet()) { + StringBuilder row = new StringBuilder(); + row.append(e.getKey()); + for (int i = 0; i < headingPrefix - e.getKey().length(); i++) { + row.append(" "); + } + + Snapshot s = e.getValue().getSnapshot(); + row.append("\t").append(e.getValue().getCount()); + row.append("\t").append(toMs(s.getMean())); + row.append("\t").append(toMs(s.getMin())); + row.append("\t").append(toMs(s.getMax())); + row.append("\t").append(toMs(s.get95thPercentile())); + + System.out.println(row); + } + + Slf4jReporter.forRegistry(metrics).outputTo(LoggerFactory.getLogger(LevelDBBenchmark.class)) + .build().report(); + } + + private static String toMs(double nanos) { + return String.format("%.3f", nanos / 1000 / 1000); + } + + @Test + public void sequentialWritesNoIndex() throws Exception { + List entries = createSimpleType(); + writeAll(entries, false, "sequentialWritesNoIndex"); + writeAll(entries, false, "sequentialUpdatesNoIndex"); + deleteNoIndex(entries, false, "sequentialDeleteNoIndex"); + } + + @Test + public void sequentialSyncWritesNoIndex() throws Exception { + List entries = createSimpleType(); + writeAll(entries, true, "sequentialSyncWritesNoIndex"); + writeAll(entries, true, "sequentialSyncUpdatesNoIndex"); + deleteNoIndex(entries, true, "sequentialSyncDeleteNoIndex"); + } + + @Test + public void randomWritesNoIndex() throws Exception { + List entries = createSimpleType(); + + Collections.shuffle(entries); + writeAll(entries, false, "randomWritesNoIndex"); + + Collections.shuffle(entries); + writeAll(entries, false, "randomUpdatesNoIndex"); + + Collections.shuffle(entries); + deleteNoIndex(entries, false, "randomDeletesNoIndex"); + } + + @Test + public void randomSyncWritesNoIndex() throws Exception { + List entries = createSimpleType(); + + Collections.shuffle(entries); + writeAll(entries, true, "randomSyncWritesNoIndex"); + + Collections.shuffle(entries); + writeAll(entries, true, "randomSyncUpdatesNoIndex"); + + Collections.shuffle(entries); + deleteNoIndex(entries, true, "randomSyncDeletesNoIndex"); + } + + @Test + public void sequentialWritesIndexedType() throws Exception { + List entries = createIndexedType(); + writeAll(entries, false, "sequentialWritesIndexed"); + writeAll(entries, false, "sequentialUpdatesIndexed"); + deleteIndexed(entries, false, "sequentialDeleteIndexed"); + } + + @Test + public void sequentialSyncWritesIndexedType() throws Exception { + List entries = createIndexedType(); + writeAll(entries, true, "sequentialSyncWritesIndexed"); + writeAll(entries, true, "sequentialSyncUpdatesIndexed"); + deleteIndexed(entries, true, "sequentialSyncDeleteIndexed"); + } + + @Test + public void randomWritesIndexedTypeAndIteration() throws Exception { + List entries = createIndexedType(); + + Collections.shuffle(entries); + writeAll(entries, false, "randomWritesIndexed"); + + Collections.shuffle(entries); + writeAll(entries, false, "randomUpdatesIndexed"); + + // Run iteration benchmarks here since we've gone through the trouble of writing all + // the data already. + KVStoreView view = db.view(IndexedType.class); + iterate(view, "naturalIndex"); + iterate(view.reverse(), "naturalIndexDescending"); + iterate(view.index("name"), "refIndex"); + iterate(view.index("name").reverse(), "refIndexDescending"); + + Collections.shuffle(entries); + deleteIndexed(entries, false, "randomDeleteIndexed"); + } + + @Test + public void randomSyncWritesIndexedTypeAndIteration() throws Exception { + List entries = createIndexedType(); + + Collections.shuffle(entries); + writeAll(entries, true, "randomSyncWritesIndexed"); + + Collections.shuffle(entries); + deleteIndexed(entries, true, "randomSyncDeleteIndexed"); + } + + private void iterate(KVStoreView view, String name) throws Exception { + Timer create = metrics.timer(name + "CreateIterator"); + Timer iter = metrics.timer(name + "Iteration"); + KVStoreIterator it = null; + { + // Create the iterator several times, just to have multiple data points. + for (int i = 0; i < 1024; i++) { + if (it != null) { + it.close(); + } + try(Timer.Context ctx = create.time()) { + it = view.closeableIterator(); + } + } + } + + for (; it.hasNext(); ) { + try(Timer.Context ctx = iter.time()) { + it.next(); + } + } + } + + private void writeAll(List entries, boolean sync, String timerName) throws Exception { + Timer timer = newTimer(timerName); + for (Object o : entries) { + try(Timer.Context ctx = timer.time()) { + db.write(o, sync); + } + } + } + + private void deleteNoIndex(List entries, boolean sync, String timerName) + throws Exception { + Timer delete = newTimer(timerName); + for (SimpleType i : entries) { + try(Timer.Context ctx = delete.time()) { + db.delete(i.getClass(), i.key, sync); + } + } + } + + private void deleteIndexed(List entries, boolean sync, String timerName) + throws Exception { + Timer delete = newTimer(timerName); + for (IndexedType i : entries) { + try(Timer.Context ctx = delete.time()) { + db.delete(i.getClass(), i.key, sync); + } + } + } + + private List createSimpleType() { + List entries = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + SimpleType t = new SimpleType(); + t.key = IDGEN.getAndIncrement(); + t.name = "name" + (t.key % 1024); + entries.add(t); + } + return entries; + } + + private List createIndexedType() { + List entries = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + IndexedType t = new IndexedType(); + t.key = IDGEN.getAndIncrement(); + t.name = "name" + (t.key % 1024); + entries.add(t); + } + return entries; + } + + private Timer newTimer(String name) { + assertNull("Timer already exists: " + name, metrics.getTimers().get(name)); + return metrics.timer(name); + } + + public static class SimpleType { + + @KVIndex + public int key; + + public String name; + + } + + public static class IndexedType { + + @KVIndex + public int key; + + @KVIndex("name") + public String name; + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java new file mode 100644 index 0000000000000..b67503b3fcbea --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Predicate; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import org.apache.commons.io.FileUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBIteratorSuite { + + private static final int MIN_ENTRIES = 42; + private static final int MAX_ENTRIES = 1024; + private static final Random RND = new Random(); + + private static List allEntries; + private static List clashingEntries; + private static LevelDB db; + private static File dbpath; + + private abstract class BaseComparator implements Comparator { + /** + * Returns a comparator that falls back to natural order if this comparator's ordering + * returns equality for two elements. Used to mimic how the index sorts things internally. + */ + BaseComparator fallback() { + return new BaseComparator() { + @Override + public int compare(CustomType1 t1, CustomType1 t2) { + int result = BaseComparator.this.compare(t1, t2); + if (result != 0) { + return result; + } + + return t1.key.compareTo(t2.key); + } + }; + } + + /** Reverses the order of this comparator. */ + BaseComparator reverse() { + return new BaseComparator() { + @Override + public int compare(CustomType1 t1, CustomType1 t2) { + return -BaseComparator.this.compare(t1, t2); + } + }; + } + } + + private final BaseComparator NATURAL_ORDER = new BaseComparator() { + @Override + public int compare(CustomType1 t1, CustomType1 t2) { + return t1.key.compareTo(t2.key); + } + }; + + private final BaseComparator REF_INDEX_ORDER = new BaseComparator() { + @Override + public int compare(CustomType1 t1, CustomType1 t2) { + return t1.id.compareTo(t2.id); + } + }; + + private final BaseComparator COPY_INDEX_ORDER = new BaseComparator() { + @Override + public int compare(CustomType1 t1, CustomType1 t2) { + return t1.name.compareTo(t2.name); + } + }; + + private final BaseComparator NUMERIC_INDEX_ORDER = new BaseComparator() { + @Override + public int compare(CustomType1 t1, CustomType1 t2) { + return t1.num - t2.num; + } + }; + + @BeforeClass + public static void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + db = new LevelDB(dbpath); + + int count = RND.nextInt(MAX_ENTRIES) + MIN_ENTRIES; + + // Instead of generating sequential IDs, generate random unique IDs to avoid the insertion + // order matching the natural ordering. Just in case. + boolean[] usedIDs = new boolean[count]; + + allEntries = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + CustomType1 t = new CustomType1(); + + int id; + do { + id = RND.nextInt(count); + } while (usedIDs[id]); + + usedIDs[id] = true; + t.key = "key" + id; + t.id = "id" + i; + t.name = "name" + RND.nextInt(MAX_ENTRIES); + t.num = RND.nextInt(MAX_ENTRIES); + allEntries.add(t); + db.write(t); + } + + // Pick the first generated value, and forcefully create a few entries that will clash + // with the indexed values (id and name), to make sure the index behaves correctly when + // multiple entities are indexed by the same value. + // + // This also serves as a test for the test code itself, to make sure it's sorting indices + // the same way the store is expected to. + CustomType1 first = allEntries.get(0); + clashingEntries = new ArrayList<>(); + for (int i = 0; i < RND.nextInt(MIN_ENTRIES) + 1; i++) { + CustomType1 t = new CustomType1(); + t.key = "n-key" + (count + i); + t.id = first.id; + t.name = first.name; + t.num = first.num; + allEntries.add(t); + clashingEntries.add(t); + db.write(t); + } + + // Create another entry that could cause problems: take the first entry, and make its indexed + // name be an extension of the existing ones, to make sure the implementation sorts these + // correctly even considering the separator character (shorter strings first). + CustomType1 t = new CustomType1(); + t.key = "extended-key-0"; + t.id = first.id; + t.name = first.name + "a"; + t.num = first.num; + allEntries.add(t); + db.write(t); + } + + @AfterClass + public static void cleanup() throws Exception { + allEntries = null; + if (db != null) { + db.close(); + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @Test + public void naturalIndex() throws Exception { + testIteration(NATURAL_ORDER, view(), null); + } + + @Test + public void refIndex() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id"), null); + } + + @Test + public void copyIndex() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name"), null); + } + + @Test + public void numericIndex() throws Exception { + testIteration(NUMERIC_INDEX_ORDER, view().index("int"), null); + } + + @Test + public void naturalIndexDescending() throws Exception { + testIteration(NATURAL_ORDER, view().reverse(), null); + } + + @Test + public void refIndexDescending() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id").reverse(), null); + } + + @Test + public void copyIndexDescending() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").reverse(), null); + } + + @Test + public void numericIndexDescending() throws Exception { + testIteration(NUMERIC_INDEX_ORDER, view().index("int").reverse(), null); + } + + @Test + public void naturalIndexWithStart() throws Exception { + CustomType1 first = pickFirst(); + testIteration(NATURAL_ORDER, view().first(first.key), first); + } + + @Test + public void refIndexWithStart() throws Exception { + CustomType1 first = pickFirst(); + testIteration(REF_INDEX_ORDER, view().index("id").first(first.id), first); + } + + @Test + public void copyIndexWithStart() throws Exception { + CustomType1 first = pickFirst(); + testIteration(COPY_INDEX_ORDER, view().index("name").first(first.name), first); + } + + @Test + public void numericIndexWithStart() throws Exception { + CustomType1 first = pickFirst(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").first(first.num), first); + } + + @Test + public void naturalIndexDescendingWithStart() throws Exception { + CustomType1 first = pickFirst(); + testIteration(NATURAL_ORDER, view().reverse().first(first.key), first); + } + + @Test + public void refIndexDescendingWithStart() throws Exception { + CustomType1 first = pickFirst(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").first(first.id), first); + } + + @Test + public void copyIndexDescendingWithStart() throws Exception { + CustomType1 first = pickFirst(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").first(first.name), + first); + } + + @Test + public void numericIndexDescendingWithStart() throws Exception { + CustomType1 first = pickFirst(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").first(first.num), + first); + } + + @Test + public void naturalIndexWithSkip() throws Exception { + testIteration(NATURAL_ORDER, view().skip(RND.nextInt(allEntries.size() / 2)), null); + } + + @Test + public void refIndexWithSkip() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id").skip(RND.nextInt(allEntries.size() / 2)), + null); + } + + @Test + public void copyIndexWithSkip() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").skip(RND.nextInt(allEntries.size() / 2)), + null); + } + + @Test + public void testRefWithIntNaturalKey() throws Exception { + LevelDBSuite.IntKeyType i = new LevelDBSuite.IntKeyType(); + i.key = 1; + i.id = "1"; + i.values = Arrays.asList("1"); + + db.write(i); + + try(KVStoreIterator it = db.view(i.getClass()).closeableIterator()) { + Object read = it.next(); + assertEquals(i, read); + } + } + + private CustomType1 pickFirst() { + // Picks a first element that has clashes with other elements in the given index. + return clashingEntries.get(RND.nextInt(clashingEntries.size())); + } + + /** + * Compares the two values and falls back to comparing the natural key of CustomType1 + * if they're the same, to mimic the behavior of the indexing code. + */ + private > int compareWithFallback( + T v1, + T v2, + CustomType1 ct1, + CustomType1 ct2) { + int result = v1.compareTo(v2); + if (result != 0) { + return result; + } + + return ct1.key.compareTo(ct2.key); + } + + private void testIteration( + final BaseComparator order, + final KVStoreView params, + final CustomType1 first) throws Exception { + List indexOrder = sortBy(order.fallback()); + if (!params.ascending) { + indexOrder = Lists.reverse(indexOrder); + } + + Iterable expected = indexOrder; + if (first != null) { + final BaseComparator expectedOrder = params.ascending ? order : order.reverse(); + expected = Iterables.filter(expected, new Predicate() { + @Override + public boolean apply(CustomType1 v) { + return expectedOrder.compare(first, v) <= 0; + } + }); + } + + if (params.skip > 0) { + expected = Iterables.skip(expected, (int) params.skip); + } + + List actual = collect(params); + compareLists(expected, actual); + } + + /** Could use assertEquals(), but that creates hard to read errors for large lists. */ + private void compareLists(Iterable expected, List actual) { + Iterator expectedIt = expected.iterator(); + Iterator actualIt = actual.iterator(); + + int count = 0; + while (expectedIt.hasNext()) { + if (!actualIt.hasNext()) { + break; + } + count++; + assertEquals(expectedIt.next(), actualIt.next()); + } + + String message; + Object[] remaining; + int expectedCount = count; + int actualCount = count; + + if (expectedIt.hasNext()) { + remaining = Iterators.toArray(expectedIt, Object.class); + expectedCount += remaining.length; + message = "missing"; + } else { + remaining = Iterators.toArray(actualIt, Object.class); + actualCount += remaining.length; + message = "stray"; + } + + assertEquals(String.format("Found %s elements: %s", message, Arrays.asList(remaining)), + expectedCount, actualCount); + } + + private KVStoreView view() throws Exception { + return db.view(CustomType1.class); + } + + private List collect(KVStoreView view) throws Exception { + return Arrays.asList(Iterables.toArray(view, CustomType1.class)); + } + + private List sortBy(Comparator comp) { + List copy = new ArrayList<>(allEntries); + Collections.sort(copy, comp); + return copy; + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java new file mode 100644 index 0000000000000..2f83345ba8d5a --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.Iterators; +import org.apache.commons.io.FileUtils; +import org.iq80.leveldb.DBIterator; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBSuite { + + private LevelDB db; + private File dbpath; + + @After + public void cleanup() throws Exception { + if (db != null) { + db.close(); + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @Before + public void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + db = new LevelDB(dbpath); + } + + @Test + public void testReopenAndVersionCheckDb() throws Exception { + db.close(); + db = null; + assertTrue(dbpath.exists()); + + db = new LevelDB(dbpath); + assertEquals(LevelDB.STORE_VERSION, + db.serializer.deserializeLong(db.db.get(LevelDB.STORE_VERSION_KEY))); + db.db.put(LevelDB.STORE_VERSION_KEY, db.serializer.serialize(LevelDB.STORE_VERSION + 1)); + db.close(); + db = null; + + try { + db = new LevelDB(dbpath); + fail("Should have failed version check."); + } catch (UnsupportedStoreVersionException e) { + // Expected. + } + } + + @Test + public void testStringWriteReadDelete() throws Exception { + String string = "testString"; + byte[] key = string.getBytes(UTF_8); + testReadWriteDelete(key, string); + } + + @Test + public void testIntWriteReadDelete() throws Exception { + int value = 42; + byte[] key = "key".getBytes(UTF_8); + testReadWriteDelete(key, value); + } + + @Test + public void testSimpleTypeWriteReadDelete() throws Exception { + byte[] key = "testKey".getBytes(UTF_8); + CustomType1 t = new CustomType1(); + t.id = "id"; + t.name = "name"; + testReadWriteDelete(key, t); + } + + @Test + public void testObjectWriteReadDelete() throws Exception { + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + + try { + db.read(CustomType1.class, t.key); + fail("Expected exception for non-existant object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + db.write(t); + assertEquals(t, db.read(t.getClass(), t.key)); + assertEquals(1L, db.count(t.getClass())); + + db.delete(t.getClass(), t.key); + try { + db.read(t.getClass(), t.key); + fail("Expected exception for deleted object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + // Look into the actual DB and make sure that all the keys related to the type have been + // removed. + assertEquals(0, countKeys(t.getClass())); + } + + @Test + public void testMultipleObjectWriteReadDelete() throws Exception { + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.id = "id"; + t1.name = "name1"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.id = "id"; + t2.name = "name2"; + + db.write(t1); + db.write(t2); + + assertEquals(t1, db.read(t1.getClass(), t1.key)); + assertEquals(t2, db.read(t2.getClass(), t2.key)); + assertEquals(2L, db.count(t1.getClass())); + + // There should be one "id" index entry with two values. + assertEquals(2, countIndexEntries(t1.getClass(), "id", t1.id)); + + // Delete the first entry; now there should be 3 remaining keys, since one of the "name" + // index entries should have been removed. + db.delete(t1.getClass(), t1.key); + + // Make sure there's a single entry in the "id" index now. + assertEquals(1, countIndexEntries(t2.getClass(), "id", t2.id)); + + // Delete the remaining entry, make sure all data is gone. + db.delete(t2.getClass(), t2.key); + assertEquals(0, countKeys(t2.getClass())); + } + + @Test + public void testMultipleTypesWriteReadDelete() throws Exception { + CustomType1 t1 = new CustomType1(); + t1.key = "1"; + t1.id = "id"; + t1.name = "name1"; + + IntKeyType t2 = new IntKeyType(); + t2.key = 2; + t2.id = "2"; + t2.values = Arrays.asList("value1", "value2"); + + db.write(t1); + db.write(t2); + + assertEquals(t1, db.read(t1.getClass(), t1.key)); + assertEquals(t2, db.read(t2.getClass(), t2.key)); + + // There should be one "id" index with a single entry for each type. + assertEquals(1, countIndexEntries(t1.getClass(), "id", t1.id)); + assertEquals(1, countIndexEntries(t2.getClass(), "id", t2.id)); + + // Delete the first entry; this should not affect the entries for the second type. + db.delete(t1.getClass(), t1.key); + assertEquals(0, countKeys(t1.getClass())); + assertEquals(1, countIndexEntries(t2.getClass(), "id", t2.id)); + + // Delete the remaining entry, make sure all data is gone. + db.delete(t2.getClass(), t2.key); + assertEquals(0, countKeys(t2.getClass())); + } + + @Test + public void testMetadata() throws Exception { + assertNull(db.getMetadata(CustomType1.class)); + + CustomType1 t = new CustomType1(); + t.id = "id"; + t.name = "name"; + + db.setMetadata(t); + assertEquals(t, db.getMetadata(CustomType1.class)); + + db.setMetadata(null); + assertNull(db.getMetadata(CustomType1.class)); + } + + private long countIndexEntries(Class type, String index, Object value) throws Exception { + LevelDBTypeInfo.Index idx = db.getTypeInfo(type).index(index); + return idx.getCount(idx.end()); + } + + private int countKeys(Class type) throws Exception { + byte[] prefix = db.getTypeInfo(type).keyPrefix(); + int count = 0; + + DBIterator it = db.db.iterator(); + it.seek(prefix); + + while (it.hasNext()) { + byte[] key = it.next().getKey(); + if (LevelDBIterator.startsWith(key, prefix)) { + count++; + } + } + + return count; + } + + private void testReadWriteDelete(byte[] key, T value) throws Exception { + try { + db.get(key, value.getClass()); + fail("Expected exception for non-existent key."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + db.put(key, value); + assertEquals(value, db.get(key, value.getClass())); + + db.delete(key); + try { + db.get(key, value.getClass()); + fail("Expected exception for deleted key."); + } catch (NoSuchElementException nsee) { + // Expected. + } + } + + public static class IntKeyType { + + @KVIndex + public int key; + + @KVIndex("id") + public String id; + + public List values; + + @Override + public boolean equals(Object o) { + if (o instanceof IntKeyType) { + IntKeyType other = (IntKeyType) o; + return key == other.key && id.equals(other.id) && values.equals(other.values); + } + return false; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java new file mode 100644 index 0000000000000..2ef69fd33ce19 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBTypeInfoSuite { + + @Test + public void testIndexAnnotation() throws Exception { + LevelDBTypeInfo ti = newTypeInfo(CustomType1.class); + assertEquals(4, ti.indices().size()); + + CustomType1 t1 = new CustomType1(); + t1.key = "key"; + t1.id = "id"; + t1.name = "name"; + t1.num = 42; + + assertEquals(t1.key, ti.naturalIndex().accessor.get(t1)); + assertEquals(t1.id, ti.index("id").accessor.get(t1)); + assertEquals(t1.name, ti.index("name").accessor.get(t1)); + assertEquals(t1.num, ti.index("int").accessor.get(t1)); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoNaturalIndex() throws Exception { + newTypeInfo(NoNaturalIndex.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testDuplicateIndex() throws Exception { + newTypeInfo(DuplicateIndex.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyIndexName() throws Exception { + newTypeInfo(EmptyIndexName.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalIndexName() throws Exception { + newTypeInfo(IllegalIndexName.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalIndexMethod() throws Exception { + newTypeInfo(IllegalIndexMethod.class); + } + + @Test + public void testKeyClashes() throws Exception { + LevelDBTypeInfo ti = newTypeInfo(CustomType1.class); + + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.name = "a"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.name = "aa"; + + CustomType1 t3 = new CustomType1(); + t3.key = "key3"; + t3.name = "aaa"; + + // Make sure entries with conflicting names are sorted correctly. + assertBefore(ti.index("name").entityKey(t1), ti.index("name").entityKey(t2)); + assertBefore(ti.index("name").entityKey(t1), ti.index("name").entityKey(t3)); + assertBefore(ti.index("name").entityKey(t2), ti.index("name").entityKey(t3)); + } + + @Test + public void testNumEncoding() throws Exception { + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + + assertBefore(idx.toKey(1), idx.toKey(2)); + assertBefore(idx.toKey(-1), idx.toKey(2)); + assertBefore(idx.toKey(-11), idx.toKey(2)); + assertBefore(idx.toKey(-11), idx.toKey(-1)); + assertBefore(idx.toKey(1), idx.toKey(11)); + assertBefore(idx.toKey(Integer.MIN_VALUE), idx.toKey(Integer.MAX_VALUE)); + assertEquals(LevelDBTypeInfo.INT_ENCODED_LEN + LevelDBTypeInfo.ENTRY_PREFIX.length(), + idx.toKey(Integer.MIN_VALUE).length()); + + assertBefore(idx.toKey(1L), idx.toKey(2L)); + assertBefore(idx.toKey(-1L), idx.toKey(2L)); + assertBefore(idx.toKey(Long.MIN_VALUE), idx.toKey(Long.MAX_VALUE)); + assertEquals(LevelDBTypeInfo.LONG_ENCODED_LEN + LevelDBTypeInfo.ENTRY_PREFIX.length(), + idx.toKey(Long.MIN_VALUE).length()); + + assertBefore(idx.toKey((short) 1), idx.toKey((short) 2)); + assertBefore(idx.toKey((short) -1), idx.toKey((short) 2)); + assertBefore(idx.toKey(Short.MIN_VALUE), idx.toKey(Short.MAX_VALUE)); + assertEquals(LevelDBTypeInfo.SHORT_ENCODED_LEN + LevelDBTypeInfo.ENTRY_PREFIX.length(), + idx.toKey(Short.MIN_VALUE).length()); + + assertBefore(idx.toKey((byte) 1), idx.toKey((byte) 2)); + assertBefore(idx.toKey((byte) -1), idx.toKey((byte) 2)); + assertBefore(idx.toKey(Byte.MIN_VALUE), idx.toKey(Byte.MAX_VALUE)); + assertEquals(LevelDBTypeInfo.BYTE_ENCODED_LEN + LevelDBTypeInfo.ENTRY_PREFIX.length(), + idx.toKey(Byte.MIN_VALUE).length()); + + assertEquals(LevelDBTypeInfo.ENTRY_PREFIX + "false", idx.toKey(false)); + assertEquals(LevelDBTypeInfo.ENTRY_PREFIX + "true", idx.toKey(true)); + } + + private LevelDBTypeInfo newTypeInfo(Class type) throws Exception { + return new LevelDBTypeInfo<>(null, type); + } + + private void assertBefore(byte[] key1, byte[] key2) { + assertBefore(new String(key1, UTF_8), new String(key2, UTF_8)); + } + + private void assertBefore(String str1, String str2) { + assertTrue(String.format("%s < %s failed", str1, str2), str1.compareTo(str2) < 0); + } + + public static class NoNaturalIndex { + + public String id; + + } + + public static class DuplicateIndex { + + @KVIndex("id") + public String id; + + @KVIndex("id") + public String id2; + + } + + public static class EmptyIndexName { + + @KVIndex("") + public String id; + + } + + public static class IllegalIndexName { + + @KVIndex("__invalid") + public String id; + + } + + public static class IllegalIndexMethod { + + @KVIndex("id") + public String id(boolean illegalParam) { + return null; + } + + } + +} diff --git a/common/kvstore/src/test/resources/log4j.properties b/common/kvstore/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e8da774f7ca9e --- /dev/null +++ b/common/kvstore/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/pom.xml b/pom.xml index 0533a8dcf2e0a..6835ea14cd42b 100644 --- a/pom.xml +++ b/pom.xml @@ -83,6 +83,7 @@ common/sketch + common/kvstore common/network-common common/network-shuffle common/unsafe @@ -441,6 +442,11 @@ httpcore ${commons.httpcore.version}
+ + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + org.seleniumhq.selenium selenium-java @@ -588,6 +594,11 @@ metrics-graphite ${codahale.metrics.version} + + com.fasterxml.jackson.core + jackson-core + ${fasterxml.jackson.version} + com.fasterxml.jackson.core jackson-databind diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b5362ec1ae452..89b0c7a3ab7b0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -50,10 +50,10 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( - core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, _* + core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", - "tags", "sketch" + "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(mesos, yarn, sparkGangliaLgpl, @@ -310,7 +310,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010 + unsafe, tags, sqlKafka010, kvstore ).contains(x) } From 52ed2b45c09e7104e4fef5adcf78025f53b7a8e0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 1 Nov 2016 11:34:25 -0700 Subject: [PATCH 22/30] SHS-NG M1: Add support for arrays when indexing. This is needed because some UI types have compound keys. --- .../org/apache/spark/kvstore/KVIndex.java | 3 +- .../apache/spark/kvstore/LevelDBTypeInfo.java | 12 ++++++ .../apache/spark/kvstore/LevelDBSuite.java | 37 ++++++++++++++++++- .../spark/kvstore/LevelDBTypeInfoSuite.java | 12 ++++++ 4 files changed, 62 insertions(+), 2 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java index 3c61e7706079a..bf5e4a66e510f 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java @@ -42,7 +42,8 @@ *

* *

- * Indices are restricted to String, and integral types (byte, short, int, long, boolean). + * Indices are restricted to String, integral types (byte, short, int, long, boolean), and arrays + * of those values. *

*/ @Retention(RetentionPolicy.RUNTIME) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java index c49b18324c00a..947d6c4945e37 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -17,6 +17,7 @@ package org.apache.spark.kvstore; +import java.lang.reflect.Array; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.io.ByteArrayOutputStream; @@ -283,6 +284,8 @@ private void updateCount(WriteBatch batch, byte[] key, long delta) throws Except * encoded value would look like "*~~~~~~~123"). For positive values, similarly, a value that * is "lower than 0" (".") is used for padding. The fill characters were chosen for readability * when looking at the encoded keys. + * + * Arrays are encoded by encoding each element separately, separated by KEY_SEPARATOR. */ @VisibleForTesting String toKey(Object value) { @@ -292,6 +295,15 @@ String toKey(Object value) { sb.append(value); } else if (value instanceof Boolean) { sb.append(((Boolean) value).toString().toLowerCase()); + } else if (value.getClass().isArray()) { + int length = Array.getLength(value); + for (int i = 0; i < length; i++) { + sb.append(toKey(Array.get(value, i))); + sb.append(KEY_SEPARATOR); + } + if (length > 0) { + sb.setLength(sb.length() - 1); + } } else { int encodedLen; diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java index 2f83345ba8d5a..868baf743e027 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -175,24 +175,35 @@ public void testMultipleTypesWriteReadDelete() throws Exception { t2.id = "2"; t2.values = Arrays.asList("value1", "value2"); + ArrayKeyIndexType t3 = new ArrayKeyIndexType(); + t3.key = new int[] { 42, 84 }; + t3.id = new String[] { "id1", "id2" }; + db.write(t1); db.write(t2); + db.write(t3); assertEquals(t1, db.read(t1.getClass(), t1.key)); assertEquals(t2, db.read(t2.getClass(), t2.key)); + assertEquals(t3, db.read(t3.getClass(), t3.key)); // There should be one "id" index with a single entry for each type. assertEquals(1, countIndexEntries(t1.getClass(), "id", t1.id)); assertEquals(1, countIndexEntries(t2.getClass(), "id", t2.id)); + assertEquals(1, countIndexEntries(t3.getClass(), "id", t3.id)); // Delete the first entry; this should not affect the entries for the second type. db.delete(t1.getClass(), t1.key); assertEquals(0, countKeys(t1.getClass())); assertEquals(1, countIndexEntries(t2.getClass(), "id", t2.id)); + assertEquals(1, countIndexEntries(t3.getClass(), "id", t3.id)); - // Delete the remaining entry, make sure all data is gone. + // Delete the remaining entries, make sure all data is gone. db.delete(t2.getClass(), t2.key); assertEquals(0, countKeys(t2.getClass())); + + db.delete(t3.getClass(), t3.key); + assertEquals(0, countKeys(t3.getClass())); } @Test @@ -278,4 +289,28 @@ public int hashCode() { } + public static class ArrayKeyIndexType { + + @KVIndex + public int[] key; + + @KVIndex("id") + public String[] id; + + @Override + public boolean equals(Object o) { + if (o instanceof ArrayKeyIndexType) { + ArrayKeyIndexType other = (ArrayKeyIndexType) o; + return Arrays.equals(key, other.key) && Arrays.equals(id, other.id); + } + return false; + } + + @Override + public int hashCode() { + return key.hashCode(); + } + + } + } diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java index 2ef69fd33ce19..7ee0b24552219 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java @@ -123,6 +123,18 @@ public void testNumEncoding() throws Exception { assertEquals(LevelDBTypeInfo.ENTRY_PREFIX + "true", idx.toKey(true)); } + @Test + public void testArrayIndices() throws Exception { + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + + assertBefore(idx.toKey(new String[] { "str1" }), idx.toKey(new String[] { "str2" })); + assertBefore(idx.toKey(new String[] { "str1", "str2" }), + idx.toKey(new String[] { "str1", "str3" })); + + assertBefore(idx.toKey(new int[] { 1 }), idx.toKey(new int[] { 2 })); + assertBefore(idx.toKey(new int[] { 1, 2 }), idx.toKey(new int[] { 1, 3 })); + } + private LevelDBTypeInfo newTypeInfo(Class type) throws Exception { return new LevelDBTypeInfo<>(null, type); } From 4112afe723f85412035ad3a9c4801b583e74f876 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 3 Nov 2016 15:18:24 -0700 Subject: [PATCH 23/30] SHS-NG M1: Fix counts in LevelDB when updating entries. Also add unit test. When updating, the code needs to keep track of the aggregated delta to be added to each count stored in the db, instead of reading the count from the db for each update. --- .../org/apache/spark/kvstore/LevelDB.java | 10 +- .../apache/spark/kvstore/LevelDBTypeInfo.java | 22 +--- .../spark/kvstore/LevelDBWriteBatch.java | 113 ++++++++++++++++++ .../apache/spark/kvstore/LevelDBSuite.java | 26 ++++ 4 files changed, 150 insertions(+), 21 deletions(-) create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBWriteBatch.java diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java index 51287c02ebab1..2b7ea6889aee1 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -136,7 +136,7 @@ public void write(Object value, boolean sync) throws Exception { Preconditions.checkArgument(value != null, "Null values are not allowed."); LevelDBTypeInfo ti = getTypeInfo(value.getClass()); - WriteBatch batch = db.createWriteBatch(); + LevelDBWriteBatch batch = new LevelDBWriteBatch(this); try { byte[] data = serializer.serialize(value); synchronized (ti) { @@ -149,7 +149,7 @@ public void write(Object value, boolean sync) throws Exception { for (LevelDBTypeInfo.Index idx : ti.indices()) { idx.add(batch, value, data); } - db.write(batch, new WriteOptions().sync(sync)); + batch.write(sync); } } finally { batch.close(); @@ -163,7 +163,7 @@ public void delete(Class type, Object naturalKey) throws Exception { public void delete(Class type, Object naturalKey, boolean sync) throws Exception { Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); - WriteBatch batch = db.createWriteBatch(); + LevelDBWriteBatch batch = new LevelDBWriteBatch(this); try { LevelDBTypeInfo ti = getTypeInfo(type); byte[] key = ti.naturalIndex().start(naturalKey); @@ -172,7 +172,7 @@ public void delete(Class type, Object naturalKey, boolean sync) throws Except Object existing = serializer.deserialize(data, type); synchronized (ti) { removeInstance(ti, batch, existing); - db.write(batch, new WriteOptions().sync(sync)); + batch.write(sync); } } } finally { @@ -229,7 +229,7 @@ LevelDBTypeInfo getTypeInfo(Class type) throws Exception { return idx; } - private void removeInstance(LevelDBTypeInfo ti, WriteBatch batch, Object instance) + private void removeInstance(LevelDBTypeInfo ti, LevelDBWriteBatch batch, Object instance) throws Exception { for (LevelDBTypeInfo.Index idx : ti.indices()) { idx.remove(batch, instance); diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java index 947d6c4945e37..184b6611e0e0a 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -30,7 +30,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; -import org.iq80.leveldb.WriteBatch; /** * Holds metadata about app-specific types stored in LevelDB. Serves as a cache for data collected @@ -235,14 +234,14 @@ byte[] entityKey(Object entity) throws Exception { * @param data Serialized entity to store (when storing the entity, not a reference). * @param naturalKey The value's key. */ - void add(WriteBatch batch, Object entity, byte[] data) throws Exception { + void add(LevelDBWriteBatch batch, Object entity, byte[] data) throws Exception { byte[] stored = data; if (!copy) { stored = db.serializer.serialize(toKey(naturalIndex().accessor.get(entity))); } batch.put(entityKey(entity), stored); - updateCount(batch, end(accessor.get(entity)), 1L); - updateCount(batch, end(), 1L); + batch.updateCount(end(accessor.get(entity)), 1L); + batch.updateCount(end(), 1L); } /** @@ -252,10 +251,10 @@ void add(WriteBatch batch, Object entity, byte[] data) throws Exception { * @param entity The entity being removed, to identify the index entry to modify. * @param naturalKey The value's key. */ - void remove(WriteBatch batch, Object entity) throws Exception { + void remove(LevelDBWriteBatch batch, Object entity) throws Exception { batch.delete(entityKey(entity)); - updateCount(batch, end(accessor.get(entity)), -1L); - updateCount(batch, end(), -1L); + batch.updateCount(end(accessor.get(entity)), -1L); + batch.updateCount(end(), -1L); } long getCount(byte[] key) throws Exception { @@ -263,15 +262,6 @@ long getCount(byte[] key) throws Exception { return data != null ? db.serializer.deserializeLong(data) : 0; } - private void updateCount(WriteBatch batch, byte[] key, long delta) throws Exception { - long count = getCount(key) + delta; - if (count > 0) { - batch.put(key, db.serializer.serialize(count)); - } else { - batch.delete(key); - } - } - /** * Translates a value to be used as part of the store key. * diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBWriteBatch.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBWriteBatch.java new file mode 100644 index 0000000000000..a6ca190222931 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBWriteBatch.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import org.iq80.leveldb.DB; +import org.iq80.leveldb.WriteBatch; +import org.iq80.leveldb.WriteOptions; + +/** + * A wrapper around the LevelDB library's WriteBatch with some extra functionality for keeping + * track of counts. + */ +class LevelDBWriteBatch { + + private final LevelDB db; + private final Map deltas; + private final WriteBatch batch; + + LevelDBWriteBatch(LevelDB db) { + this.db = db; + this.batch = db.db.createWriteBatch(); + this.deltas = new HashMap<>(); + } + + void updateCount(byte[] key, long delta) { + KeyWrapper kw = new KeyWrapper(key); + Long fullDelta = deltas.get(kw); + if (fullDelta != null) { + fullDelta += delta; + } else { + fullDelta = delta; + } + deltas.put(kw, fullDelta); + } + + void put(byte[] key, byte[] value) { + batch.put(key, value); + } + + void delete(byte[] key) { + batch.delete(key); + } + + void write(boolean sync) { + for (Map.Entry e : deltas.entrySet()) { + long delta = e.getValue(); + if (delta == 0) { + continue; + } + + byte[] key = e.getKey().key; + byte[] data = db.db.get(key); + long count = data != null ? db.serializer.deserializeLong(data) : 0L; + long newCount = count + delta; + + if (newCount > 0) { + batch.put(key, db.serializer.serialize(newCount)); + } else { + batch.delete(key); + } + } + + db.db.write(batch, new WriteOptions().sync(sync)); + } + + void close() throws IOException { + batch.close(); + } + + private static class KeyWrapper { + + private final byte[] key; + + KeyWrapper(byte[] key) { + this.key = key; + } + + @Override + public boolean equals(Object other) { + if (other instanceof KeyWrapper) { + return Arrays.equals(key, ((KeyWrapper) other).key); + } + return false; + } + + @Override + public int hashCode() { + return Arrays.hashCode(key); + } + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java index 868baf743e027..5dc1c10765274 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -226,6 +226,32 @@ private long countIndexEntries(Class type, String index, Object value) throws return idx.getCount(idx.end()); } + @Test + public void testUpdate() throws Exception { + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + + db.write(t); + + t.name = "anotherName"; + + db.write(t); + + assertEquals(1, db.count(t.getClass())); + + LevelDBTypeInfo.Index ni = db.getTypeInfo(t.getClass()).index("name"); + assertEquals(1, ni.getCount(ni.end())); + assertEquals(1, ni.getCount(ni.end("anotherName"))); + try { + db.get(ni.end("name"), Integer.class); + fail("Should have gotten an exception."); + } catch (NoSuchElementException nsee) { + // Expected. + } + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; From 718cabd098dd6a534e7952066cd43f89f6875a14 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 17 Mar 2017 20:17:04 -0700 Subject: [PATCH 24/30] SHS-NG M1: Try to prevent db use after close. This causes JVM crashes in the leveldb library, so try to avoid it; if there are still issues, we'll neeed locking. --- .../org/apache/spark/kvstore/LevelDB.java | 41 ++++++++++++------- .../apache/spark/kvstore/LevelDBIterator.java | 6 +-- .../apache/spark/kvstore/LevelDBTypeInfo.java | 2 +- .../spark/kvstore/LevelDBWriteBatch.java | 6 +-- .../apache/spark/kvstore/LevelDBSuite.java | 6 +-- 5 files changed, 37 insertions(+), 24 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java index 2b7ea6889aee1..35cdbb6733a39 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -23,6 +23,7 @@ import java.util.NoSuchElementException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.annotations.VisibleForTesting; @@ -48,11 +49,10 @@ public class LevelDB implements KVStore { /** DB key where app metadata is stored. */ private static final byte[] METADATA_KEY = "__meta__".getBytes(UTF_8); - final DB db; + final AtomicReference _db; final KVStoreSerializer serializer; private final ConcurrentMap, LevelDBTypeInfo> types; - private boolean closed; public LevelDB(File path) throws IOException { this(path, new KVStoreSerializer()); @@ -64,16 +64,16 @@ public LevelDB(File path, KVStoreSerializer serializer) throws IOException { Options options = new Options(); options.createIfMissing(!path.exists()); - this.db = JniDBFactory.factory.open(path, options); + this._db = new AtomicReference<>(JniDBFactory.factory.open(path, options)); - byte[] versionData = db.get(STORE_VERSION_KEY); + byte[] versionData = db().get(STORE_VERSION_KEY); if (versionData != null) { long version = serializer.deserializeLong(versionData); if (version != STORE_VERSION) { throw new UnsupportedStoreVersionException(); } } else { - db.put(STORE_VERSION_KEY, serializer.serialize(STORE_VERSION)); + db().put(STORE_VERSION_KEY, serializer.serialize(STORE_VERSION)); } } @@ -91,13 +91,13 @@ public void setMetadata(Object value) throws Exception { if (value != null) { put(METADATA_KEY, value); } else { - db.delete(METADATA_KEY); + db().delete(METADATA_KEY); } } @Override public T get(byte[] key, Class klass) throws Exception { - byte[] data = db.get(key); + byte[] data = db().get(key); if (data == null) { throw new NoSuchElementException(new String(key, UTF_8)); } @@ -107,12 +107,12 @@ public T get(byte[] key, Class klass) throws Exception { @Override public void put(byte[] key, Object value) throws Exception { Preconditions.checkArgument(value != null, "Null values are not allowed."); - db.put(key, serializer.serialize(value)); + db().put(key, serializer.serialize(value)); } @Override public void delete(byte[] key) throws Exception { - db.delete(key); + db().delete(key); } @Override @@ -167,7 +167,7 @@ public void delete(Class type, Object naturalKey, boolean sync) throws Except try { LevelDBTypeInfo ti = getTypeInfo(type); byte[] key = ti.naturalIndex().start(naturalKey); - byte[] data = db.get(key); + byte[] data = db().get(key); if (data != null) { Object existing = serializer.deserialize(data, type); synchronized (ti) { @@ -201,14 +201,14 @@ public long count(Class type) throws Exception { } @Override - public synchronized void close() throws IOException { - if (closed) { + public void close() throws IOException { + DB _db = this._db.getAndSet(null); + if (_db == null) { return; } try { - db.close(); - closed = true; + _db.close(); } catch (IOException ioe) { throw ioe; } catch (Exception e) { @@ -229,6 +229,19 @@ LevelDBTypeInfo getTypeInfo(Class type) throws Exception { return idx; } + /** + * Try to avoid use-after close since that has the tendency of crashing the JVM. This doesn't + * prevent methods that retrieved the instance from using it after close, but hopefully will + * catch most cases; otherwise, we'll need some kind of locking. + */ + DB db() { + DB _db = this._db.get(); + if (_db == null) { + throw new IllegalStateException("DB is closed."); + } + return _db; + } + private void removeInstance(LevelDBTypeInfo ti, LevelDBWriteBatch batch, Object instance) throws Exception { for (LevelDBTypeInfo.Index idx : ti.indices()) { diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java index d0b6e25420812..b777ff7bafc02 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -52,7 +52,7 @@ class LevelDBIterator implements KVStoreIterator { this.type = type; this.ti = null; this.index = null; - this.it = db.db.iterator(); + this.it = db.db().iterator(); this.indexKeyPrefix = keyPrefix; this.end = null; it.seek(keyPrefix); @@ -64,7 +64,7 @@ class LevelDBIterator implements KVStoreIterator { LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { this.db = db; this.ascending = params.ascending; - this.it = db.db.iterator(); + this.it = db.db().iterator(); this.type = params.type; this.ti = db.getTypeInfo(type); this.index = ti.index(params.index); @@ -157,7 +157,7 @@ public boolean skip(long n) { } @Override - public void close() throws IOException { + public synchronized void close() throws IOException { if (!closed) { it.close(); closed = true; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java index 184b6611e0e0a..e0f3dc80cfe62 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -258,7 +258,7 @@ void remove(LevelDBWriteBatch batch, Object entity) throws Exception { } long getCount(byte[] key) throws Exception { - byte[] data = db.db.get(key); + byte[] data = db.db().get(key); return data != null ? db.serializer.deserializeLong(data) : 0; } diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBWriteBatch.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBWriteBatch.java index a6ca190222931..f3de251de554f 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBWriteBatch.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBWriteBatch.java @@ -38,7 +38,7 @@ class LevelDBWriteBatch { LevelDBWriteBatch(LevelDB db) { this.db = db; - this.batch = db.db.createWriteBatch(); + this.batch = db.db().createWriteBatch(); this.deltas = new HashMap<>(); } @@ -69,7 +69,7 @@ void write(boolean sync) { } byte[] key = e.getKey().key; - byte[] data = db.db.get(key); + byte[] data = db.db().get(key); long count = data != null ? db.serializer.deserializeLong(data) : 0L; long newCount = count + delta; @@ -80,7 +80,7 @@ void write(boolean sync) { } } - db.db.write(batch, new WriteOptions().sync(sync)); + db.db().write(batch, new WriteOptions().sync(sync)); } void close() throws IOException { diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java index 5dc1c10765274..3b43307c9580c 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -62,8 +62,8 @@ public void testReopenAndVersionCheckDb() throws Exception { db = new LevelDB(dbpath); assertEquals(LevelDB.STORE_VERSION, - db.serializer.deserializeLong(db.db.get(LevelDB.STORE_VERSION_KEY))); - db.db.put(LevelDB.STORE_VERSION_KEY, db.serializer.serialize(LevelDB.STORE_VERSION + 1)); + db.serializer.deserializeLong(db.db().get(LevelDB.STORE_VERSION_KEY))); + db.db().put(LevelDB.STORE_VERSION_KEY, db.serializer.serialize(LevelDB.STORE_VERSION + 1)); db.close(); db = null; @@ -256,7 +256,7 @@ private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; - DBIterator it = db.db.iterator(); + DBIterator it = db.db().iterator(); it.seek(prefix); while (it.hasNext()) { From 45a027fd5e32421b57846236180d6012ee72e69b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 24 Mar 2017 13:19:07 -0700 Subject: [PATCH 25/30] SHS-NG M1: Use Java 8 lambdas. Also rename LevelDBIteratorSuite to work around some super weird issue with sbt. --- ...teratorSuite.java => DBIteratorSuite.java} | 72 ++++++------------- .../apache/spark/kvstore/LevelDBSuite.java | 1 - 2 files changed, 20 insertions(+), 53 deletions(-) rename common/kvstore/src/test/java/org/apache/spark/kvstore/{LevelDBIteratorSuite.java => DBIteratorSuite.java} (86%) diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java similarity index 86% rename from common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java rename to common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java index b67503b3fcbea..88c7cc08984bb 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java @@ -36,7 +36,11 @@ import org.junit.Test; import static org.junit.Assert.*; -public class LevelDBIteratorSuite { +/** + * This class should really be called "LevelDBIteratorSuite" but for some reason I don't know, + * sbt does not run the tests if it has that name. + */ +public class DBIteratorSuite { private static final int MIN_ENTRIES = 42; private static final int MAX_ENTRIES = 1024; @@ -47,63 +51,32 @@ public class LevelDBIteratorSuite { private static LevelDB db; private static File dbpath; - private abstract class BaseComparator implements Comparator { + private interface BaseComparator extends Comparator { /** * Returns a comparator that falls back to natural order if this comparator's ordering * returns equality for two elements. Used to mimic how the index sorts things internally. */ - BaseComparator fallback() { - return new BaseComparator() { - @Override - public int compare(CustomType1 t1, CustomType1 t2) { - int result = BaseComparator.this.compare(t1, t2); - if (result != 0) { - return result; - } - - return t1.key.compareTo(t2.key); + default BaseComparator fallback() { + return (t1, t2) -> { + int result = BaseComparator.this.compare(t1, t2); + if (result != 0) { + return result; } + + return t1.key.compareTo(t2.key); }; } /** Reverses the order of this comparator. */ - BaseComparator reverse() { - return new BaseComparator() { - @Override - public int compare(CustomType1 t1, CustomType1 t2) { - return -BaseComparator.this.compare(t1, t2); - } - }; + default BaseComparator reverse() { + return (t1, t2) -> -BaseComparator.this.compare(t1, t2); } } - private final BaseComparator NATURAL_ORDER = new BaseComparator() { - @Override - public int compare(CustomType1 t1, CustomType1 t2) { - return t1.key.compareTo(t2.key); - } - }; - - private final BaseComparator REF_INDEX_ORDER = new BaseComparator() { - @Override - public int compare(CustomType1 t1, CustomType1 t2) { - return t1.id.compareTo(t2.id); - } - }; - - private final BaseComparator COPY_INDEX_ORDER = new BaseComparator() { - @Override - public int compare(CustomType1 t1, CustomType1 t2) { - return t1.name.compareTo(t2.name); - } - }; - - private final BaseComparator NUMERIC_INDEX_ORDER = new BaseComparator() { - @Override - public int compare(CustomType1 t1, CustomType1 t2) { - return t1.num - t2.num; - } - }; + private final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); + private final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); + private final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); + private final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; @BeforeClass public static void setup() throws Exception { @@ -333,12 +306,7 @@ private void testIteration( Iterable expected = indexOrder; if (first != null) { final BaseComparator expectedOrder = params.ascending ? order : order.reverse(); - expected = Iterables.filter(expected, new Predicate() { - @Override - public boolean apply(CustomType1 v) { - return expectedOrder.compare(first, v) <= 0; - } - }); + expected = Iterables.filter(expected, v -> expectedOrder.compare(first, v) <= 0); } if (params.skip > 0) { diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java index 3b43307c9580c..1d33ba099f4f8 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -24,7 +24,6 @@ import java.util.NoSuchElementException; import static java.nio.charset.StandardCharsets.UTF_8; -import com.google.common.collect.Iterators; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; import org.junit.After; From e592bf69b94c3308d194c2cb678be133931b95b5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 24 Mar 2017 17:24:08 -0700 Subject: [PATCH 26/30] SHS-NG M1: Compress values stored in LevelDB. LevelDB has built-in support for snappy compression, but it seems to be buggy in the leveldb-jni library; the compression threads don't seem to run by default, and when you enable them, there are weird issues when stopping the DB. So just do compression manually using the JRE libraries; it's probably a little slower but it saves a good chunk of disk space. --- .../spark/kvstore/KVStoreSerializer.java | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java index d9f9e2646cc14..b84ec91cf67a0 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java @@ -17,6 +17,10 @@ package org.apache.spark.kvstore; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; import static java.nio.charset.StandardCharsets.UTF_8; import com.fasterxml.jackson.databind.ObjectMapper; @@ -46,7 +50,14 @@ public final byte[] serialize(Object o) throws Exception { if (o instanceof String) { return ((String) o).getBytes(UTF_8); } else { - return mapper.writeValueAsBytes(o); + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + GZIPOutputStream out = new GZIPOutputStream(bytes); + try { + mapper.writeValue(out, o); + } finally { + out.close(); + } + return bytes.toByteArray(); } } @@ -55,7 +66,12 @@ public final T deserialize(byte[] data, Class klass) throws Exception { if (klass.equals(String.class)) { return (T) new String(data, UTF_8); } else { - return mapper.readValue(data, klass); + GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data)); + try { + return mapper.readValue(in, klass); + } finally { + in.close(); + } } } From 889963f2ffbcb628f9e53e7142fd37931ba09a54 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 24 Mar 2017 18:24:58 -0700 Subject: [PATCH 27/30] SHS-NG M1: Use type aliases as keys in Level DB. The type name gets repeated a lot in the store, so using it as the prefix for every key causes disk usage to grow unnecessarily. Instead, create a short alias for the type and keep a mapping of aliases to known types in a map in memory; the map is also saved to the database so it can be read later. --- .../org/apache/spark/kvstore/LevelDB.java | 62 ++++++++++++++++--- .../apache/spark/kvstore/LevelDBTypeInfo.java | 18 +----- .../spark/kvstore/LevelDBTypeInfoSuite.java | 2 +- 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java index 35cdbb6733a39..e423d71a335e9 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -19,7 +19,9 @@ import java.io.File; import java.io.IOException; +import java.util.HashMap; import java.util.Iterator; +import java.util.Map; import java.util.NoSuchElementException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -49,16 +51,20 @@ public class LevelDB implements KVStore { /** DB key where app metadata is stored. */ private static final byte[] METADATA_KEY = "__meta__".getBytes(UTF_8); + /** DB key where type aliases are stored. */ + private static final byte[] TYPE_ALIASES_KEY = "__types__".getBytes(UTF_8); + final AtomicReference _db; final KVStoreSerializer serializer; + private final ConcurrentMap typeAliases; private final ConcurrentMap, LevelDBTypeInfo> types; - public LevelDB(File path) throws IOException { + public LevelDB(File path) throws Exception { this(path, new KVStoreSerializer()); } - public LevelDB(File path, KVStoreSerializer serializer) throws IOException { + public LevelDB(File path, KVStoreSerializer serializer) throws Exception { this.serializer = serializer; this.types = new ConcurrentHashMap<>(); @@ -75,6 +81,14 @@ public LevelDB(File path, KVStoreSerializer serializer) throws IOException { } else { db().put(STORE_VERSION_KEY, serializer.serialize(STORE_VERSION)); } + + Map aliases; + try { + aliases = get(TYPE_ALIASES_KEY, TypeAliases.class).aliases; + } catch (NoSuchElementException e) { + aliases = new HashMap<>(); + } + typeAliases = new ConcurrentHashMap<>(aliases); } @Override @@ -218,15 +232,15 @@ public void close() throws IOException { /** Returns metadata about indices for the given type. */ LevelDBTypeInfo getTypeInfo(Class type) throws Exception { - LevelDBTypeInfo idx = types.get(type); - if (idx == null) { - LevelDBTypeInfo tmp = new LevelDBTypeInfo<>(this, type); - idx = types.putIfAbsent(type, tmp); - if (idx == null) { - idx = tmp; + LevelDBTypeInfo ti = types.get(type); + if (ti == null) { + LevelDBTypeInfo tmp = new LevelDBTypeInfo<>(this, type, getTypeAlias(type)); + ti = types.putIfAbsent(type, tmp); + if (ti == null) { + ti = tmp; } } - return idx; + return ti; } /** @@ -249,4 +263,34 @@ private void removeInstance(LevelDBTypeInfo ti, LevelDBWriteBatch batch, Obje } } + private byte[] getTypeAlias(Class klass) throws Exception { + byte[] alias = typeAliases.get(klass.getName()); + if (alias == null) { + synchronized (typeAliases) { + byte[] tmp = String.valueOf(typeAliases.size()).getBytes(UTF_8); + alias = typeAliases.putIfAbsent(klass.getName(), tmp); + if (alias == null) { + alias = tmp; + put(TYPE_ALIASES_KEY, new TypeAliases(typeAliases)); + } + } + } + return alias; + } + + /** Needs to be public for Jackson. */ + public static class TypeAliases { + + public Map aliases; + + TypeAliases(Map aliases) { + this.aliases = aliases; + } + + TypeAliases() { + this(null); + } + + } + } diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java index e0f3dc80cfe62..b9bb1959f5ae0 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -62,7 +62,7 @@ class LevelDBTypeInfo { private final Map indices; private final byte[] typePrefix; - LevelDBTypeInfo(LevelDB db, Class type) throws Exception { + LevelDBTypeInfo(LevelDB db, Class type, byte[] alias) throws Exception { this.db = db; this.type = type; this.indices = new HashMap<>(); @@ -88,21 +88,7 @@ class LevelDBTypeInfo { ByteArrayOutputStream typePrefix = new ByteArrayOutputStream(); typePrefix.write(utf8(ENTRY_PREFIX)); - - // Change fully-qualified class names to make keys more spread out by placing the - // class name first, and the package name afterwards. - String[] components = type.getName().split("\\."); - typePrefix.write(utf8(components[components.length - 1])); - if (components.length > 1) { - typePrefix.write(utf8("/")); - } - for (int i = 0; i < components.length - 1; i++) { - typePrefix.write(utf8(components[i])); - if (i < components.length - 2) { - typePrefix.write(utf8(".")); - } - } - typePrefix.write(KEY_SEPARATOR); + typePrefix.write(alias); this.typePrefix = typePrefix.toByteArray(); } diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java index 7ee0b24552219..4cddab1acac08 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java @@ -136,7 +136,7 @@ public void testArrayIndices() throws Exception { } private LevelDBTypeInfo newTypeInfo(Class type) throws Exception { - return new LevelDBTypeInfo<>(null, type); + return new LevelDBTypeInfo<>(null, type, type.getName().getBytes(UTF_8)); } private void assertBefore(byte[] key1, byte[] key2) { From 84ab160699ef8dad4df1fa4cbba29deec7c92c06 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 3 Apr 2017 11:35:50 -0700 Subject: [PATCH 28/30] SHS-NG M1: Separate index introspection from storage. The new KVTypeInfo class can help with writing different implementations of KVStore without duplicating logic from LevelDBTypeInfo. --- .../org/apache/spark/kvstore/KVTypeInfo.java | 135 ++++++++++++++++++ .../org/apache/spark/kvstore/LevelDB.java | 18 +-- .../apache/spark/kvstore/LevelDBIterator.java | 4 +- .../apache/spark/kvstore/LevelDBTypeInfo.java | 87 ++--------- .../apache/spark/kvstore/LevelDBSuite.java | 2 +- .../spark/kvstore/LevelDBTypeInfoSuite.java | 22 +-- 6 files changed, 168 insertions(+), 100 deletions(-) create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java new file mode 100644 index 0000000000000..1a0bee958d482 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; + +import com.google.common.base.Preconditions; + +/** + * Wrapper around types managed in a KVStore, providing easy access to their indexed fields. + */ +public class KVTypeInfo { + + private final Class type; + private final Collection indices; + private final Map accessors; + + public KVTypeInfo(Class type) throws Exception { + this.type = type; + this.indices = new ArrayList<>(); + this.accessors = new HashMap<>(); + + for (Field f : type.getFields()) { + KVIndex idx = f.getAnnotation(KVIndex.class); + if (idx != null) { + checkIndex(idx); + indices.add(idx); + accessors.put(idx.value(), new FieldAccessor(f)); + } + } + + for (Method m : type.getMethods()) { + KVIndex idx = m.getAnnotation(KVIndex.class); + if (idx != null) { + checkIndex(idx); + Preconditions.checkArgument(m.getParameterTypes().length == 0, + "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + indices.add(idx); + accessors.put(idx.value(), new MethodAccessor(m)); + } + } + + Preconditions.checkArgument(accessors.containsKey(KVIndex.NATURAL_INDEX_NAME), + "No natural index defined for type %s.", type.getName()); + } + + private void checkIndex(KVIndex idx) { + Preconditions.checkArgument(idx.value() != null && !idx.value().isEmpty(), + "No name provided for index in type %s.", type.getName()); + Preconditions.checkArgument( + !idx.value().startsWith("_") || idx.value().equals(KVIndex.NATURAL_INDEX_NAME), + "Index name %s (in type %s) is not allowed.", idx.value(), type.getName()); + Preconditions.checkArgument(!indices.contains(idx.value()), + "Duplicate index %s for type %s.", idx.value(), type.getName()); + } + + public Class getType() { + return type; + } + + public Object getIndexValue(String indexName, Object instance) throws Exception { + return getAccessor(indexName).get(instance); + } + + public Stream indices() { + return indices.stream(); + } + + Accessor getAccessor(String indexName) { + Accessor a = accessors.get(indexName); + Preconditions.checkArgument(a != null, "No index %s.", indexName); + return a; + } + + /** + * Abstracts the difference between invoking a Field and a Method. + */ + interface Accessor { + + Object get(Object instance) throws Exception; + + } + + private class FieldAccessor implements Accessor { + + private final Field field; + + FieldAccessor(Field field) { + this.field = field; + } + + @Override + public Object get(Object instance) throws Exception { + return field.get(instance); + } + + } + + private class MethodAccessor implements Accessor { + + private final Method method; + + MethodAccessor(Method method) { + this.method = method; + } + + @Override + public Object get(Object instance) throws Exception { + return method.invoke(instance); + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java index e423d71a335e9..337b9541e2879 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -148,7 +148,7 @@ public void write(Object value) throws Exception { public void write(Object value, boolean sync) throws Exception { Preconditions.checkArgument(value != null, "Null values are not allowed."); - LevelDBTypeInfo ti = getTypeInfo(value.getClass()); + LevelDBTypeInfo ti = getTypeInfo(value.getClass()); LevelDBWriteBatch batch = new LevelDBWriteBatch(this); try { @@ -160,7 +160,7 @@ public void write(Object value, boolean sync) throws Exception { } catch (NoSuchElementException e) { // Ignore. No previous value. } - for (LevelDBTypeInfo.Index idx : ti.indices()) { + for (LevelDBTypeInfo.Index idx : ti.indices()) { idx.add(batch, value, data); } batch.write(sync); @@ -179,7 +179,7 @@ public void delete(Class type, Object naturalKey, boolean sync) throws Except Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); LevelDBWriteBatch batch = new LevelDBWriteBatch(this); try { - LevelDBTypeInfo ti = getTypeInfo(type); + LevelDBTypeInfo ti = getTypeInfo(type); byte[] key = ti.naturalIndex().start(naturalKey); byte[] data = db().get(key); if (data != null) { @@ -210,7 +210,7 @@ public Iterator iterator() { @Override public long count(Class type) throws Exception { - LevelDBTypeInfo.Index idx = getTypeInfo(type).naturalIndex(); + LevelDBTypeInfo.Index idx = getTypeInfo(type).naturalIndex(); return idx.getCount(idx.end()); } @@ -231,10 +231,10 @@ public void close() throws IOException { } /** Returns metadata about indices for the given type. */ - LevelDBTypeInfo getTypeInfo(Class type) throws Exception { - LevelDBTypeInfo ti = types.get(type); + LevelDBTypeInfo getTypeInfo(Class type) throws Exception { + LevelDBTypeInfo ti = types.get(type); if (ti == null) { - LevelDBTypeInfo tmp = new LevelDBTypeInfo<>(this, type, getTypeAlias(type)); + LevelDBTypeInfo tmp = new LevelDBTypeInfo(this, type, getTypeAlias(type)); ti = types.putIfAbsent(type, tmp); if (ti == null) { ti = tmp; @@ -256,9 +256,9 @@ DB db() { return _db; } - private void removeInstance(LevelDBTypeInfo ti, LevelDBWriteBatch batch, Object instance) + private void removeInstance(LevelDBTypeInfo ti, LevelDBWriteBatch batch, Object instance) throws Exception { - for (LevelDBTypeInfo.Index idx : ti.indices()) { + for (LevelDBTypeInfo.Index idx : ti.indices()) { idx.remove(batch, instance); } } diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java index b777ff7bafc02..3b00c171740db 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -34,8 +34,8 @@ class LevelDBIterator implements KVStoreIterator { private final boolean ascending; private final DBIterator it; private final Class type; - private final LevelDBTypeInfo ti; - private final LevelDBTypeInfo.Index index; + private final LevelDBTypeInfo ti; + private final LevelDBTypeInfo.Index index; private final byte[] indexKeyPrefix; private final byte[] end; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java index b9bb1959f5ae0..826e6cf068fd9 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -35,7 +35,7 @@ * Holds metadata about app-specific types stored in LevelDB. Serves as a cache for data collected * via reflection, to make it cheaper to access it multiple times. */ -class LevelDBTypeInfo { +class LevelDBTypeInfo { static final String ENTRY_PREFIX = "+"; static final String END_MARKER = "-"; @@ -58,33 +58,19 @@ class LevelDBTypeInfo { static final int SHORT_ENCODED_LEN = String.valueOf(Short.MAX_VALUE).length() + 1; private final LevelDB db; - private final Class type; + private final Class type; private final Map indices; private final byte[] typePrefix; - LevelDBTypeInfo(LevelDB db, Class type, byte[] alias) throws Exception { + LevelDBTypeInfo(LevelDB db, Class type, byte[] alias) throws Exception { this.db = db; this.type = type; this.indices = new HashMap<>(); - for (Field f : type.getFields()) { - KVIndex idx = f.getAnnotation(KVIndex.class); - if (idx != null) { - register(idx, new FieldAccessor(f)); - } - } - - for (Method m : type.getMethods()) { - KVIndex idx = m.getAnnotation(KVIndex.class); - if (idx != null) { - Preconditions.checkArgument(m.getParameterTypes().length == 0, - "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); - register(idx, new MethodAccessor(m)); - } - } - - Preconditions.checkArgument(indices.get(KVIndex.NATURAL_INDEX_NAME) != null, - "No natural index defined for type %s.", type.getName()); + KVTypeInfo ti = new KVTypeInfo(type); + ti.indices().forEach(idx -> { + indices.put(idx.value(), new Index(idx.value(), idx.copy(), ti.getAccessor(idx.value()))); + }); ByteArrayOutputStream typePrefix = new ByteArrayOutputStream(); typePrefix.write(utf8(ENTRY_PREFIX)); @@ -92,18 +78,7 @@ class LevelDBTypeInfo { this.typePrefix = typePrefix.toByteArray(); } - private void register(KVIndex idx, Accessor accessor) { - Preconditions.checkArgument(idx.value() != null && !idx.value().isEmpty(), - "No name provided for index in type %s.", type.getName()); - Preconditions.checkArgument( - !idx.value().startsWith("_") || idx.value().equals(KVIndex.NATURAL_INDEX_NAME), - "Index name %s (in type %s) is not allowed.", idx.value(), type.getName()); - Preconditions.checkArgument(indices.get(idx.value()) == null, - "Duplicate index %s for type %s.", idx.value(), type.getName()); - indices.put(idx.value(), new Index(idx.value(), idx.copy(), accessor)); - } - - Class type() { + Class type() { return type; } @@ -164,11 +139,9 @@ class Index { private final boolean copy; private final boolean isNatural; private final String name; + private final KVTypeInfo.Accessor accessor; - @VisibleForTesting - final Accessor accessor; - - private Index(String name, boolean copy, Accessor accessor) { + private Index(String name, boolean copy, KVTypeInfo.Accessor accessor) { this.name = name; this.isNatural = name.equals(KVIndex.NATURAL_INDEX_NAME); this.copy = isNatural || copy; @@ -320,44 +293,4 @@ String toKey(Object value) { } - /** - * Abstracts the difference between invoking a Field and a Method. - */ - @VisibleForTesting - interface Accessor { - - Object get(Object instance) throws Exception; - - } - - private class FieldAccessor implements Accessor { - - private final Field field; - - FieldAccessor(Field field) { - this.field = field; - } - - @Override - public Object get(Object instance) throws Exception { - return field.get(instance); - } - - } - - private class MethodAccessor implements Accessor { - - private final Method method; - - MethodAccessor(Method method) { - this.method = method; - } - - @Override - public Object get(Object instance) throws Exception { - return method.invoke(instance); - } - - } - } diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java index 1d33ba099f4f8..c3baf76589286 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -221,7 +221,7 @@ public void testMetadata() throws Exception { } private long countIndexEntries(Class type, String index, Object value) throws Exception { - LevelDBTypeInfo.Index idx = db.getTypeInfo(type).index(index); + LevelDBTypeInfo.Index idx = db.getTypeInfo(type).index(index); return idx.getCount(idx.end()); } diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java index 4cddab1acac08..cf69f32dfb354 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java @@ -26,8 +26,8 @@ public class LevelDBTypeInfoSuite { @Test public void testIndexAnnotation() throws Exception { - LevelDBTypeInfo ti = newTypeInfo(CustomType1.class); - assertEquals(4, ti.indices().size()); + KVTypeInfo ti = new KVTypeInfo(CustomType1.class); + assertEquals(4, ti.indices().count()); CustomType1 t1 = new CustomType1(); t1.key = "key"; @@ -35,10 +35,10 @@ public void testIndexAnnotation() throws Exception { t1.name = "name"; t1.num = 42; - assertEquals(t1.key, ti.naturalIndex().accessor.get(t1)); - assertEquals(t1.id, ti.index("id").accessor.get(t1)); - assertEquals(t1.name, ti.index("name").accessor.get(t1)); - assertEquals(t1.num, ti.index("int").accessor.get(t1)); + assertEquals(t1.key, ti.getIndexValue(KVIndex.NATURAL_INDEX_NAME, t1)); + assertEquals(t1.id, ti.getIndexValue("id", t1)); + assertEquals(t1.name, ti.getIndexValue("name", t1)); + assertEquals(t1.num, ti.getIndexValue("int", t1)); } @Test(expected = IllegalArgumentException.class) @@ -68,7 +68,7 @@ public void testIllegalIndexMethod() throws Exception { @Test public void testKeyClashes() throws Exception { - LevelDBTypeInfo ti = newTypeInfo(CustomType1.class); + LevelDBTypeInfo ti = newTypeInfo(CustomType1.class); CustomType1 t1 = new CustomType1(); t1.key = "key1"; @@ -90,7 +90,7 @@ public void testKeyClashes() throws Exception { @Test public void testNumEncoding() throws Exception { - LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); assertBefore(idx.toKey(1), idx.toKey(2)); assertBefore(idx.toKey(-1), idx.toKey(2)); @@ -125,7 +125,7 @@ public void testNumEncoding() throws Exception { @Test public void testArrayIndices() throws Exception { - LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); assertBefore(idx.toKey(new String[] { "str1" }), idx.toKey(new String[] { "str2" })); assertBefore(idx.toKey(new String[] { "str1", "str2" }), @@ -135,8 +135,8 @@ public void testArrayIndices() throws Exception { assertBefore(idx.toKey(new int[] { 1, 2 }), idx.toKey(new int[] { 1, 3 })); } - private LevelDBTypeInfo newTypeInfo(Class type) throws Exception { - return new LevelDBTypeInfo<>(null, type, type.getName().getBytes(UTF_8)); + private LevelDBTypeInfo newTypeInfo(Class type) throws Exception { + return new LevelDBTypeInfo(null, type, type.getName().getBytes(UTF_8)); } private void assertBefore(byte[] key1, byte[] key2) { From 7b870212e80e70b8c3f3eb4279e3bb9ec0125d2d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 26 Apr 2017 11:54:33 -0700 Subject: [PATCH 29/30] SHS-NG M1: Remove unused methods from KVStore. Turns out I ended up not using the raw storage methods in KVStore, so this change removes them to simplify the API and save some code. --- .../org/apache/spark/kvstore/KVStore.java | 44 +++++-------------- .../org/apache/spark/kvstore/LevelDB.java | 16 +------ .../apache/spark/kvstore/LevelDBIterator.java | 18 -------- .../apache/spark/kvstore/LevelDBSuite.java | 43 ------------------ 4 files changed, 14 insertions(+), 107 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java index 31d4e6fefc289..667fccccd5428 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java @@ -25,21 +25,21 @@ * Abstraction for a local key/value store for storing app data. * *

- * Use {@link KVStoreBuilder} to create an instance. There are two main features provided by the - * implementations of this interface: + * There are two main features provided by the implementations of this interface: *

* - *
    - *
  • serialization: this feature is not optional; data will be serialized to and deserialized - * from the underlying data store using a {@link KVStoreSerializer}, which can be customized by - * the application. The serializer is based on Jackson, so it supports all the Jackson annotations - * for controlling the serialization of app-defined types.
  • + *

    Serialization

    * - *
  • key management: by using {@link #read(Class, Object)} and {@link #write(Class, Object)}, - * applications can leave key management to the implementation. For applications that want to - * manage their own keys, the {@link #get(byte[], Class)} and {@link #set(byte[], Object)} methods - * are available.
  • - *
+ *

+ * Data will be serialized to and deserialized from the underlying data store using a + * {@link KVStoreSerializer}, which can be customized by the application. The serializer is + * based on Jackson, so it supports all the Jackson annotations for controlling the serialization + * of app-defined types. + *

+ * + *

+ * Data is also automatically compressed to save disk space. + *

* *

Automatic Key Management

* @@ -78,26 +78,6 @@ public interface KVStore extends Closeable { */ void setMetadata(Object value) throws Exception; - /** - * Returns the value of a specific key, deserialized to the given type. - */ - T get(byte[] key, Class klass) throws Exception; - - /** - * Write a single key directly to the store, atomically. - */ - void put(byte[] key, Object value) throws Exception; - - /** - * Removes a key from the store. - */ - void delete(byte[] key) throws Exception; - - /** - * Returns an iterator that will only list values with keys starting with the given prefix. - */ - KVStoreIterator iterator(byte[] prefix, Class klass) throws Exception; - /** * Read a specific instance of an object. */ diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java index 337b9541e2879..b40c7950d1d11 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -109,8 +109,7 @@ public void setMetadata(Object value) throws Exception { } } - @Override - public T get(byte[] key, Class klass) throws Exception { + T get(byte[] key, Class klass) throws Exception { byte[] data = db().get(key); if (data == null) { throw new NoSuchElementException(new String(key, UTF_8)); @@ -118,22 +117,11 @@ public T get(byte[] key, Class klass) throws Exception { return serializer.deserialize(data, klass); } - @Override - public void put(byte[] key, Object value) throws Exception { + private void put(byte[] key, Object value) throws Exception { Preconditions.checkArgument(value != null, "Null values are not allowed."); db().put(key, serializer.serialize(value)); } - @Override - public void delete(byte[] key) throws Exception { - db().delete(key); - } - - @Override - public KVStoreIterator iterator(byte[] prefix, Class klass) throws Exception { - throw new UnsupportedOperationException(); - } - @Override public T read(Class klass, Object naturalKey) throws Exception { Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java index 3b00c171740db..f65152a9fc36a 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -43,24 +43,6 @@ class LevelDBIterator implements KVStoreIterator { private T next; private boolean closed; - /** - * Creates a simple iterator over db keys. - */ - LevelDBIterator(LevelDB db, byte[] keyPrefix, Class type) throws Exception { - this.db = db; - this.ascending = true; - this.type = type; - this.ti = null; - this.index = null; - this.it = db.db().iterator(); - this.indexKeyPrefix = keyPrefix; - this.end = null; - it.seek(keyPrefix); - } - - /** - * Creates an iterator for indexed types (i.e., those whose keys are managed by the library). - */ LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { this.db = db; this.ascending = params.ascending; diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java index c3baf76589286..1f88aae0be2aa 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -74,29 +74,6 @@ public void testReopenAndVersionCheckDb() throws Exception { } } - @Test - public void testStringWriteReadDelete() throws Exception { - String string = "testString"; - byte[] key = string.getBytes(UTF_8); - testReadWriteDelete(key, string); - } - - @Test - public void testIntWriteReadDelete() throws Exception { - int value = 42; - byte[] key = "key".getBytes(UTF_8); - testReadWriteDelete(key, value); - } - - @Test - public void testSimpleTypeWriteReadDelete() throws Exception { - byte[] key = "testKey".getBytes(UTF_8); - CustomType1 t = new CustomType1(); - t.id = "id"; - t.name = "name"; - testReadWriteDelete(key, t); - } - @Test public void testObjectWriteReadDelete() throws Exception { CustomType1 t = new CustomType1(); @@ -268,26 +245,6 @@ private int countKeys(Class type) throws Exception { return count; } - private void testReadWriteDelete(byte[] key, T value) throws Exception { - try { - db.get(key, value.getClass()); - fail("Expected exception for non-existent key."); - } catch (NoSuchElementException nsee) { - // Expected. - } - - db.put(key, value); - assertEquals(value, db.get(key, value.getClass())); - - db.delete(key); - try { - db.get(key, value.getClass()); - fail("Expected exception for deleted key."); - } catch (NoSuchElementException nsee) { - // Expected. - } - } - public static class IntKeyType { @KVIndex From 5197c218525db2ad849dfe77d83dddf2311bb5ad Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 5 May 2017 14:36:00 -0700 Subject: [PATCH 30/30] SHS-NG M1: Add "max" and "last" to kvstore iterators. This makes it easier for callers to control the end of iteration, making it easier to write Scala code that automatically closes underlying iterator resources. Before, code had to use Scala's "takeWhile", convert the result to a list, and manually close the iterators; with these two parameters, that can be avoided in a bunch of cases, with iterators auto-closing when the last element is reached. --- .../org/apache/spark/kvstore/KVStoreView.java | 22 ++- .../apache/spark/kvstore/LevelDBIterator.java | 50 ++++++- .../apache/spark/kvstore/DBIteratorSuite.java | 135 +++++++++++++----- 3 files changed, 167 insertions(+), 40 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java index a68c37942dee4..65edc0149b438 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java @@ -43,7 +43,9 @@ public abstract class KVStoreView implements Iterable { boolean ascending = true; String index = KVIndex.NATURAL_INDEX_NAME; Object first = null; + Object last = null; long skip = 0L; + long max = Long.MAX_VALUE; public KVStoreView(Class type) { this.type = type; @@ -74,7 +76,25 @@ public KVStoreView first(Object value) { } /** - * Skips a number of elements in the resulting iterator. + * Stops iteration at the given value of the chosen index. + */ + public KVStoreView last(Object value) { + this.last = value; + return this; + } + + /** + * Stops iteration after a number of elements has been retrieved. + */ + public KVStoreView max(long max) { + Preconditions.checkArgument(max > 0L, "max must be positive."); + this.max = max; + return this; + } + + /** + * Skips a number of elements at the start of iteration. Skipped elements are not accounted + * when using {@link #max(long)}. */ public KVStoreView skip(long n) { this.skip = n; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java index f65152a9fc36a..73ca8afc9eb28 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -38,10 +38,12 @@ class LevelDBIterator implements KVStoreIterator { private final LevelDBTypeInfo.Index index; private final byte[] indexKeyPrefix; private final byte[] end; + private final long max; private boolean checkedNext; private T next; private boolean closed; + private long count; LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { this.db = db; @@ -51,6 +53,7 @@ class LevelDBIterator implements KVStoreIterator { this.ti = db.getTypeInfo(type); this.index = ti.index(params.index); this.indexKeyPrefix = index.keyPrefix(); + this.max = params.max; byte[] firstKey; if (params.first != null) { @@ -66,14 +69,27 @@ class LevelDBIterator implements KVStoreIterator { } it.seek(firstKey); + byte[] end = null; if (ascending) { - this.end = index.end(); + end = params.last != null ? index.end(params.last) : index.end(); } else { - this.end = null; + if (params.last != null) { + end = index.start(params.last); + } if (it.hasNext()) { - it.next(); + // When descending, the caller may have set up the start of iteration at a non-existant + // entry that is guaranteed to be after the desired entry. For example, if you have a + // compound key (a, b) where b is a, integer, you may seek to the end of the elements that + // have the same "a" value by specifying Integer.MAX_VALUE for "b", and that value may not + // exist in the database. So need to check here whether the next value actually belongs to + // the set being returned by the iterator before advancing. + byte[] nextKey = it.peekNext().getKey(); + if (compare(nextKey, indexKeyPrefix) <= 0) { + it.next(); + } } } + this.end = end; if (params.skip > 0) { skip(params.skip); @@ -147,6 +163,10 @@ public synchronized void close() throws IOException { } private T loadNext() { + if (count >= max) { + return null; + } + try { while (true) { boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); @@ -173,11 +193,16 @@ private T loadNext() { return null; } - // If there's a known end key and it's found, stop. - if (end != null && Arrays.equals(nextKey, end)) { - return null; + // If there's a known end key and iteration has gone past it, stop. + if (end != null) { + int comp = compare(nextKey, end) * (ascending ? 1 : -1); + if (comp > 0) { + return null; + } } + count++; + // Next element is part of the iteration, return it. if (index == null || index.isCopy()) { return db.serializer.deserialize(nextEntry.getValue(), type); @@ -228,4 +253,17 @@ private byte[] stitch(byte[]... comps) { return dest; } + private int compare(byte[] a, byte[] b) { + int diff = 0; + int minLen = Math.min(a.length, b.length); + for (int i = 0; i < minLen; i++) { + diff += (a[i] - b[i]); + if (diff != 0) { + return diff; + } + } + + return a.length - b.length; + } + } diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java index 88c7cc08984bb..6c4469e1ed5d0 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java @@ -152,111 +152,170 @@ public static void cleanup() throws Exception { @Test public void naturalIndex() throws Exception { - testIteration(NATURAL_ORDER, view(), null); + testIteration(NATURAL_ORDER, view(), null, null); } @Test public void refIndex() throws Exception { - testIteration(REF_INDEX_ORDER, view().index("id"), null); + testIteration(REF_INDEX_ORDER, view().index("id"), null, null); } @Test public void copyIndex() throws Exception { - testIteration(COPY_INDEX_ORDER, view().index("name"), null); + testIteration(COPY_INDEX_ORDER, view().index("name"), null, null); } @Test public void numericIndex() throws Exception { - testIteration(NUMERIC_INDEX_ORDER, view().index("int"), null); + testIteration(NUMERIC_INDEX_ORDER, view().index("int"), null, null); } @Test public void naturalIndexDescending() throws Exception { - testIteration(NATURAL_ORDER, view().reverse(), null); + testIteration(NATURAL_ORDER, view().reverse(), null, null); } @Test public void refIndexDescending() throws Exception { - testIteration(REF_INDEX_ORDER, view().index("id").reverse(), null); + testIteration(REF_INDEX_ORDER, view().index("id").reverse(), null, null); } @Test public void copyIndexDescending() throws Exception { - testIteration(COPY_INDEX_ORDER, view().index("name").reverse(), null); + testIteration(COPY_INDEX_ORDER, view().index("name").reverse(), null, null); } @Test public void numericIndexDescending() throws Exception { - testIteration(NUMERIC_INDEX_ORDER, view().index("int").reverse(), null); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").reverse(), null, null); } @Test public void naturalIndexWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(NATURAL_ORDER, view().first(first.key), first); + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().first(first.key), first, null); } @Test public void refIndexWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(REF_INDEX_ORDER, view().index("id").first(first.id), first); + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").first(first.id), first, null); } @Test public void copyIndexWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(COPY_INDEX_ORDER, view().index("name").first(first.name), first); + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").first(first.name), first, null); } @Test public void numericIndexWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(NUMERIC_INDEX_ORDER, view().index("int").first(first.num), first); + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").first(first.num), first, null); } @Test public void naturalIndexDescendingWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(NATURAL_ORDER, view().reverse().first(first.key), first); + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().first(first.key), first, null); } @Test public void refIndexDescendingWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(REF_INDEX_ORDER, view().reverse().index("id").first(first.id), first); + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").first(first.id), first, null); } @Test public void copyIndexDescendingWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(COPY_INDEX_ORDER, view().reverse().index("name").first(first.name), - first); + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").first(first.name), first, null); } @Test public void numericIndexDescendingWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").first(first.num), - first); + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").first(first.num), first, null); } @Test public void naturalIndexWithSkip() throws Exception { - testIteration(NATURAL_ORDER, view().skip(RND.nextInt(allEntries.size() / 2)), null); + testIteration(NATURAL_ORDER, view().skip(RND.nextInt(allEntries.size() / 2)), null, null); } @Test public void refIndexWithSkip() throws Exception { testIteration(REF_INDEX_ORDER, view().index("id").skip(RND.nextInt(allEntries.size() / 2)), - null); + null, null); } @Test public void copyIndexWithSkip() throws Exception { testIteration(COPY_INDEX_ORDER, view().index("name").skip(RND.nextInt(allEntries.size() / 2)), - null); + null, null); } + @Test + public void naturalIndexWithMax() throws Exception { + testIteration(NATURAL_ORDER, view().max(RND.nextInt(allEntries.size() / 2)), null, null); + } + + @Test + public void copyIndexWithMax() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").max(RND.nextInt(allEntries.size() / 2)), + null, null); + } + + @Test + public void naturalIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().last(last.key), null, last); + } + + @Test + public void refIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").last(last.name), null, last); + } + + @Test + public void numericIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").last(last.num), null, last); + } + + @Test + public void naturalIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().last(last.key), null, last); + } + + @Test + public void refIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").last(last.name), + null, last); + } + + @Test + public void numericIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").last(last.num), + null, last); + } + @Test public void testRefWithIntNaturalKey() throws Exception { LevelDBSuite.IntKeyType i = new LevelDBSuite.IntKeyType(); @@ -272,8 +331,8 @@ public void testRefWithIntNaturalKey() throws Exception { } } - private CustomType1 pickFirst() { - // Picks a first element that has clashes with other elements in the given index. + private CustomType1 pickLimit() { + // Picks an element that has clashes with other elements in the given index. return clashingEntries.get(RND.nextInt(clashingEntries.size())); } @@ -297,22 +356,32 @@ private > int compareWithFallback( private void testIteration( final BaseComparator order, final KVStoreView params, - final CustomType1 first) throws Exception { + final CustomType1 first, + final CustomType1 last) throws Exception { List indexOrder = sortBy(order.fallback()); if (!params.ascending) { indexOrder = Lists.reverse(indexOrder); } Iterable expected = indexOrder; + BaseComparator expectedOrder = params.ascending ? order : order.reverse(); + if (first != null) { - final BaseComparator expectedOrder = params.ascending ? order : order.reverse(); expected = Iterables.filter(expected, v -> expectedOrder.compare(first, v) <= 0); } + if (last != null) { + expected = Iterables.filter(expected, v -> expectedOrder.compare(v, last) <= 0); + } + if (params.skip > 0) { expected = Iterables.skip(expected, (int) params.skip); } + if (params.max != Long.MAX_VALUE) { + expected = Iterables.limit(expected, (int) params.max); + } + List actual = collect(params); compareLists(expected, actual); }