From 0319fa5c0527f68f3a3862afbbfd1b708f1d307d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 29 Jan 2018 16:32:02 +0900 Subject: [PATCH] Fix test failure and few minor clean up for tests --- python/pyspark/context.py | 2 + python/pyspark/sql/session.py | 1 - python/pyspark/sql/tests.py | 70 ++++++++++++++--------------------- 3 files changed, 29 insertions(+), 44 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 24905f1c97b21..95fb510b9accd 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -400,6 +400,8 @@ def stop(self): """ if getattr(self, "_jsc", None): try: + # We should clean the default session up. See SPARK-23228. + self._jvm.SparkSession.clearDefaultSession() self._jsc.stop() except Py4JError: # Case: SPARK-18523 diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 0d7d122ee1e30..5ae468e2d05a2 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -764,7 +764,6 @@ def stop(self): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() - self._jvm.SparkSession.clearDefaultSession() SparkSession._instantiatedSession = None @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7e650033de31c..78ba2ca3adfa9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -69,7 +69,7 @@ from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -204,48 +204,6 @@ def assertPandasEqual(self, expected, result): self.assertTrue(expected.equals(result), msg=msg) -class PySparkSessionTests(unittest.TestCase): - - def test_set_jvm_default_session(self): - spark = None - sc = None - try: - sc = SparkContext('local[4]', "test_spark_session") - spark = SparkSession(sc) - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) - finally: - if spark is not None: - spark.stop() - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) - spark = None - sc = None - - if sc is not None: - sc.stop() - sc = None - - def test_jvm_default_session_already_set(self): - spark = None - sc = None - try: - sc = SparkContext('local[4]', "test_spark_session") - jsession = sc._jvm.SparkSession(sc._jsc.sc()) - sc._jvm.SparkSession.setDefaultSession(jsession) - - spark = SparkSession(sc, jsession) - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) - self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) - finally: - if spark is not None: - spark.stop() - spark = None - sc = None - - if sc is not None: - sc.stop() - sc = None - - class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -2954,6 +2912,32 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class SparkSessionTests(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: