Skip to content

Commit

Permalink
Fix test failure and few minor clean up for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Jan 29, 2018
1 parent eec4386 commit 0319fa5
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 44 deletions.
2 changes: 2 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ def stop(self):
"""
if getattr(self, "_jsc", None):
try:
# We should clean the default session up. See SPARK-23228.
self._jvm.SparkSession.clearDefaultSession()
self._jsc.stop()
except Py4JError:
# Case: SPARK-18523
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,6 @@ def stop(self):
"""Stop the underlying :class:`SparkContext`.
"""
self._sc.stop()
self._jvm.SparkSession.clearDefaultSession()
SparkSession._instantiatedSession = None

@since(2.0)
Expand Down
70 changes: 27 additions & 43 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
from pyspark.sql.types import _merge_type
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
Expand Down Expand Up @@ -204,48 +204,6 @@ def assertPandasEqual(self, expected, result):
self.assertTrue(expected.equals(result), msg=msg)


class PySparkSessionTests(unittest.TestCase):

def test_set_jvm_default_session(self):
spark = None
sc = None
try:
sc = SparkContext('local[4]', "test_spark_session")
spark = SparkSession(sc)
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
finally:
if spark is not None:
spark.stop()
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())
spark = None
sc = None

if sc is not None:
sc.stop()
sc = None

def test_jvm_default_session_already_set(self):
spark = None
sc = None
try:
sc = SparkContext('local[4]', "test_spark_session")
jsession = sc._jvm.SparkSession(sc._jsc.sc())
sc._jvm.SparkSession.setDefaultSession(jsession)

spark = SparkSession(sc, jsession)
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
finally:
if spark is not None:
spark.stop()
spark = None
sc = None

if sc is not None:
sc.stop()
sc = None


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
Expand Down Expand Up @@ -2954,6 +2912,32 @@ def test_sparksession_with_stopped_sparkcontext(self):
sc.stop()


class SparkSessionTests(PySparkTestCase):

# This test is separate because it's closely related with session's start and stop.
# See SPARK-23228.
def test_set_jvm_default_session(self):
spark = SparkSession.builder.getOrCreate()
try:
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
finally:
spark.stop()
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())

def test_jvm_default_session_already_set(self):
# Here, we assume there is the default session already set in JVM.
jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc())
self.sc._jvm.SparkSession.setDefaultSession(jsession)

spark = SparkSession.builder.getOrCreate()
try:
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
# The session should be the same with the exiting one.
self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
finally:
spark.stop()


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
if SparkSession._instantiatedSession is not None:
Expand Down

0 comments on commit 0319fa5

Please sign in to comment.