diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 918e59350dac6..6d87d467e00ba 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -47,6 +47,7 @@ jobs: SPARK_BENCHMARK_NUM_SPLITS: ${{ github.event.inputs.num-splits }} SPARK_BENCHMARK_CUR_SPLIT: ${{ matrix.split }} SPARK_GENERATE_BENCHMARK_FILES: 1 + SPARK_LOCAL_IP: localhost steps: - name: Checkout Spark repository uses: actions/checkout@v2 diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 3abe20608a11a..68462dacb0e53 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -83,6 +83,7 @@ jobs: CONDA_PREFIX: /usr/share/miniconda GITHUB_PREV_SHA: ${{ github.event.before }} GITHUB_INPUT_BRANCH: ${{ github.event.inputs.target }} + SPARK_LOCAL_IP: localhost steps: - name: Checkout Spark repository uses: actions/checkout@v2 @@ -171,6 +172,7 @@ jobs: CONDA_PREFIX: /usr/share/miniconda GITHUB_PREV_SHA: ${{ github.event.before }} GITHUB_INPUT_BRANCH: ${{ github.event.inputs.target }} + SPARK_LOCAL_IP: localhost steps: - name: Checkout Spark repository uses: actions/checkout@v2 @@ -238,6 +240,7 @@ jobs: HIVE_PROFILE: hive2.3 GITHUB_PREV_SHA: ${{ github.event.before }} GITHUB_INPUT_BRANCH: ${{ github.event.inputs.target }} + SPARK_LOCAL_IP: localhost steps: - name: Checkout Spark repository uses: actions/checkout@v2 @@ -468,6 +471,8 @@ jobs: tpcds-1g: name: Run TPC-DS queries with SF=1 runs-on: ubuntu-20.04 + env: + SPARK_LOCAL_IP: localhost steps: - name: Checkout Spark repository uses: actions/checkout@v2 diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 562075bb63dcd..7fd9af1389cb4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -418,12 +418,7 @@ class MasterSuite extends SparkFunSuite (workerResponse \ "masterwebuiurl").extract[String] should be (reverseProxyUrl + "/") } - // with LocalCluster, we have masters and workers in the same JVM, each overwriting - // system property spark.ui.proxyBase. - // so we need to manage this property explicitly for test - System.getProperty("spark.ui.proxyBase") should startWith - (s"$reverseProxyUrl/proxy/worker-") - System.setProperty("spark.ui.proxyBase", reverseProxyUrl) + System.getProperty("spark.ui.proxyBase") should be (reverseProxyUrl) val html = Utils .tryWithResource(Source.fromURL(s"$masterUrl/"))(_.getLines().mkString("\n")) html should include ("Spark Master at spark://") diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 6823415137663..ab60939fbf22f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -610,6 +610,8 @@ def __hash__(self): "pyspark.pandas.spark.accessors", "pyspark.pandas.spark.utils", "pyspark.pandas.typedef.typehints", + # unittests + "pyspark.pandas.tests.test_dataframe", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 64fb588e98506..b4965003ba33d 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -134,4 +134,7 @@ private[avro] case class AvroDataToCatalyst( """ }) } + + override protected def withNewChildInternal(newChild: Expression): AvroDataToCatalyst = + copy(child = newChild) } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala index 53910b752fdd6..5d79c44ad422e 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -64,4 +64,7 @@ private[avro] case class CatalystDataToAvro( defineCodeGen(ctx, ev, input => s"(byte[]) $expr.nullSafeEval($input)") } + + override protected def withNewChildInternal(newChild: Expression): CatalystDataToAvro = + copy(child = newChild) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 109ccbd964aca..a3dd133a4ce8d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -374,6 +374,10 @@ private[spark] object SummaryBuilderImpl extends Logging { override def left: Expression = featuresExpr override def right: Expression = weightExpr + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MetricsAggregate = + copy(featuresExpr = newLeft, weightExpr = newRight) + override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = { val features = vectorUDT.deserialize(featuresExpr.eval(row)) val weight = weightExpr.eval(row).asInstanceOf[Double] diff --git a/pom.xml b/pom.xml index 2aca8c7656f76..32eb56036608b 100644 --- a/pom.xml +++ b/pom.xml @@ -138,7 +138,7 @@ 10.14.2.0 1.12.0 1.6.7 - 9.4.37.v20210219 + 9.4.39.v20210325 4.0.3 0.9.5 2.4.0 diff --git a/python/pyspark/pandas/testing/__init__.py b/python/pyspark/pandas/testing/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/pandas/testing/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/pandas/testing/utils.py b/python/pyspark/pandas/testing/utils.py new file mode 100644 index 0000000000000..d8b164d6b96fe --- /dev/null +++ b/python/pyspark/pandas/testing/utils.py @@ -0,0 +1,432 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import shutil +import tempfile +import unittest +import warnings +from contextlib import contextmanager +from distutils.version import LooseVersion + +import pandas as pd +from pandas.api.types import is_list_like +from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal + +from pyspark import pandas as pp +from pyspark.pandas.frame import DataFrame +from pyspark.pandas.indexes import Index +from pyspark.pandas.series import Series +from pyspark.pandas.utils import default_session, sql_conf as sqlc, SPARK_CONF_ARROW_ENABLED + + +tabulate_requirement_message = None +try: + from tabulate import tabulate # noqa: F401 +except ImportError as e: + # If tabulate requirement is not satisfied, skip related tests. + tabulate_requirement_message = str(e) +have_tabulate = tabulate_requirement_message is None + + +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ + + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + with sqlc(pairs, spark=self.spark): + yield + + @contextmanager + def database(self, *databases): + """ + A convenient context manager to test with some specific databases. This drops the given + databases if it exists and sets current database to "default" when it exits. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for db in databases: + self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) + self.spark.catalog.setCurrentDatabase("default") + + @contextmanager + def table(self, *tables): + """ + A convenient context manager to test with some specific tables. This drops the given tables + if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for t in tables: + self.spark.sql("DROP TABLE IF EXISTS %s" % t) + + @contextmanager + def tempView(self, *views): + """ + A convenient context manager to test with some specific views. This drops the given views + if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for v in views: + self.spark.catalog.dropTempView(v) + + @contextmanager + def function(self, *functions): + """ + A convenient context manager to test with some specific functions. This drops the given + functions if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for f in functions: + self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) + + +class ReusedSQLTestCase(unittest.TestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + cls.spark = default_session() + cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True) + + @classmethod + def tearDownClass(cls): + # We don't stop Spark session to reuse across all tests. + # The Spark session will be started and stopped at PyTest session level. + # Please see databricks/koalas/conftest.py. + pass + + def assertPandasEqual(self, left, right, check_exact=True): + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): + try: + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + kwargs = dict(check_freq=False) + else: + kwargs = dict() + + assert_frame_equal( + left, + right, + check_index_type=("equiv" if len(left.index) > 0 else False), + check_column_type=("equiv" if len(left.columns) > 0 else False), + check_exact=check_exact, + **kwargs + ) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtypes) + + "\n\nRight:\n%s\n%s" % (right, right.dtypes) + ) + raise AssertionError(msg) from e + elif isinstance(left, pd.Series) and isinstance(right, pd.Series): + try: + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + kwargs = dict(check_freq=False) + else: + kwargs = dict() + + assert_series_equal( + left, + right, + check_index_type=("equiv" if len(left.index) > 0 else False), + check_exact=check_exact, + **kwargs + ) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + raise AssertionError(msg) from e + elif isinstance(left, pd.Index) and isinstance(right, pd.Index): + try: + assert_index_equal(left, right, check_exact=check_exact) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + raise AssertionError(msg) from e + else: + raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + + def assertPandasAlmostEqual(self, left, right): + """ + This function checks if given pandas objects approximately same, + which means the conditions below: + - Both objects are nullable + - Compare floats rounding to the number of decimal places, 7 after + dropping missing values (NaN, NaT, None) + """ + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): + msg = ( + "DataFrames are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtypes) + + "\n\nRight:\n%s\n%s" % (right, right.dtypes) + ) + self.assertEqual(left.shape, right.shape, msg=msg) + for lcol, rcol in zip(left.columns, right.columns): + self.assertEqual(lcol, rcol, msg=msg) + for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + self.assertEqual(left.columns.names, right.columns.names, msg=msg) + elif isinstance(left, pd.Series) and isinstance(right, pd.Series): + msg = ( + "Series are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(left.name, right.name, msg=msg) + self.assertEqual(len(left), len(right), msg=msg) + for lnull, rnull in zip(left.isnull(), right.isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left.dropna(), right.dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex): + msg = ( + "MultiIndices are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(len(left), len(right), msg=msg) + for lval, rval in zip(left, right): + self.assertAlmostEqual(lval, rval, msg=msg) + elif isinstance(left, pd.Index) and isinstance(right, pd.Index): + msg = ( + "Indices are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(len(left), len(right), msg=msg) + for lnull, rnull in zip(left.isnull(), right.isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left.dropna(), right.dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + else: + raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + + def assert_eq(self, left, right, check_exact=True, almost=False): + """ + Asserts if two arbitrary objects are equal or not. If given objects are Koalas DataFrame + or Series, they are converted into pandas' and compared. + + :param left: object to compare + :param right: object to compare + :param check_exact: if this is False, the comparison is done less precisely. + :param almost: if this is enabled, the comparison is delegated to `unittest`'s + `assertAlmostEqual`. See its documentation for more details. + """ + lobj = self._to_pandas(left) + robj = self._to_pandas(right) + if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)): + if almost: + self.assertPandasAlmostEqual(lobj, robj) + else: + self.assertPandasEqual(lobj, robj, check_exact=check_exact) + elif is_list_like(lobj) and is_list_like(robj): + self.assertTrue(len(left) == len(right)) + for litem, ritem in zip(left, right): + self.assert_eq(litem, ritem, check_exact=check_exact, almost=almost) + elif (lobj is not None and pd.isna(lobj)) and (robj is not None and pd.isna(robj)): + pass + else: + if almost: + self.assertAlmostEqual(lobj, robj) + else: + self.assertEqual(lobj, robj) + + @staticmethod + def _to_pandas(obj): + if isinstance(obj, (DataFrame, Series, Index)): + return obj.to_pandas() + else: + return obj + + +class TestUtils(object): + @contextmanager + def temp_dir(self): + tmp = tempfile.mkdtemp() + try: + yield tmp + finally: + shutil.rmtree(tmp) + + @contextmanager + def temp_file(self): + with self.temp_dir() as tmp: + yield tempfile.mktemp(dir=tmp) + + +class ComparisonTestBase(ReusedSQLTestCase): + @property + def kdf(self): + return pp.from_pandas(self.pdf) + + @property + def pdf(self): + return self.kdf.to_pandas() + + +def compare_both(f=None, almost=True): + + if f is None: + return functools.partial(compare_both, almost=almost) + elif isinstance(f, bool): + return functools.partial(compare_both, almost=f) + + @functools.wraps(f) + def wrapped(self): + if almost: + compare = self.assertPandasAlmostEqual + else: + compare = self.assertPandasEqual + + for result_pandas, result_spark in zip(f(self, self.pdf), f(self, self.kdf)): + compare(result_pandas, result_spark.to_pandas()) + + return wrapped + + +@contextmanager +def assert_produces_warning( + expected_warning=Warning, + filter_level="always", + check_stacklevel=True, + raise_on_extra_warnings=True, +): + """ + Context manager for running code expected to either raise a specific + warning, or not raise any warnings. Verifies that the code raises the + expected warning, and that it does not raise any other unexpected + warnings. It is basically a wrapper around ``warnings.catch_warnings``. + + Notes + ----- + Replicated from pandas._testing. + + Parameters + ---------- + expected_warning : {Warning, False, None}, default Warning + The type of Exception raised. ``exception.Warning`` is the base + class for all warnings. To check that no warning is returned, + specify ``False`` or ``None``. + filter_level : str or None, default "always" + Specifies whether warnings are ignored, displayed, or turned + into errors. + Valid values are: + * "error" - turns matching warnings into exceptions + * "ignore" - discard the warning + * "always" - always emit a warning + * "default" - print the warning the first time it is generated + from each location + * "module" - print the warning the first time it is generated + from each module + * "once" - print the warning the first time it is generated + check_stacklevel : bool, default True + If True, displays the line that called the function containing + the warning to show were the function is called. Otherwise, the + line that implements the function is displayed. + raise_on_extra_warnings : bool, default True + Whether extra warnings not of the type `expected_warning` should + cause the test to fail. + + Examples + -------- + >>> import warnings + >>> with assert_produces_warning(): + ... warnings.warn(UserWarning()) + ... + >>> with assert_produces_warning(False): # doctest: +SKIP + ... warnings.warn(RuntimeWarning()) + ... + Traceback (most recent call last): + ... + AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. + >>> with assert_produces_warning(UserWarning): # doctest: +SKIP + ... warnings.warn(RuntimeWarning()) + Traceback (most recent call last): + ... + AssertionError: Did not see expected warning of class 'UserWarning' + ..warn:: This is *not* thread-safe. + """ + __tracebackhide__ = True + + with warnings.catch_warnings(record=True) as w: + + saw_warning = False + warnings.simplefilter(filter_level) + yield w + extra_warnings = [] + + for actual_warning in w: + if expected_warning and issubclass(actual_warning.category, expected_warning): + saw_warning = True + + if check_stacklevel and issubclass( + actual_warning.category, (FutureWarning, DeprecationWarning) + ): + from inspect import getframeinfo, stack + + caller = getframeinfo(stack()[2][0]) + msg = ( + "Warning not set with correct stacklevel. ", + "File where warning is raised: {} != ".format(actual_warning.filename), + "{}. Warning message: {}".format(caller.filename, actual_warning.message), + ) + assert actual_warning.filename == caller.filename, msg + else: + extra_warnings.append( + ( + actual_warning.category.__name__, + actual_warning.message, + actual_warning.filename, + actual_warning.lineno, + ) + ) + if expected_warning: + msg = "Did not see expected warning of class {}".format(repr(expected_warning.__name__)) + assert saw_warning, msg + if raise_on_extra_warnings and extra_warnings: + raise AssertionError("Caused unexpected warning(s): {}".format(repr(extra_warnings))) diff --git a/python/pyspark/pandas/tests/__init__.py b/python/pyspark/pandas/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/pandas/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py new file mode 100644 index 0000000000000..397ae27eda8b1 --- /dev/null +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -0,0 +1,5560 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from datetime import datetime +from distutils.version import LooseVersion +import inspect +import sys +import unittest +from io import StringIO + +import numpy as np +import pandas as pd +from pandas.tseries.offsets import DateOffset +import pyspark +from pyspark import StorageLevel +from pyspark.ml.linalg import SparseVector +from pyspark.sql import functions as F + +from pyspark import pandas as pp +from pyspark.pandas.config import option_context +from pyspark.pandas.exceptions import PandasNotImplementedError +from pyspark.pandas.frame import CachedDataFrame +from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame +from pyspark.pandas.typedef.typehints import ( + extension_dtypes, + extension_dtypes_available, + extension_float_dtypes_available, + extension_object_dtypes_available, +) +from pyspark.pandas.testing.utils import ( + have_tabulate, + ReusedSQLTestCase, + SQLTestUtils, + SPARK_CONF_ARROW_ENABLED, +) +from pyspark.pandas.utils import name_like_string + + +class DataFrameTest(ReusedSQLTestCase, SQLTestUtils): + @property + def pdf(self): + return pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 0]}, + index=np.random.rand(9), + ) + + @property + def kdf(self): + return pp.from_pandas(self.pdf) + + @property + def df_pair(self): + pdf = self.pdf + kdf = pp.from_pandas(pdf) + return pdf, kdf + + def test_dataframe(self): + pdf, kdf = self.df_pair + + self.assert_eq(kdf["a"] + 1, pdf["a"] + 1) + + self.assert_eq(kdf.columns, pd.Index(["a", "b"])) + + self.assert_eq(kdf[kdf["b"] > 2], pdf[pdf["b"] > 2]) + self.assert_eq(-kdf[kdf["b"] > 2], -pdf[pdf["b"] > 2]) + self.assert_eq(kdf[["a", "b"]], pdf[["a", "b"]]) + self.assert_eq(kdf.a, pdf.a) + self.assert_eq(kdf.b.mean(), pdf.b.mean()) + self.assert_eq(kdf.b.var(), pdf.b.var()) + self.assert_eq(kdf.b.std(), pdf.b.std()) + + pdf, kdf = self.df_pair + self.assert_eq(kdf[["a", "b"]], pdf[["a", "b"]]) + + self.assertEqual(kdf.a.notnull().rename("x").name, "x") + + # check pp.DataFrame(pp.Series) + pser = pd.Series([1, 2, 3], name="x", index=np.random.rand(3)) + kser = pp.from_pandas(pser) + self.assert_eq(pd.DataFrame(pser), pp.DataFrame(kser)) + + # check kdf[pd.Index] + pdf, kdf = self.df_pair + column_mask = pdf.columns.isin(["a", "b"]) + index_cols = pdf.columns[column_mask] + self.assert_eq(kdf[index_cols], pdf[index_cols]) + + def _check_extension(self, kdf, pdf): + if LooseVersion("1.1") <= LooseVersion(pd.__version__) < LooseVersion("1.2.2"): + self.assert_eq(kdf, pdf, check_exact=False) + for dtype in kdf.dtypes: + self.assertTrue(isinstance(dtype, extension_dtypes)) + else: + self.assert_eq(kdf, pdf) + + @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available") + def test_extension_dtypes(self): + pdf = pd.DataFrame( + { + "a": pd.Series([1, 2, None, 4], dtype="Int8"), + "b": pd.Series([1, None, None, 4], dtype="Int16"), + "c": pd.Series([1, 2, None, None], dtype="Int32"), + "d": pd.Series([None, 2, None, 4], dtype="Int64"), + } + ) + kdf = pp.from_pandas(pdf) + + self._check_extension(kdf, pdf) + self._check_extension(kdf + F.lit(1).cast("byte"), pdf + 1) + self._check_extension(kdf + kdf, pdf + pdf) + + @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available") + def test_astype_extension_dtypes(self): + pdf = pd.DataFrame( + { + "a": [1, 2, None, 4], + "b": [1, None, None, 4], + "c": [1, 2, None, None], + "d": [None, 2, None, 4], + } + ) + kdf = pp.from_pandas(pdf) + + astype = {"a": "Int8", "b": "Int16", "c": "Int32", "d": "Int64"} + + self._check_extension(kdf.astype(astype), pdf.astype(astype)) + + @unittest.skipIf( + not extension_object_dtypes_available, "pandas extension object dtypes are not available" + ) + def test_extension_object_dtypes(self): + pdf = pd.DataFrame( + { + "a": pd.Series(["a", "b", None, "c"], dtype="string"), + "b": pd.Series([True, None, False, True], dtype="boolean"), + } + ) + kdf = pp.from_pandas(pdf) + + self._check_extension(kdf, pdf) + + @unittest.skipIf( + not extension_object_dtypes_available, "pandas extension object dtypes are not available" + ) + def test_astype_extension_object_dtypes(self): + pdf = pd.DataFrame({"a": ["a", "b", None, "c"], "b": [True, None, False, True]}) + kdf = pp.from_pandas(pdf) + + astype = {"a": "string", "b": "boolean"} + + self._check_extension(kdf.astype(astype), pdf.astype(astype)) + + @unittest.skipIf( + not extension_float_dtypes_available, "pandas extension float dtypes are not available" + ) + def test_extension_float_dtypes(self): + pdf = pd.DataFrame( + { + "a": pd.Series([1.0, 2.0, None, 4.0], dtype="Float32"), + "b": pd.Series([1.0, None, 3.0, 4.0], dtype="Float64"), + } + ) + kdf = pp.from_pandas(pdf) + + self._check_extension(kdf, pdf) + self._check_extension(kdf + 1, pdf + 1) + self._check_extension(kdf + kdf, pdf + pdf) + + @unittest.skipIf( + not extension_float_dtypes_available, "pandas extension float dtypes are not available" + ) + def test_astype_extension_float_dtypes(self): + pdf = pd.DataFrame({"a": [1.0, 2.0, None, 4.0], "b": [1.0, None, 3.0, 4.0]}) + kdf = pp.from_pandas(pdf) + + astype = {"a": "Float32", "b": "Float64"} + + self._check_extension(kdf.astype(astype), pdf.astype(astype)) + + def test_insert(self): + # + # Basic DataFrame + # + pdf = pd.DataFrame([1, 2, 3]) + kdf = pp.from_pandas(pdf) + + kdf.insert(1, "b", 10) + pdf.insert(1, "b", 10) + self.assert_eq(kdf.sort_index(), pdf.sort_index(), almost=True) + kdf.insert(2, "c", 0.1) + pdf.insert(2, "c", 0.1) + self.assert_eq(kdf.sort_index(), pdf.sort_index(), almost=True) + kdf.insert(3, "d", kdf.b + 1) + pdf.insert(3, "d", pdf.b + 1) + self.assert_eq(kdf.sort_index(), pdf.sort_index(), almost=True) + + kser = pp.Series([4, 5, 6]) + self.assertRaises(ValueError, lambda: kdf.insert(0, "y", kser)) + self.assertRaisesRegex( + ValueError, "cannot insert b, already exists", lambda: kdf.insert(1, "b", 10) + ) + self.assertRaisesRegex( + ValueError, + '"column" should be a scalar value or tuple that contains scalar values', + lambda: kdf.insert(0, list("abc"), kser), + ) + self.assertRaises(ValueError, lambda: kdf.insert(0, "e", [7, 8, 9, 10])) + self.assertRaises(ValueError, lambda: kdf.insert(0, "f", pp.Series([7, 8]))) + self.assertRaises(AssertionError, lambda: kdf.insert(100, "y", kser)) + self.assertRaises(AssertionError, lambda: kdf.insert(1, "y", kser, allow_duplicates=True)) + + # + # DataFrame with MultiIndex as columns + # + pdf = pd.DataFrame({("x", "a", "b"): [1, 2, 3]}) + kdf = pp.from_pandas(pdf) + + kdf.insert(1, "b", 10) + pdf.insert(1, "b", 10) + self.assert_eq(kdf.sort_index(), pdf.sort_index(), almost=True) + kdf.insert(2, "c", 0.1) + pdf.insert(2, "c", 0.1) + self.assert_eq(kdf.sort_index(), pdf.sort_index(), almost=True) + kdf.insert(3, "d", kdf.b + 1) + pdf.insert(3, "d", pdf.b + 1) + self.assert_eq(kdf.sort_index(), pdf.sort_index(), almost=True) + + self.assertRaisesRegex( + ValueError, "cannot insert d, already exists", lambda: kdf.insert(4, "d", 11) + ) + self.assertRaisesRegex( + ValueError, + '"column" must have length equal to number of column levels.', + lambda: kdf.insert(4, ("e",), 11), + ) + + def test_inplace(self): + pdf, kdf = self.df_pair + + pser = pdf.a + kser = kdf.a + + pdf["a"] = pdf["a"] + 10 + kdf["a"] = kdf["a"] + 10 + + self.assert_eq(kdf, pdf) + self.assert_eq(kser, pser) + + def test_assign_list(self): + pdf, kdf = self.df_pair + + pser = pdf.a + kser = kdf.a + + pdf["x"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] + kdf["x"] = [10, 20, 30, 40, 50, 60, 70, 80, 90] + + self.assert_eq(kdf.sort_index(), pdf.sort_index()) + self.assert_eq(kser, pser) + + with self.assertRaisesRegex(ValueError, "Length of values does not match length of index"): + kdf["z"] = [10, 20, 30, 40, 50, 60, 70, 80] + + def test_dataframe_multiindex_columns(self): + pdf = pd.DataFrame( + { + ("x", "a", "1"): [1, 2, 3], + ("x", "b", "2"): [4, 5, 6], + ("y.z", "c.d", "3"): [7, 8, 9], + ("x", "b", "4"): [10, 11, 12], + }, + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf, pdf) + self.assert_eq(kdf["x"], pdf["x"]) + self.assert_eq(kdf["y.z"], pdf["y.z"]) + self.assert_eq(kdf["x"]["b"], pdf["x"]["b"]) + self.assert_eq(kdf["x"]["b"]["2"], pdf["x"]["b"]["2"]) + + self.assert_eq(kdf.x, pdf.x) + self.assert_eq(kdf.x.b, pdf.x.b) + self.assert_eq(kdf.x.b["2"], pdf.x.b["2"]) + + self.assertRaises(KeyError, lambda: kdf["z"]) + self.assertRaises(AttributeError, lambda: kdf.z) + + self.assert_eq(kdf[("x",)], pdf[("x",)]) + self.assert_eq(kdf[("x", "a")], pdf[("x", "a")]) + self.assert_eq(kdf[("x", "a", "1")], pdf[("x", "a", "1")]) + + def test_dataframe_column_level_name(self): + column = pd.Index(["A", "B", "C"], name="X") + pdf = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=column, index=np.random.rand(2)) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf, pdf) + self.assert_eq(kdf.columns.names, pdf.columns.names) + self.assert_eq(kdf.to_pandas().columns.names, pdf.columns.names) + + def test_dataframe_multiindex_names_level(self): + columns = pd.MultiIndex.from_tuples( + [("X", "A", "Z"), ("X", "B", "Z"), ("Y", "C", "Z"), ("Y", "D", "Z")], + names=["lvl_1", "lvl_2", "lv_3"], + ) + pdf = pd.DataFrame( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16], [17, 18, 19, 20]], + columns=columns, + index=np.random.rand(5), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.columns.names, pdf.columns.names) + self.assert_eq(kdf.to_pandas().columns.names, pdf.columns.names) + + kdf1 = pp.from_pandas(pdf) + self.assert_eq(kdf1.columns.names, pdf.columns.names) + + self.assertRaises( + AssertionError, lambda: pp.DataFrame(kdf1._internal.copy(column_label_names=("level",))) + ) + + self.assert_eq(kdf["X"], pdf["X"]) + self.assert_eq(kdf["X"].columns.names, pdf["X"].columns.names) + self.assert_eq(kdf["X"].to_pandas().columns.names, pdf["X"].columns.names) + self.assert_eq(kdf["X"]["A"], pdf["X"]["A"]) + self.assert_eq(kdf["X"]["A"].columns.names, pdf["X"]["A"].columns.names) + self.assert_eq(kdf["X"]["A"].to_pandas().columns.names, pdf["X"]["A"].columns.names) + self.assert_eq(kdf[("X", "A")], pdf[("X", "A")]) + self.assert_eq(kdf[("X", "A")].columns.names, pdf[("X", "A")].columns.names) + self.assert_eq(kdf[("X", "A")].to_pandas().columns.names, pdf[("X", "A")].columns.names) + self.assert_eq(kdf[("X", "A", "Z")], pdf[("X", "A", "Z")]) + + def test_itertuples(self): + pdf = pd.DataFrame({"num_legs": [4, 2], "num_wings": [0, 2]}, index=["dog", "hawk"]) + kdf = pp.from_pandas(pdf) + + for ptuple, ktuple in zip( + pdf.itertuples(index=False, name="Animal"), kdf.itertuples(index=False, name="Animal") + ): + self.assert_eq(ptuple, ktuple) + for ptuple, ktuple in zip(pdf.itertuples(name=None), kdf.itertuples(name=None)): + self.assert_eq(ptuple, ktuple) + + pdf.index = pd.MultiIndex.from_arrays( + [[1, 2], ["black", "brown"]], names=("count", "color") + ) + kdf = pp.from_pandas(pdf) + for ptuple, ktuple in zip(pdf.itertuples(name="Animal"), kdf.itertuples(name="Animal")): + self.assert_eq(ptuple, ktuple) + + pdf.columns = pd.MultiIndex.from_arrays( + [["CA", "WA"], ["age", "children"]], names=("origin", "info") + ) + kdf = pp.from_pandas(pdf) + for ptuple, ktuple in zip(pdf.itertuples(name="Animal"), kdf.itertuples(name="Animal")): + self.assert_eq(ptuple, ktuple) + + pdf = pd.DataFrame([1, 2, 3]) + kdf = pp.from_pandas(pdf) + for ptuple, ktuple in zip( + (pdf + 1).itertuples(name="num"), (kdf + 1).itertuples(name="num") + ): + self.assert_eq(ptuple, ktuple) + + # DataFrames with a large number of columns (>254) + pdf = pd.DataFrame(np.random.random((1, 255))) + kdf = pp.from_pandas(pdf) + for ptuple, ktuple in zip(pdf.itertuples(name="num"), kdf.itertuples(name="num")): + self.assert_eq(ptuple, ktuple) + + def test_iterrows(self): + pdf = pd.DataFrame( + { + ("x", "a", "1"): [1, 2, 3], + ("x", "b", "2"): [4, 5, 6], + ("y.z", "c.d", "3"): [7, 8, 9], + ("x", "b", "4"): [10, 11, 12], + }, + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + + for (pdf_k, pdf_v), (kdf_k, kdf_v) in zip(pdf.iterrows(), kdf.iterrows()): + self.assert_eq(pdf_k, kdf_k) + self.assert_eq(pdf_v, kdf_v) + + def test_reset_index(self): + pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=np.random.rand(3)) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.reset_index(), pdf.reset_index()) + self.assert_eq(kdf.reset_index().index, pdf.reset_index().index) + self.assert_eq(kdf.reset_index(drop=True), pdf.reset_index(drop=True)) + + pdf.index.name = "a" + kdf.index.name = "a" + + with self.assertRaisesRegex(ValueError, "cannot insert a, already exists"): + kdf.reset_index() + + self.assert_eq(kdf.reset_index(drop=True), pdf.reset_index(drop=True)) + + # inplace + pser = pdf.a + kser = kdf.a + pdf.reset_index(drop=True, inplace=True) + kdf.reset_index(drop=True, inplace=True) + self.assert_eq(kdf, pdf) + self.assert_eq(kser, pser) + + def test_reset_index_with_default_index_types(self): + pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=np.random.rand(3)) + kdf = pp.from_pandas(pdf) + + with pp.option_context("compute.default_index_type", "sequence"): + self.assert_eq(kdf.reset_index(), pdf.reset_index()) + + with pp.option_context("compute.default_index_type", "distributed-sequence"): + self.assert_eq(kdf.reset_index(), pdf.reset_index()) + + with pp.option_context("compute.default_index_type", "distributed"): + # the index is different. + self.assert_eq(kdf.reset_index().to_pandas().reset_index(drop=True), pdf.reset_index()) + + def test_reset_index_with_multiindex_columns(self): + index = pd.MultiIndex.from_tuples( + [("bird", "falcon"), ("bird", "parrot"), ("mammal", "lion"), ("mammal", "monkey")], + names=["class", "name"], + ) + columns = pd.MultiIndex.from_tuples([("speed", "max"), ("species", "type")]) + pdf = pd.DataFrame( + [(389.0, "fly"), (24.0, "fly"), (80.5, "run"), (np.nan, "jump")], + index=index, + columns=columns, + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf, pdf) + self.assert_eq(kdf.reset_index(), pdf.reset_index()) + self.assert_eq(kdf.reset_index(level="class"), pdf.reset_index(level="class")) + self.assert_eq( + kdf.reset_index(level="class", col_level=1), pdf.reset_index(level="class", col_level=1) + ) + self.assert_eq( + kdf.reset_index(level="class", col_level=1, col_fill="species"), + pdf.reset_index(level="class", col_level=1, col_fill="species"), + ) + self.assert_eq( + kdf.reset_index(level="class", col_level=1, col_fill="genus"), + pdf.reset_index(level="class", col_level=1, col_fill="genus"), + ) + + with self.assertRaisesRegex(IndexError, "Index has only 2 levels, not 3"): + kdf.reset_index(col_level=2) + + pdf.index.names = [("x", "class"), ("y", "name")] + kdf.index.names = [("x", "class"), ("y", "name")] + + self.assert_eq(kdf.reset_index(), pdf.reset_index()) + + with self.assertRaisesRegex(ValueError, "Item must have length equal to number of levels."): + kdf.reset_index(col_level=1) + + def test_index_to_frame_reset_index(self): + def check(kdf, pdf): + self.assert_eq(kdf.reset_index(), pdf.reset_index()) + self.assert_eq(kdf.reset_index(drop=True), pdf.reset_index(drop=True)) + + pdf.reset_index(drop=True, inplace=True) + kdf.reset_index(drop=True, inplace=True) + self.assert_eq(kdf, pdf) + + pdf, kdf = self.df_pair + check(kdf.index.to_frame(), pdf.index.to_frame()) + check(kdf.index.to_frame(index=False), pdf.index.to_frame(index=False)) + + if LooseVersion(pd.__version__) >= LooseVersion("0.24"): + # The `name` argument is added in pandas 0.24. + check(kdf.index.to_frame(name="a"), pdf.index.to_frame(name="a")) + check( + kdf.index.to_frame(index=False, name="a"), pdf.index.to_frame(index=False, name="a") + ) + check(kdf.index.to_frame(name=("x", "a")), pdf.index.to_frame(name=("x", "a"))) + check( + kdf.index.to_frame(index=False, name=("x", "a")), + pdf.index.to_frame(index=False, name=("x", "a")), + ) + + def test_multiindex_column_access(self): + columns = pd.MultiIndex.from_tuples( + [ + ("a", "", "", "b"), + ("c", "", "d", ""), + ("e", "", "f", ""), + ("e", "g", "", ""), + ("", "", "", "h"), + ("i", "", "", ""), + ] + ) + + pdf = pd.DataFrame( + [ + (1, "a", "x", 10, 100, 1000), + (2, "b", "y", 20, 200, 2000), + (3, "c", "z", 30, 300, 3000), + ], + columns=columns, + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf, pdf) + self.assert_eq(kdf["a"], pdf["a"]) + self.assert_eq(kdf["a"]["b"], pdf["a"]["b"]) + self.assert_eq(kdf["c"], pdf["c"]) + self.assert_eq(kdf["c"]["d"], pdf["c"]["d"]) + self.assert_eq(kdf["e"], pdf["e"]) + self.assert_eq(kdf["e"][""]["f"], pdf["e"][""]["f"]) + self.assert_eq(kdf["e"]["g"], pdf["e"]["g"]) + self.assert_eq(kdf[""], pdf[""]) + self.assert_eq(kdf[""]["h"], pdf[""]["h"]) + self.assert_eq(kdf["i"], pdf["i"]) + + self.assert_eq(kdf[["a", "e"]], pdf[["a", "e"]]) + self.assert_eq(kdf[["e", "a"]], pdf[["e", "a"]]) + + self.assert_eq(kdf[("a",)], pdf[("a",)]) + self.assert_eq(kdf[("e", "g")], pdf[("e", "g")]) + # self.assert_eq(kdf[("i",)], pdf[("i",)]) + self.assert_eq(kdf[("i", "")], pdf[("i", "")]) + + self.assertRaises(KeyError, lambda: kdf[("a", "b")]) + + def test_repr_cache_invalidation(self): + # If there is any cache, inplace operations should invalidate it. + df = pp.range(10) + df.__repr__() + df["a"] = df["id"] + self.assertEqual(df.__repr__(), df.to_pandas().__repr__()) + + def test_repr_html_cache_invalidation(self): + # If there is any cache, inplace operations should invalidate it. + df = pp.range(10) + df._repr_html_() + df["a"] = df["id"] + self.assertEqual(df._repr_html_(), df.to_pandas()._repr_html_()) + + def test_empty_dataframe(self): + pdf = pd.DataFrame({"a": pd.Series([], dtype="i1"), "b": pd.Series([], dtype="str")}) + + kdf = pp.from_pandas(pdf) + if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): + self.assert_eq(kdf, pdf) + else: + with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): + self.assert_eq(kdf, pdf) + + with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf, pdf) + + def test_all_null_dataframe(self): + pdf = pd.DataFrame( + { + "a": [None, None, None, "a"], + "b": [None, None, None, 1], + "c": [None, None, None] + list(np.arange(1, 2).astype("i1")), + "d": [None, None, None, 1.0], + "e": [None, None, None, True], + "f": [None, None, None] + list(pd.date_range("20130101", periods=1)), + }, + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.iloc[:-1], pdf.iloc[:-1]) + + with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): + self.assert_eq(kdf.iloc[:-1], pdf.iloc[:-1]) + + pdf = pd.DataFrame( + { + "a": pd.Series([None, None, None], dtype="float64"), + "b": pd.Series([None, None, None], dtype="str"), + }, + ) + + kdf = pp.from_pandas(pdf) + if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): + self.assert_eq(kdf, pdf) + else: + with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): + self.assert_eq(kdf, pdf) + + with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf, pdf) + + def test_nullable_object(self): + pdf = pd.DataFrame( + { + "a": list("abc") + [np.nan, None], + "b": list(range(1, 4)) + [np.nan, None], + "c": list(np.arange(3, 6).astype("i1")) + [np.nan, None], + "d": list(np.arange(4.0, 7.0, dtype="float64")) + [np.nan, None], + "e": [True, False, True, np.nan, None], + "f": list(pd.date_range("20130101", periods=3)) + [np.nan, None], + }, + index=np.random.rand(5), + ) + + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf, pdf) + + with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf, pdf) + + def test_assign(self): + pdf, kdf = self.df_pair + + kdf["w"] = 1.0 + pdf["w"] = 1.0 + + self.assert_eq(kdf, pdf) + + kdf.w = 10.0 + pdf.w = 10.0 + + self.assert_eq(kdf, pdf) + + kdf[1] = 1.0 + pdf[1] = 1.0 + + self.assert_eq(kdf, pdf) + + kdf = kdf.assign(a=kdf["a"] * 2) + pdf = pdf.assign(a=pdf["a"] * 2) + + self.assert_eq(kdf, pdf) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", "w"), ("y", "v")]) + pdf.columns = columns + kdf.columns = columns + + kdf[("a", "c")] = "def" + pdf[("a", "c")] = "def" + + self.assert_eq(kdf, pdf) + + kdf = kdf.assign(Z="ZZ") + pdf = pdf.assign(Z="ZZ") + + self.assert_eq(kdf, pdf) + + kdf["x"] = "ghi" + pdf["x"] = "ghi" + + self.assert_eq(kdf, pdf) + + def test_head(self): + pdf, kdf = self.df_pair + + self.assert_eq(kdf.head(2), pdf.head(2)) + self.assert_eq(kdf.head(3), pdf.head(3)) + self.assert_eq(kdf.head(0), pdf.head(0)) + self.assert_eq(kdf.head(-3), pdf.head(-3)) + self.assert_eq(kdf.head(-10), pdf.head(-10)) + + def test_attributes(self): + kdf = self.kdf + + self.assertIn("a", dir(kdf)) + self.assertNotIn("foo", dir(kdf)) + self.assertRaises(AttributeError, lambda: kdf.foo) + + kdf = pp.DataFrame({"a b c": [1, 2, 3]}) + self.assertNotIn("a b c", dir(kdf)) + kdf = pp.DataFrame({"a": [1, 2], 5: [1, 2]}) + self.assertIn("a", dir(kdf)) + self.assertNotIn(5, dir(kdf)) + + def test_column_names(self): + pdf, kdf = self.df_pair + + self.assert_eq(kdf.columns, pdf.columns) + self.assert_eq(kdf[["b", "a"]].columns, pdf[["b", "a"]].columns) + self.assert_eq(kdf["a"].name, pdf["a"].name) + self.assert_eq((kdf["a"] + 1).name, (pdf["a"] + 1).name) + + self.assert_eq((kdf.a + kdf.b).name, (pdf.a + pdf.b).name) + self.assert_eq((kdf.a + kdf.b.rename("a")).name, (pdf.a + pdf.b.rename("a")).name) + self.assert_eq((kdf.a + kdf.b.rename()).name, (pdf.a + pdf.b.rename()).name) + self.assert_eq((kdf.a.rename() + kdf.b).name, (pdf.a.rename() + pdf.b).name) + self.assert_eq( + (kdf.a.rename() + kdf.b.rename()).name, (pdf.a.rename() + pdf.b.rename()).name + ) + + def test_rename_columns(self): + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, index=np.random.rand(7) + ) + kdf = pp.from_pandas(pdf) + + kdf.columns = ["x", "y"] + pdf.columns = ["x", "y"] + self.assert_eq(kdf.columns, pd.Index(["x", "y"])) + self.assert_eq(kdf, pdf) + self.assert_eq(kdf._internal.data_spark_column_names, ["x", "y"]) + self.assert_eq(kdf.to_spark().columns, ["x", "y"]) + self.assert_eq(kdf.to_spark(index_col="index").columns, ["index", "x", "y"]) + + columns = pdf.columns + columns.name = "lvl_1" + + kdf.columns = columns + self.assert_eq(kdf.columns.names, ["lvl_1"]) + self.assert_eq(kdf, pdf) + + msg = "Length mismatch: Expected axis has 2 elements, new values have 4 elements" + with self.assertRaisesRegex(ValueError, msg): + kdf.columns = [1, 2, 3, 4] + + # Multi-index columns + pdf = pd.DataFrame( + {("A", "0"): [1, 2, 2, 3], ("B", "1"): [1, 2, 3, 4]}, index=np.random.rand(4) + ) + kdf = pp.from_pandas(pdf) + + columns = pdf.columns + self.assert_eq(kdf.columns, columns) + self.assert_eq(kdf, pdf) + + pdf.columns = ["x", "y"] + kdf.columns = ["x", "y"] + self.assert_eq(kdf.columns, pd.Index(["x", "y"])) + self.assert_eq(kdf, pdf) + self.assert_eq(kdf._internal.data_spark_column_names, ["x", "y"]) + self.assert_eq(kdf.to_spark().columns, ["x", "y"]) + self.assert_eq(kdf.to_spark(index_col="index").columns, ["index", "x", "y"]) + + pdf.columns = columns + kdf.columns = columns + self.assert_eq(kdf.columns, columns) + self.assert_eq(kdf, pdf) + self.assert_eq(kdf._internal.data_spark_column_names, ["(A, 0)", "(B, 1)"]) + self.assert_eq(kdf.to_spark().columns, ["(A, 0)", "(B, 1)"]) + self.assert_eq(kdf.to_spark(index_col="index").columns, ["index", "(A, 0)", "(B, 1)"]) + + columns.names = ["lvl_1", "lvl_2"] + + kdf.columns = columns + self.assert_eq(kdf.columns.names, ["lvl_1", "lvl_2"]) + self.assert_eq(kdf, pdf) + self.assert_eq(kdf._internal.data_spark_column_names, ["(A, 0)", "(B, 1)"]) + self.assert_eq(kdf.to_spark().columns, ["(A, 0)", "(B, 1)"]) + self.assert_eq(kdf.to_spark(index_col="index").columns, ["index", "(A, 0)", "(B, 1)"]) + + def test_rename_dataframe(self): + pdf1 = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + kdf1 = pp.from_pandas(pdf1) + + self.assert_eq( + kdf1.rename(columns={"A": "a", "B": "b"}), pdf1.rename(columns={"A": "a", "B": "b"}) + ) + + result_kdf = kdf1.rename(index={1: 10, 2: 20}) + result_pdf = pdf1.rename(index={1: 10, 2: 20}) + self.assert_eq(result_kdf, result_pdf) + + # inplace + pser = result_pdf.A + kser = result_kdf.A + result_kdf.rename(index={10: 100, 20: 200}, inplace=True) + result_pdf.rename(index={10: 100, 20: 200}, inplace=True) + self.assert_eq(result_kdf, result_pdf) + self.assert_eq(kser, pser) + + def str_lower(s) -> str: + return str.lower(s) + + self.assert_eq( + kdf1.rename(str_lower, axis="columns"), pdf1.rename(str_lower, axis="columns") + ) + + def mul10(x) -> int: + return x * 10 + + self.assert_eq(kdf1.rename(mul10, axis="index"), pdf1.rename(mul10, axis="index")) + + self.assert_eq( + kdf1.rename(columns=str_lower, index={1: 10, 2: 20}), + pdf1.rename(columns=str_lower, index={1: 10, 2: 20}), + ) + + idx = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Y", "D")]) + pdf2 = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=idx) + kdf2 = pp.from_pandas(pdf2) + + self.assert_eq(kdf2.rename(columns=str_lower), pdf2.rename(columns=str_lower)) + + self.assert_eq( + kdf2.rename(columns=str_lower, level=0), pdf2.rename(columns=str_lower, level=0) + ) + self.assert_eq( + kdf2.rename(columns=str_lower, level=1), pdf2.rename(columns=str_lower, level=1) + ) + + pdf3 = pd.DataFrame([[1, 2], [3, 4], [5, 6], [7, 8]], index=idx, columns=list("ab")) + kdf3 = pp.from_pandas(pdf3) + + self.assert_eq(kdf3.rename(index=str_lower), pdf3.rename(index=str_lower)) + self.assert_eq(kdf3.rename(index=str_lower, level=0), pdf3.rename(index=str_lower, level=0)) + self.assert_eq(kdf3.rename(index=str_lower, level=1), pdf3.rename(index=str_lower, level=1)) + + pdf4 = pdf2 + 1 + kdf4 = kdf2 + 1 + self.assert_eq(kdf4.rename(columns=str_lower), pdf4.rename(columns=str_lower)) + + pdf5 = pdf3 + 1 + kdf5 = kdf3 + 1 + self.assert_eq(kdf5.rename(index=str_lower), pdf5.rename(index=str_lower)) + + def test_rename_axis(self): + index = pd.Index(["A", "B", "C"], name="index") + columns = pd.Index(["numbers", "values"], name="cols") + pdf = pd.DataFrame([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], index=index, columns=columns) + kdf = pp.from_pandas(pdf) + + for axis in [0, "index"]: + self.assert_eq( + pdf.rename_axis("index2", axis=axis).sort_index(), + kdf.rename_axis("index2", axis=axis).sort_index(), + ) + self.assert_eq( + pdf.rename_axis(["index2"], axis=axis).sort_index(), + kdf.rename_axis(["index2"], axis=axis).sort_index(), + ) + + for axis in [1, "columns"]: + self.assert_eq( + pdf.rename_axis("cols2", axis=axis).sort_index(), + kdf.rename_axis("cols2", axis=axis).sort_index(), + ) + self.assert_eq( + pdf.rename_axis(["cols2"], axis=axis).sort_index(), + kdf.rename_axis(["cols2"], axis=axis).sort_index(), + ) + + pdf2 = pdf.copy() + kdf2 = kdf.copy() + pdf2.rename_axis("index2", axis="index", inplace=True) + kdf2.rename_axis("index2", axis="index", inplace=True) + self.assert_eq(pdf2.sort_index(), kdf2.sort_index()) + + self.assertRaises(ValueError, lambda: kdf.rename_axis(["index2", "index3"], axis=0)) + self.assertRaises(ValueError, lambda: kdf.rename_axis(["cols2", "cols3"], axis=1)) + self.assertRaises(TypeError, lambda: kdf.rename_axis(mapper=["index2"], index=["index3"])) + + # index/columns parameters and dict_like/functions mappers introduced in pandas 0.24.0 + if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"): + self.assert_eq( + pdf.rename_axis(index={"index": "index2"}, columns={"cols": "cols2"}).sort_index(), + kdf.rename_axis(index={"index": "index2"}, columns={"cols": "cols2"}).sort_index(), + ) + + self.assert_eq( + pdf.rename_axis( + index={"missing": "index2"}, columns={"missing": "cols2"} + ).sort_index(), + kdf.rename_axis( + index={"missing": "index2"}, columns={"missing": "cols2"} + ).sort_index(), + ) + + self.assert_eq( + pdf.rename_axis(index=str.upper, columns=str.upper).sort_index(), + kdf.rename_axis(index=str.upper, columns=str.upper).sort_index(), + ) + else: + expected = pdf + expected.index.name = "index2" + expected.columns.name = "cols2" + result = kdf.rename_axis( + index={"index": "index2"}, columns={"cols": "cols2"} + ).sort_index() + self.assert_eq(expected, result) + + expected.index.name = "index" + expected.columns.name = "cols" + result = kdf.rename_axis( + index={"missing": "index2"}, columns={"missing": "cols2"} + ).sort_index() + self.assert_eq(expected, result) + + expected.index.name = "INDEX" + expected.columns.name = "COLS" + result = kdf.rename_axis(index=str.upper, columns=str.upper).sort_index() + self.assert_eq(expected, result) + + index = pd.MultiIndex.from_tuples( + [("A", "B"), ("C", "D"), ("E", "F")], names=["index1", "index2"] + ) + columns = pd.MultiIndex.from_tuples( + [("numbers", "first"), ("values", "second")], names=["cols1", "cols2"] + ) + pdf = pd.DataFrame([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], index=index, columns=columns) + kdf = pp.from_pandas(pdf) + + for axis in [0, "index"]: + self.assert_eq( + pdf.rename_axis(["index3", "index4"], axis=axis).sort_index(), + kdf.rename_axis(["index3", "index4"], axis=axis).sort_index(), + ) + + for axis in [1, "columns"]: + self.assert_eq( + pdf.rename_axis(["cols3", "cols4"], axis=axis).sort_index(), + kdf.rename_axis(["cols3", "cols4"], axis=axis).sort_index(), + ) + + self.assertRaises( + ValueError, lambda: kdf.rename_axis(["index3", "index4", "index5"], axis=0) + ) + self.assertRaises(ValueError, lambda: kdf.rename_axis(["cols3", "cols4", "cols5"], axis=1)) + + # index/columns parameters and dict_like/functions mappers introduced in pandas 0.24.0 + if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"): + self.assert_eq( + pdf.rename_axis( + index={"index1": "index3"}, columns={"cols1": "cols3"} + ).sort_index(), + kdf.rename_axis( + index={"index1": "index3"}, columns={"cols1": "cols3"} + ).sort_index(), + ) + + self.assert_eq( + pdf.rename_axis( + index={"missing": "index3"}, columns={"missing": "cols3"} + ).sort_index(), + kdf.rename_axis( + index={"missing": "index3"}, columns={"missing": "cols3"} + ).sort_index(), + ) + + self.assert_eq( + pdf.rename_axis( + index={"index1": "index3", "index2": "index4"}, + columns={"cols1": "cols3", "cols2": "cols4"}, + ).sort_index(), + kdf.rename_axis( + index={"index1": "index3", "index2": "index4"}, + columns={"cols1": "cols3", "cols2": "cols4"}, + ).sort_index(), + ) + + self.assert_eq( + pdf.rename_axis(index=str.upper, columns=str.upper).sort_index(), + kdf.rename_axis(index=str.upper, columns=str.upper).sort_index(), + ) + else: + expected = pdf + expected.index.names = ["index3", "index2"] + expected.columns.names = ["cols3", "cols2"] + result = kdf.rename_axis( + index={"index1": "index3"}, columns={"cols1": "cols3"} + ).sort_index() + self.assert_eq(expected, result) + + expected.index.names = ["index1", "index2"] + expected.columns.names = ["cols1", "cols2"] + result = kdf.rename_axis( + index={"missing": "index2"}, columns={"missing": "cols2"} + ).sort_index() + self.assert_eq(expected, result) + + expected.index.names = ["index3", "index4"] + expected.columns.names = ["cols3", "cols4"] + result = kdf.rename_axis( + index={"index1": "index3", "index2": "index4"}, + columns={"cols1": "cols3", "cols2": "cols4"}, + ).sort_index() + self.assert_eq(expected, result) + + expected.index.names = ["INDEX1", "INDEX2"] + expected.columns.names = ["COLS1", "COLS2"] + result = kdf.rename_axis(index=str.upper, columns=str.upper).sort_index() + self.assert_eq(expected, result) + + def test_dot_in_column_name(self): + self.assert_eq( + pp.DataFrame(pp.range(1)._internal.spark_frame.selectExpr("1L as `a.b`"))["a.b"], + pp.Series([1], name="a.b"), + ) + + def test_aggregate(self): + pdf = pd.DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [np.nan, np.nan, np.nan]], columns=["A", "B", "C"] + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq( + kdf.agg(["sum", "min"])[["A", "B", "C"]].sort_index(), # TODO?: fix column order + pdf.agg(["sum", "min"])[["A", "B", "C"]].sort_index(), + ) + self.assert_eq( + kdf.agg({"A": ["sum", "min"], "B": ["min", "max"]})[["A", "B"]].sort_index(), + pdf.agg({"A": ["sum", "min"], "B": ["min", "max"]})[["A", "B"]].sort_index(), + ) + + self.assertRaises(KeyError, lambda: kdf.agg({"A": ["sum", "min"], "X": ["min", "max"]})) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + kdf.agg(["sum", "min"])[[("X", "A"), ("X", "B"), ("Y", "C")]].sort_index(), + pdf.agg(["sum", "min"])[[("X", "A"), ("X", "B"), ("Y", "C")]].sort_index(), + ) + self.assert_eq( + kdf.agg({("X", "A"): ["sum", "min"], ("X", "B"): ["min", "max"]})[ + [("X", "A"), ("X", "B")] + ].sort_index(), + pdf.agg({("X", "A"): ["sum", "min"], ("X", "B"): ["min", "max"]})[ + [("X", "A"), ("X", "B")] + ].sort_index(), + ) + + self.assertRaises(TypeError, lambda: kdf.agg({"X": ["sum", "min"], "Y": ["min", "max"]})) + + # non-string names + pdf = pd.DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [np.nan, np.nan, np.nan]], columns=[10, 20, 30] + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq( + kdf.agg(["sum", "min"])[[10, 20, 30]].sort_index(), + pdf.agg(["sum", "min"])[[10, 20, 30]].sort_index(), + ) + self.assert_eq( + kdf.agg({10: ["sum", "min"], 20: ["min", "max"]})[[10, 20]].sort_index(), + pdf.agg({10: ["sum", "min"], 20: ["min", "max"]})[[10, 20]].sort_index(), + ) + + columns = pd.MultiIndex.from_tuples([("X", 10), ("X", 20), ("Y", 30)]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + kdf.agg(["sum", "min"])[[("X", 10), ("X", 20), ("Y", 30)]].sort_index(), + pdf.agg(["sum", "min"])[[("X", 10), ("X", 20), ("Y", 30)]].sort_index(), + ) + self.assert_eq( + kdf.agg({("X", 10): ["sum", "min"], ("X", 20): ["min", "max"]})[ + [("X", 10), ("X", 20)] + ].sort_index(), + pdf.agg({("X", 10): ["sum", "min"], ("X", 20): ["min", "max"]})[ + [("X", 10), ("X", 20)] + ].sort_index(), + ) + + pdf = pd.DataFrame( + [datetime(2019, 2, 2, 0, 0, 0, 0), datetime(2019, 2, 3, 0, 0, 0, 0)], + columns=["timestamp"], + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.timestamp.min(), pdf.timestamp.min()) + self.assert_eq(kdf.timestamp.max(), pdf.timestamp.max()) + + def test_droplevel(self): + pdf = ( + pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + .set_index([0, 1]) + .rename_axis(["a", "b"]) + ) + pdf.columns = pd.MultiIndex.from_tuples( + [("c", "e"), ("d", "f")], names=["level_1", "level_2"] + ) + kdf = pp.from_pandas(pdf) + + self.assertRaises(ValueError, lambda: kdf.droplevel(["a", "b"])) + self.assertRaises(ValueError, lambda: kdf.droplevel([1, 1, 1, 1, 1])) + self.assertRaises(IndexError, lambda: kdf.droplevel(2)) + self.assertRaises(IndexError, lambda: kdf.droplevel(-3)) + self.assertRaises(KeyError, lambda: kdf.droplevel({"a"})) + self.assertRaises(KeyError, lambda: kdf.droplevel({"a": 1})) + + self.assertRaises(ValueError, lambda: kdf.droplevel(["level_1", "level_2"], axis=1)) + self.assertRaises(IndexError, lambda: kdf.droplevel(2, axis=1)) + self.assertRaises(IndexError, lambda: kdf.droplevel(-3, axis=1)) + self.assertRaises(KeyError, lambda: kdf.droplevel({"level_1"}, axis=1)) + self.assertRaises(KeyError, lambda: kdf.droplevel({"level_1": 1}, axis=1)) + + # droplevel is new in pandas 0.24.0 + if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"): + self.assert_eq(pdf.droplevel("a"), kdf.droplevel("a")) + self.assert_eq(pdf.droplevel(["a"]), kdf.droplevel(["a"])) + self.assert_eq(pdf.droplevel(("a",)), kdf.droplevel(("a",))) + self.assert_eq(pdf.droplevel(0), kdf.droplevel(0)) + self.assert_eq(pdf.droplevel(-1), kdf.droplevel(-1)) + + self.assert_eq(pdf.droplevel("level_1", axis=1), kdf.droplevel("level_1", axis=1)) + self.assert_eq(pdf.droplevel(["level_1"], axis=1), kdf.droplevel(["level_1"], axis=1)) + self.assert_eq(pdf.droplevel(("level_1",), axis=1), kdf.droplevel(("level_1",), axis=1)) + self.assert_eq(pdf.droplevel(0, axis=1), kdf.droplevel(0, axis=1)) + self.assert_eq(pdf.droplevel(-1, axis=1), kdf.droplevel(-1, axis=1)) + else: + expected = pdf.copy() + expected.index = expected.index.droplevel("a") + + self.assert_eq(expected, kdf.droplevel("a")) + self.assert_eq(expected, kdf.droplevel(["a"])) + self.assert_eq(expected, kdf.droplevel(("a",))) + self.assert_eq(expected, kdf.droplevel(0)) + + expected = pdf.copy() + expected.index = expected.index.droplevel(-1) + + self.assert_eq(expected, kdf.droplevel(-1)) + + expected = pdf.copy() + expected.columns = expected.columns.droplevel("level_1") + + self.assert_eq(expected, kdf.droplevel("level_1", axis=1)) + self.assert_eq(expected, kdf.droplevel(["level_1"], axis=1)) + self.assert_eq(expected, kdf.droplevel(("level_1",), axis=1)) + self.assert_eq(expected, kdf.droplevel(0, axis=1)) + + expected = pdf.copy() + expected.columns = expected.columns.droplevel(-1) + + self.assert_eq(expected, kdf.droplevel(-1, axis=1)) + + # Tupled names + pdf.columns.names = [("level", 1), ("level", 2)] + pdf.index.names = [("a", 10), ("x", 20)] + kdf = pp.from_pandas(pdf) + + self.assertRaises(KeyError, lambda: kdf.droplevel("a")) + self.assertRaises(KeyError, lambda: kdf.droplevel(("a", 10))) + + # droplevel is new in pandas 0.24.0 + if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"): + self.assert_eq(pdf.droplevel([("a", 10)]), kdf.droplevel([("a", 10)])) + self.assert_eq( + pdf.droplevel([("level", 1)], axis=1), kdf.droplevel([("level", 1)], axis=1) + ) + else: + expected = pdf.copy() + expected.index = expected.index.droplevel([("a", 10)]) + + self.assert_eq(expected, kdf.droplevel([("a", 10)])) + + expected = pdf.copy() + expected.columns = expected.columns.droplevel([("level", 1)]) + + self.assert_eq(expected, kdf.droplevel([("level", 1)], axis=1)) + + # non-string names + pdf = ( + pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + .set_index([0, 1]) + .rename_axis([10.0, 20.0]) + ) + pdf.columns = pd.MultiIndex.from_tuples([("c", "e"), ("d", "f")], names=[100.0, 200.0]) + kdf = pp.from_pandas(pdf) + + # droplevel is new in pandas 0.24.0 + if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"): + self.assert_eq(pdf.droplevel(10.0), kdf.droplevel(10.0)) + self.assert_eq(pdf.droplevel([10.0]), kdf.droplevel([10.0])) + self.assert_eq(pdf.droplevel((10.0,)), kdf.droplevel((10.0,))) + self.assert_eq(pdf.droplevel(0), kdf.droplevel(0)) + self.assert_eq(pdf.droplevel(-1), kdf.droplevel(-1)) + self.assert_eq(pdf.droplevel(100.0, axis=1), kdf.droplevel(100.0, axis=1)) + self.assert_eq(pdf.droplevel(0, axis=1), kdf.droplevel(0, axis=1)) + else: + expected = pdf.copy() + expected.index = expected.index.droplevel(10.0) + + self.assert_eq(expected, kdf.droplevel(10.0)) + self.assert_eq(expected, kdf.droplevel([10.0])) + self.assert_eq(expected, kdf.droplevel((10.0,))) + self.assert_eq(expected, kdf.droplevel(0)) + + expected = pdf.copy() + expected.index = expected.index.droplevel(-1) + self.assert_eq(expected, kdf.droplevel(-1)) + + expected = pdf.copy() + expected.columns = expected.columns.droplevel(100.0) + + self.assert_eq(expected, kdf.droplevel(100.0, axis=1)) + self.assert_eq(expected, kdf.droplevel(0, axis=1)) + + def test_drop(self): + pdf = pd.DataFrame({"x": [1, 2], "y": [3, 4], "z": [5, 6]}, index=np.random.rand(2)) + kdf = pp.from_pandas(pdf) + + # Assert 'labels' or 'columns' parameter is set + expected_error_message = "Need to specify at least one of 'labels' or 'columns'" + with self.assertRaisesRegex(ValueError, expected_error_message): + kdf.drop() + # Assert axis cannot be 0 + with self.assertRaisesRegex(NotImplementedError, "Drop currently only works for axis=1"): + kdf.drop("x", axis=0) + # Assert using a str for 'labels' works + self.assert_eq(kdf.drop("x", axis=1), pdf.drop("x", axis=1)) + # Assert axis is 1 by default + self.assert_eq(kdf.drop("x"), pdf.drop("x", axis=1)) + # Assert using a list for 'labels' works + self.assert_eq(kdf.drop(["y", "z"], axis=1), pdf.drop(["y", "z"], axis=1)) + # Assert using 'columns' instead of 'labels' produces the same results + self.assert_eq(kdf.drop(columns="x"), pdf.drop(columns="x")) + self.assert_eq(kdf.drop(columns=["y", "z"]), pdf.drop(columns=["y", "z"])) + + # Assert 'labels' being used when both 'labels' and 'columns' are specified + # TODO: should throw an error? + expected_output = pd.DataFrame({"y": [3, 4], "z": [5, 6]}, index=kdf.index.to_pandas()) + self.assert_eq(kdf.drop(labels=["x"], columns=["y"]), expected_output) + + columns = pd.MultiIndex.from_tuples([(1, "x"), (1, "y"), (2, "z")]) + pdf.columns = columns + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.drop(columns=1), pdf.drop(columns=1)) + self.assert_eq(kdf.drop(columns=(1, "x")), pdf.drop(columns=(1, "x"))) + self.assert_eq(kdf.drop(columns=[(1, "x"), 2]), pdf.drop(columns=[(1, "x"), 2])) + + self.assertRaises(KeyError, lambda: kdf.drop(columns=3)) + self.assertRaises(KeyError, lambda: kdf.drop(columns=(1, "z"))) + + # non-string names + pdf = pd.DataFrame({10: [1, 2], 20: [3, 4], 30: [5, 6]}, index=np.random.rand(2)) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.drop(10), pdf.drop(10, axis=1)) + self.assert_eq(kdf.drop([20, 30]), pdf.drop([20, 30], axis=1)) + + def _test_dropna(self, pdf, axis): + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.dropna(axis=axis), pdf.dropna(axis=axis)) + self.assert_eq(kdf.dropna(axis=axis, how="all"), pdf.dropna(axis=axis, how="all")) + self.assert_eq(kdf.dropna(axis=axis, subset=["x"]), pdf.dropna(axis=axis, subset=["x"])) + self.assert_eq(kdf.dropna(axis=axis, subset="x"), pdf.dropna(axis=axis, subset=["x"])) + self.assert_eq( + kdf.dropna(axis=axis, subset=["y", "z"]), pdf.dropna(axis=axis, subset=["y", "z"]) + ) + self.assert_eq( + kdf.dropna(axis=axis, subset=["y", "z"], how="all"), + pdf.dropna(axis=axis, subset=["y", "z"], how="all"), + ) + + self.assert_eq(kdf.dropna(axis=axis, thresh=2), pdf.dropna(axis=axis, thresh=2)) + self.assert_eq( + kdf.dropna(axis=axis, thresh=1, subset=["y", "z"]), + pdf.dropna(axis=axis, thresh=1, subset=["y", "z"]), + ) + + pdf2 = pdf.copy() + kdf2 = kdf.copy() + pser = pdf2[pdf2.columns[0]] + kser = kdf2[kdf2.columns[0]] + pdf2.dropna(inplace=True) + kdf2.dropna(inplace=True) + self.assert_eq(kdf2, pdf2) + self.assert_eq(kser, pser) + + # multi-index + columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]) + if axis == 0: + pdf.columns = columns + else: + pdf.index = columns + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.dropna(axis=axis), pdf.dropna(axis=axis)) + self.assert_eq(kdf.dropna(axis=axis, how="all"), pdf.dropna(axis=axis, how="all")) + self.assert_eq( + kdf.dropna(axis=axis, subset=[("a", "x")]), pdf.dropna(axis=axis, subset=[("a", "x")]) + ) + self.assert_eq( + kdf.dropna(axis=axis, subset=("a", "x")), pdf.dropna(axis=axis, subset=[("a", "x")]) + ) + self.assert_eq( + kdf.dropna(axis=axis, subset=[("a", "y"), ("b", "z")]), + pdf.dropna(axis=axis, subset=[("a", "y"), ("b", "z")]), + ) + self.assert_eq( + kdf.dropna(axis=axis, subset=[("a", "y"), ("b", "z")], how="all"), + pdf.dropna(axis=axis, subset=[("a", "y"), ("b", "z")], how="all"), + ) + + self.assert_eq(kdf.dropna(axis=axis, thresh=2), pdf.dropna(axis=axis, thresh=2)) + self.assert_eq( + kdf.dropna(axis=axis, thresh=1, subset=[("a", "y"), ("b", "z")]), + pdf.dropna(axis=axis, thresh=1, subset=[("a", "y"), ("b", "z")]), + ) + + def test_dropna_axis_index(self): + pdf = pd.DataFrame( + { + "x": [np.nan, 2, 3, 4, np.nan, 6], + "y": [1, 2, np.nan, 4, np.nan, np.nan], + "z": [1, 2, 3, 4, np.nan, np.nan], + }, + index=np.random.rand(6), + ) + kdf = pp.from_pandas(pdf) + + self._test_dropna(pdf, axis=0) + + # empty + pdf = pd.DataFrame(index=np.random.rand(6)) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.dropna(), pdf.dropna()) + self.assert_eq(kdf.dropna(how="all"), pdf.dropna(how="all")) + self.assert_eq(kdf.dropna(thresh=0), pdf.dropna(thresh=0)) + self.assert_eq(kdf.dropna(thresh=1), pdf.dropna(thresh=1)) + + with self.assertRaisesRegex(ValueError, "No axis named foo"): + kdf.dropna(axis="foo") + + self.assertRaises(KeyError, lambda: kdf.dropna(subset="1")) + with self.assertRaisesRegex(ValueError, "invalid how option: 1"): + kdf.dropna(how=1) + with self.assertRaisesRegex(TypeError, "must specify how or thresh"): + kdf.dropna(how=None) + + def test_dropna_axis_column(self): + pdf = pd.DataFrame( + { + "x": [np.nan, 2, 3, 4, np.nan, 6], + "y": [1, 2, np.nan, 4, np.nan, np.nan], + "z": [1, 2, 3, 4, np.nan, np.nan], + }, + index=[str(r) for r in np.random.rand(6)], + ).T + + self._test_dropna(pdf, axis=1) + + # empty + pdf = pd.DataFrame({"x": [], "y": [], "z": []}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.dropna(axis=1), pdf.dropna(axis=1)) + self.assert_eq(kdf.dropna(axis=1, how="all"), pdf.dropna(axis=1, how="all")) + self.assert_eq(kdf.dropna(axis=1, thresh=0), pdf.dropna(axis=1, thresh=0)) + self.assert_eq(kdf.dropna(axis=1, thresh=1), pdf.dropna(axis=1, thresh=1)) + + def test_dtype(self): + pdf = pd.DataFrame( + { + "a": list("abc"), + "b": list(range(1, 4)), + "c": np.arange(3, 6).astype("i1"), + "d": np.arange(4.0, 7.0, dtype="float64"), + "e": [True, False, True], + "f": pd.date_range("20130101", periods=3), + }, + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf, pdf) + self.assertTrue((kdf.dtypes == pdf.dtypes).all()) + + # multi-index columns + columns = pd.MultiIndex.from_tuples(zip(list("xxxyyz"), list("abcdef"))) + pdf.columns = columns + kdf.columns = columns + self.assertTrue((kdf.dtypes == pdf.dtypes).all()) + + def test_fillna(self): + pdf = pd.DataFrame( + { + "x": [np.nan, 2, 3, 4, np.nan, 6], + "y": [1, 2, np.nan, 4, np.nan, np.nan], + "z": [1, 2, 3, 4, np.nan, np.nan], + }, + index=np.random.rand(6), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf, pdf) + self.assert_eq(kdf.fillna(-1), pdf.fillna(-1)) + self.assert_eq( + kdf.fillna({"x": -1, "y": -2, "z": -5}), pdf.fillna({"x": -1, "y": -2, "z": -5}) + ) + self.assert_eq(pdf.fillna(method="ffill"), kdf.fillna(method="ffill")) + self.assert_eq(pdf.fillna(method="ffill", limit=2), kdf.fillna(method="ffill", limit=2)) + self.assert_eq(pdf.fillna(method="bfill"), kdf.fillna(method="bfill")) + self.assert_eq(pdf.fillna(method="bfill", limit=2), kdf.fillna(method="bfill", limit=2)) + + pdf = pdf.set_index(["x", "y"]) + kdf = pp.from_pandas(pdf) + # check multi index + self.assert_eq(kdf.fillna(-1), pdf.fillna(-1)) + self.assert_eq(pdf.fillna(method="bfill"), kdf.fillna(method="bfill")) + self.assert_eq(pdf.fillna(method="ffill"), kdf.fillna(method="ffill")) + + pser = pdf.z + kser = kdf.z + pdf.fillna({"x": -1, "y": -2, "z": -5}, inplace=True) + kdf.fillna({"x": -1, "y": -2, "z": -5}, inplace=True) + self.assert_eq(kdf, pdf) + self.assert_eq(kser, pser) + + s_nan = pd.Series([-1, -2, -5], index=["x", "y", "z"], dtype=int) + self.assert_eq(kdf.fillna(s_nan), pdf.fillna(s_nan)) + + with self.assertRaisesRegex(NotImplementedError, "fillna currently only"): + kdf.fillna(-1, axis=1) + with self.assertRaisesRegex(NotImplementedError, "fillna currently only"): + kdf.fillna(-1, axis="columns") + with self.assertRaisesRegex(ValueError, "limit parameter for value is not support now"): + kdf.fillna(-1, limit=1) + with self.assertRaisesRegex(TypeError, "Unsupported.*DataFrame"): + kdf.fillna(pd.DataFrame({"x": [-1], "y": [-1], "z": [-1]})) + with self.assertRaisesRegex(TypeError, "Unsupported.*int64"): + kdf.fillna({"x": np.int64(-6), "y": np.int64(-4), "z": -5}) + with self.assertRaisesRegex(ValueError, "Expecting 'pad', 'ffill', 'backfill' or 'bfill'."): + kdf.fillna(method="xxx") + with self.assertRaisesRegex( + ValueError, "Must specify a fillna 'value' or 'method' parameter." + ): + kdf.fillna() + + # multi-index columns + pdf = pd.DataFrame( + { + ("x", "a"): [np.nan, 2, 3, 4, np.nan, 6], + ("x", "b"): [1, 2, np.nan, 4, np.nan, np.nan], + ("y", "c"): [1, 2, 3, 4, np.nan, np.nan], + }, + index=np.random.rand(6), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.fillna(-1), pdf.fillna(-1)) + self.assert_eq( + kdf.fillna({("x", "a"): -1, ("x", "b"): -2, ("y", "c"): -5}), + pdf.fillna({("x", "a"): -1, ("x", "b"): -2, ("y", "c"): -5}), + ) + self.assert_eq(pdf.fillna(method="ffill"), kdf.fillna(method="ffill")) + self.assert_eq(pdf.fillna(method="ffill", limit=2), kdf.fillna(method="ffill", limit=2)) + self.assert_eq(pdf.fillna(method="bfill"), kdf.fillna(method="bfill")) + self.assert_eq(pdf.fillna(method="bfill", limit=2), kdf.fillna(method="bfill", limit=2)) + + self.assert_eq(kdf.fillna({"x": -1}), pdf.fillna({"x": -1})) + + if sys.version_info >= (3, 6): + # flaky in Python 3.5. + self.assert_eq( + kdf.fillna({"x": -1, ("x", "b"): -2}), pdf.fillna({"x": -1, ("x", "b"): -2}) + ) + self.assert_eq( + kdf.fillna({("x", "b"): -2, "x": -1}), pdf.fillna({("x", "b"): -2, "x": -1}) + ) + + # check multi index + pdf = pdf.set_index([("x", "a"), ("x", "b")]) + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf.fillna(-1), pdf.fillna(-1)) + self.assert_eq( + kdf.fillna({("x", "a"): -1, ("x", "b"): -2, ("y", "c"): -5}), + pdf.fillna({("x", "a"): -1, ("x", "b"): -2, ("y", "c"): -5}), + ) + + def test_isnull(self): + pdf = pd.DataFrame( + {"x": [1, 2, 3, 4, None, 6], "y": list("abdabd")}, index=np.random.rand(6) + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.notnull(), pdf.notnull()) + self.assert_eq(kdf.isnull(), pdf.isnull()) + + def test_to_datetime(self): + pdf = pd.DataFrame( + {"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}, index=np.random.rand(2) + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(pd.to_datetime(pdf), pp.to_datetime(kdf)) + + def test_nunique(self): + pdf = pd.DataFrame({"A": [1, 2, 3], "B": [np.nan, 3, np.nan]}, index=np.random.rand(3)) + kdf = pp.from_pandas(pdf) + + # Assert NaNs are dropped by default + self.assert_eq(kdf.nunique(), pdf.nunique()) + + # Assert including NaN values + self.assert_eq(kdf.nunique(dropna=False), pdf.nunique(dropna=False)) + + # Assert approximate counts + self.assert_eq( + pp.DataFrame({"A": range(100)}).nunique(approx=True), pd.Series([103], index=["A"]), + ) + self.assert_eq( + pp.DataFrame({"A": range(100)}).nunique(approx=True, rsd=0.01), + pd.Series([100], index=["A"]), + ) + + # Assert unsupported axis value yet + msg = 'axis should be either 0 or "index" currently.' + with self.assertRaisesRegex(NotImplementedError, msg): + kdf.nunique(axis=1) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("X", "A"), ("Y", "B")], names=["1", "2"]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.nunique(), pdf.nunique()) + self.assert_eq(kdf.nunique(dropna=False), pdf.nunique(dropna=False)) + + def test_sort_values(self): + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, None, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, index=np.random.rand(7) + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf.sort_values("b"), pdf.sort_values("b")) + self.assert_eq(kdf.sort_values(["b", "a"]), pdf.sort_values(["b", "a"])) + self.assert_eq( + kdf.sort_values(["b", "a"], ascending=[False, True]), + pdf.sort_values(["b", "a"], ascending=[False, True]), + ) + + self.assertRaises(ValueError, lambda: kdf.sort_values(["b", "a"], ascending=[False])) + + self.assert_eq( + kdf.sort_values(["b", "a"], na_position="first"), + pdf.sort_values(["b", "a"], na_position="first"), + ) + + self.assertRaises(ValueError, lambda: kdf.sort_values(["b", "a"], na_position="invalid")) + + pserA = pdf.a + kserA = kdf.a + self.assert_eq(kdf.sort_values("b", inplace=True), pdf.sort_values("b", inplace=True)) + self.assert_eq(kdf, pdf) + self.assert_eq(kserA, pserA) + + # multi-index columns + pdf = pd.DataFrame( + {("X", 10): [1, 2, 3, 4, 5, None, 7], ("X", 20): [7, 6, 5, 4, 3, 2, 1]}, + index=np.random.rand(7), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.sort_values(("X", 20)), pdf.sort_values(("X", 20))) + self.assert_eq( + kdf.sort_values([("X", 20), ("X", 10)]), pdf.sort_values([("X", 20), ("X", 10)]) + ) + + self.assertRaisesRegex( + ValueError, + "For a multi-index, the label must be a tuple with elements", + lambda: kdf.sort_values(["X"]), + ) + + # non-string names + pdf = pd.DataFrame( + {10: [1, 2, 3, 4, 5, None, 7], 20: [7, 6, 5, 4, 3, 2, 1]}, index=np.random.rand(7) + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.sort_values(20), pdf.sort_values(20)) + self.assert_eq(kdf.sort_values([20, 10]), pdf.sort_values([20, 10])) + + def test_sort_index(self): + pdf = pd.DataFrame( + {"A": [2, 1, np.nan], "B": [np.nan, 0, np.nan]}, index=["b", "a", np.nan] + ) + kdf = pp.from_pandas(pdf) + + # Assert invalid parameters + self.assertRaises(NotImplementedError, lambda: kdf.sort_index(axis=1)) + self.assertRaises(NotImplementedError, lambda: kdf.sort_index(kind="mergesort")) + self.assertRaises(ValueError, lambda: kdf.sort_index(na_position="invalid")) + + # Assert default behavior without parameters + self.assert_eq(kdf.sort_index(), pdf.sort_index()) + # Assert sorting descending + self.assert_eq(kdf.sort_index(ascending=False), pdf.sort_index(ascending=False)) + # Assert sorting NA indices first + self.assert_eq(kdf.sort_index(na_position="first"), pdf.sort_index(na_position="first")) + + # Assert sorting inplace + pserA = pdf.A + kserA = kdf.A + self.assertEqual(kdf.sort_index(inplace=True), pdf.sort_index(inplace=True)) + self.assert_eq(kdf, pdf) + self.assert_eq(kserA, pserA) + + # Assert multi-indices + pdf = pd.DataFrame( + {"A": range(4), "B": range(4)[::-1]}, index=[["b", "b", "a", "a"], [1, 0, 1, 0]] + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf.sort_index(), pdf.sort_index()) + self.assert_eq(kdf.sort_index(level=[1, 0]), pdf.sort_index(level=[1, 0])) + self.assert_eq(kdf.reset_index().sort_index(), pdf.reset_index().sort_index()) + + # Assert with multi-index columns + columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.sort_index(), pdf.sort_index()) + + def test_swaplevel(self): + # MultiIndex with two levels + arrays = [[1, 1, 2, 2], ["red", "blue", "red", "blue"]] + pidx = pd.MultiIndex.from_arrays(arrays, names=("number", "color")) + pdf = pd.DataFrame({"x1": ["a", "b", "c", "d"], "x2": ["a", "b", "c", "d"]}, index=pidx) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.swaplevel(), kdf.swaplevel()) + self.assert_eq(pdf.swaplevel(0, 1), kdf.swaplevel(0, 1)) + self.assert_eq(pdf.swaplevel(1, 1), kdf.swaplevel(1, 1)) + self.assert_eq(pdf.swaplevel("number", "color"), kdf.swaplevel("number", "color")) + + # MultiIndex with more than two levels + arrays = [[1, 1, 2, 2], ["red", "blue", "red", "blue"], ["l", "m", "s", "xs"]] + pidx = pd.MultiIndex.from_arrays(arrays, names=("number", "color", "size")) + pdf = pd.DataFrame({"x1": ["a", "b", "c", "d"], "x2": ["a", "b", "c", "d"]}, index=pidx) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.swaplevel(), kdf.swaplevel()) + self.assert_eq(pdf.swaplevel(0, 1), kdf.swaplevel(0, 1)) + self.assert_eq(pdf.swaplevel(0, 2), kdf.swaplevel(0, 2)) + self.assert_eq(pdf.swaplevel(1, 2), kdf.swaplevel(1, 2)) + self.assert_eq(pdf.swaplevel(1, 1), kdf.swaplevel(1, 1)) + self.assert_eq(pdf.swaplevel(-1, -2), kdf.swaplevel(-1, -2)) + self.assert_eq(pdf.swaplevel("number", "color"), kdf.swaplevel("number", "color")) + self.assert_eq(pdf.swaplevel("number", "size"), kdf.swaplevel("number", "size")) + self.assert_eq(pdf.swaplevel("color", "size"), kdf.swaplevel("color", "size")) + self.assert_eq( + pdf.swaplevel("color", "size", axis="index"), + kdf.swaplevel("color", "size", axis="index"), + ) + self.assert_eq( + pdf.swaplevel("color", "size", axis=0), kdf.swaplevel("color", "size", axis=0) + ) + + pdf = pd.DataFrame( + { + "x1": ["a", "b", "c", "d"], + "x2": ["a", "b", "c", "d"], + "x3": ["a", "b", "c", "d"], + "x4": ["a", "b", "c", "d"], + } + ) + pidx = pd.MultiIndex.from_arrays(arrays, names=("number", "color", "size")) + pdf.columns = pidx + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.swaplevel(axis=1), kdf.swaplevel(axis=1)) + self.assert_eq(pdf.swaplevel(0, 1, axis=1), kdf.swaplevel(0, 1, axis=1)) + self.assert_eq(pdf.swaplevel(0, 2, axis=1), kdf.swaplevel(0, 2, axis=1)) + self.assert_eq(pdf.swaplevel(1, 2, axis=1), kdf.swaplevel(1, 2, axis=1)) + self.assert_eq(pdf.swaplevel(1, 1, axis=1), kdf.swaplevel(1, 1, axis=1)) + self.assert_eq(pdf.swaplevel(-1, -2, axis=1), kdf.swaplevel(-1, -2, axis=1)) + self.assert_eq( + pdf.swaplevel("number", "color", axis=1), kdf.swaplevel("number", "color", axis=1) + ) + self.assert_eq( + pdf.swaplevel("number", "size", axis=1), kdf.swaplevel("number", "size", axis=1) + ) + self.assert_eq( + pdf.swaplevel("color", "size", axis=1), kdf.swaplevel("color", "size", axis=1) + ) + self.assert_eq( + pdf.swaplevel("color", "size", axis="columns"), + kdf.swaplevel("color", "size", axis="columns"), + ) + + # Error conditions + self.assertRaises(AssertionError, lambda: pp.DataFrame([1, 2]).swaplevel()) + self.assertRaises(IndexError, lambda: kdf.swaplevel(0, 9, axis=1)) + self.assertRaises(KeyError, lambda: kdf.swaplevel("not_number", "color", axis=1)) + self.assertRaises(ValueError, lambda: kdf.swaplevel(axis=2)) + + def test_swapaxes(self): + pdf = pd.DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], index=["x", "y", "z"], columns=["a", "b", "c"] + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.swapaxes(0, 1), pdf.swapaxes(0, 1)) + self.assert_eq(kdf.swapaxes(1, 0), pdf.swapaxes(1, 0)) + self.assert_eq(kdf.swapaxes("index", "columns"), pdf.swapaxes("index", "columns")) + self.assert_eq(kdf.swapaxes("columns", "index"), pdf.swapaxes("columns", "index")) + self.assert_eq((kdf + 1).swapaxes(0, 1), (pdf + 1).swapaxes(0, 1)) + + self.assertRaises(AssertionError, lambda: kdf.swapaxes(0, 1, copy=False)) + self.assertRaises(ValueError, lambda: kdf.swapaxes(0, -1)) + + def test_nlargest(self): + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, None, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, index=np.random.rand(7) + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf.nlargest(n=5, columns="a"), pdf.nlargest(5, columns="a")) + self.assert_eq(kdf.nlargest(n=5, columns=["a", "b"]), pdf.nlargest(5, columns=["a", "b"])) + + def test_nsmallest(self): + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, None, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, index=np.random.rand(7) + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf.nsmallest(n=5, columns="a"), pdf.nsmallest(5, columns="a")) + self.assert_eq(kdf.nsmallest(n=5, columns=["a", "b"]), pdf.nsmallest(5, columns=["a", "b"])) + + def test_xs(self): + d = { + "num_legs": [4, 4, 2, 2], + "num_wings": [0, 0, 2, 2], + "class": ["mammal", "mammal", "mammal", "bird"], + "animal": ["cat", "dog", "bat", "penguin"], + "locomotion": ["walks", "walks", "flies", "walks"], + } + pdf = pd.DataFrame(data=d) + pdf = pdf.set_index(["class", "animal", "locomotion"]) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.xs("mammal"), pdf.xs("mammal")) + self.assert_eq(kdf.xs(("mammal",)), pdf.xs(("mammal",))) + self.assert_eq(kdf.xs(("mammal", "dog", "walks")), pdf.xs(("mammal", "dog", "walks"))) + self.assert_eq( + pp.concat([kdf, kdf]).xs(("mammal", "dog", "walks")), + pd.concat([pdf, pdf]).xs(("mammal", "dog", "walks")), + ) + self.assert_eq(kdf.xs("cat", level=1), pdf.xs("cat", level=1)) + self.assert_eq(kdf.xs("flies", level=2), pdf.xs("flies", level=2)) + self.assert_eq(kdf.xs("mammal", level=-3), pdf.xs("mammal", level=-3)) + + msg = 'axis should be either 0 or "index" currently.' + with self.assertRaisesRegex(NotImplementedError, msg): + kdf.xs("num_wings", axis=1) + with self.assertRaises(KeyError): + kdf.xs(("mammal", "dog", "walk")) + msg = r"'Key length \(4\) exceeds index depth \(3\)'" + with self.assertRaisesRegex(KeyError, msg): + kdf.xs(("mammal", "dog", "walks", "foo")) + + self.assertRaises(IndexError, lambda: kdf.xs("foo", level=-4)) + self.assertRaises(IndexError, lambda: kdf.xs("foo", level=3)) + + self.assertRaises(KeyError, lambda: kdf.xs(("dog", "walks"), level=1)) + + # non-string names + pdf = pd.DataFrame(data=d) + pdf = pdf.set_index(["class", "animal", "num_legs", "num_wings"]) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.xs(("mammal", "dog", 4)), pdf.xs(("mammal", "dog", 4))) + self.assert_eq(kdf.xs(2, level=2), pdf.xs(2, level=2)) + + self.assert_eq((kdf + "a").xs(("mammal", "dog", 4)), (pdf + "a").xs(("mammal", "dog", 4))) + self.assert_eq((kdf + "a").xs(2, level=2), (pdf + "a").xs(2, level=2)) + + def test_missing(self): + kdf = self.kdf + + missing_functions = inspect.getmembers(_MissingPandasLikeDataFrame, inspect.isfunction) + unsupported_functions = [ + name for (name, type_) in missing_functions if type_.__name__ == "unsupported_function" + ] + for name in unsupported_functions: + with self.assertRaisesRegex( + PandasNotImplementedError, + "method.*DataFrame.*{}.*not implemented( yet\\.|\\. .+)".format(name), + ): + getattr(kdf, name)() + + deprecated_functions = [ + name for (name, type_) in missing_functions if type_.__name__ == "deprecated_function" + ] + for name in deprecated_functions: + with self.assertRaisesRegex( + PandasNotImplementedError, "method.*DataFrame.*{}.*is deprecated".format(name) + ): + getattr(kdf, name)() + + missing_properties = inspect.getmembers( + _MissingPandasLikeDataFrame, lambda o: isinstance(o, property) + ) + unsupported_properties = [ + name + for (name, type_) in missing_properties + if type_.fget.__name__ == "unsupported_property" + ] + for name in unsupported_properties: + with self.assertRaisesRegex( + PandasNotImplementedError, + "property.*DataFrame.*{}.*not implemented( yet\\.|\\. .+)".format(name), + ): + getattr(kdf, name) + deprecated_properties = [ + name + for (name, type_) in missing_properties + if type_.fget.__name__ == "deprecated_property" + ] + for name in deprecated_properties: + with self.assertRaisesRegex( + PandasNotImplementedError, "property.*DataFrame.*{}.*is deprecated".format(name) + ): + getattr(kdf, name) + + def test_to_numpy(self): + pdf = pd.DataFrame( + { + "a": [4, 2, 3, 4, 8, 6], + "b": [1, 2, 9, 4, 2, 4], + "c": ["one", "three", "six", "seven", "one", "5"], + }, + index=np.random.rand(6), + ) + + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.to_numpy(), pdf.values) + + def test_to_pandas(self): + pdf, kdf = self.df_pair + self.assert_eq(kdf.toPandas(), pdf) + self.assert_eq(kdf.to_pandas(), pdf) + + def test_isin(self): + pdf = pd.DataFrame( + { + "a": [4, 2, 3, 4, 8, 6], + "b": [1, 2, 9, 4, 2, 4], + "c": ["one", "three", "six", "seven", "one", "5"], + }, + index=np.random.rand(6), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.isin([4, "six"]), pdf.isin([4, "six"])) + # Seems like pandas has a bug when passing `np.array` as parameter + self.assert_eq(kdf.isin(np.array([4, "six"])), pdf.isin([4, "six"])) + self.assert_eq( + kdf.isin({"a": [2, 8], "c": ["three", "one"]}), + pdf.isin({"a": [2, 8], "c": ["three", "one"]}), + ) + self.assert_eq( + kdf.isin({"a": np.array([2, 8]), "c": ["three", "one"]}), + pdf.isin({"a": np.array([2, 8]), "c": ["three", "one"]}), + ) + + msg = "'DataFrame' object has no attribute {'e'}" + with self.assertRaisesRegex(AttributeError, msg): + kdf.isin({"e": [5, 7], "a": [1, 6]}) + + msg = "DataFrame and Series are not supported" + with self.assertRaisesRegex(NotImplementedError, msg): + kdf.isin(pdf) + + msg = "Values should be iterable, Series, DataFrame or dict." + with self.assertRaisesRegex(TypeError, msg): + kdf.isin(1) + + def test_merge(self): + left_pdf = pd.DataFrame( + { + "lkey": ["foo", "bar", "baz", "foo", "bar", "l"], + "value": [1, 2, 3, 5, 6, 7], + "x": list("abcdef"), + }, + columns=["lkey", "value", "x"], + ) + right_pdf = pd.DataFrame( + { + "rkey": ["baz", "foo", "bar", "baz", "foo", "r"], + "value": [4, 5, 6, 7, 8, 9], + "y": list("efghij"), + }, + columns=["rkey", "value", "y"], + ) + right_ps = pd.Series(list("defghi"), name="x", index=[5, 6, 7, 8, 9, 10]) + + left_kdf = pp.from_pandas(left_pdf) + right_kdf = pp.from_pandas(right_pdf) + right_kser = pp.from_pandas(right_ps) + + def check(op, right_kdf=right_kdf, right_pdf=right_pdf): + k_res = op(left_kdf, right_kdf) + k_res = k_res.to_pandas() + k_res = k_res.sort_values(by=list(k_res.columns)) + k_res = k_res.reset_index(drop=True) + p_res = op(left_pdf, right_pdf) + p_res = p_res.sort_values(by=list(p_res.columns)) + p_res = p_res.reset_index(drop=True) + self.assert_eq(k_res, p_res) + + check(lambda left, right: left.merge(right)) + check(lambda left, right: left.merge(right, on="value")) + check(lambda left, right: left.merge(right, left_on="lkey", right_on="rkey")) + check(lambda left, right: left.set_index("lkey").merge(right.set_index("rkey"))) + check( + lambda left, right: left.set_index("lkey").merge( + right, left_index=True, right_on="rkey" + ) + ) + check( + lambda left, right: left.merge( + right.set_index("rkey"), left_on="lkey", right_index=True + ) + ) + check( + lambda left, right: left.set_index("lkey").merge( + right.set_index("rkey"), left_index=True, right_index=True + ) + ) + + # MultiIndex + check( + lambda left, right: left.merge( + right, left_on=["lkey", "value"], right_on=["rkey", "value"] + ) + ) + check( + lambda left, right: left.set_index(["lkey", "value"]).merge( + right, left_index=True, right_on=["rkey", "value"] + ) + ) + check( + lambda left, right: left.merge( + right.set_index(["rkey", "value"]), left_on=["lkey", "value"], right_index=True + ) + ) + # TODO: when both left_index=True and right_index=True with multi-index + # check(lambda left, right: left.set_index(['lkey', 'value']).merge( + # right.set_index(['rkey', 'value']), left_index=True, right_index=True)) + + # join types + for how in ["inner", "left", "right", "outer"]: + check(lambda left, right: left.merge(right, on="value", how=how)) + check(lambda left, right: left.merge(right, left_on="lkey", right_on="rkey", how=how)) + + # suffix + check( + lambda left, right: left.merge( + right, left_on="lkey", right_on="rkey", suffixes=["_left", "_right"] + ) + ) + + # Test Series on the right + # pd.DataFrame.merge with Series is implemented since version 0.24.0 + if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"): + check(lambda left, right: left.merge(right), right_kser, right_ps) + check( + lambda left, right: left.merge(right, left_on="x", right_on="x"), + right_kser, + right_ps, + ) + check( + lambda left, right: left.set_index("x").merge(right, left_index=True, right_on="x"), + right_kser, + right_ps, + ) + + # Test join types with Series + for how in ["inner", "left", "right", "outer"]: + check(lambda left, right: left.merge(right, how=how), right_kser, right_ps) + check( + lambda left, right: left.merge(right, left_on="x", right_on="x", how=how), + right_kser, + right_ps, + ) + + # suffix with Series + check( + lambda left, right: left.merge( + right, + suffixes=["_left", "_right"], + how="outer", + left_index=True, + right_index=True, + ), + right_kser, + right_ps, + ) + + # multi-index columns + left_columns = pd.MultiIndex.from_tuples([(10, "lkey"), (10, "value"), (20, "x")]) + left_pdf.columns = left_columns + left_kdf.columns = left_columns + + right_columns = pd.MultiIndex.from_tuples([(10, "rkey"), (10, "value"), (30, "y")]) + right_pdf.columns = right_columns + right_kdf.columns = right_columns + + check(lambda left, right: left.merge(right)) + check(lambda left, right: left.merge(right, on=[(10, "value")])) + check( + lambda left, right: (left.set_index((10, "lkey")).merge(right.set_index((10, "rkey")))) + ) + check( + lambda left, right: ( + left.set_index((10, "lkey")).merge( + right.set_index((10, "rkey")), left_index=True, right_index=True + ) + ) + ) + # TODO: when both left_index=True and right_index=True with multi-index columns + # check(lambda left, right: left.merge(right, + # left_on=[('a', 'lkey')], right_on=[('a', 'rkey')])) + # check(lambda left, right: (left.set_index(('a', 'lkey')) + # .merge(right, left_index=True, right_on=[('a', 'rkey')]))) + + # non-string names + left_pdf.columns = [10, 100, 1000] + left_kdf.columns = [10, 100, 1000] + + right_pdf.columns = [20, 100, 2000] + right_kdf.columns = [20, 100, 2000] + + check(lambda left, right: left.merge(right)) + check(lambda left, right: left.merge(right, on=[100])) + check(lambda left, right: (left.set_index(10).merge(right.set_index(20)))) + check( + lambda left, right: ( + left.set_index(10).merge(right.set_index(20), left_index=True, right_index=True) + ) + ) + + def test_merge_same_anchor(self): + pdf = pd.DataFrame( + { + "lkey": ["foo", "bar", "baz", "foo", "bar", "l"], + "rkey": ["baz", "foo", "bar", "baz", "foo", "r"], + "value": [1, 1, 3, 5, 6, 7], + "x": list("abcdef"), + "y": list("efghij"), + }, + columns=["lkey", "rkey", "value", "x", "y"], + ) + kdf = pp.from_pandas(pdf) + + left_pdf = pdf[["lkey", "value", "x"]] + right_pdf = pdf[["rkey", "value", "y"]] + left_kdf = kdf[["lkey", "value", "x"]] + right_kdf = kdf[["rkey", "value", "y"]] + + def check(op, right_kdf=right_kdf, right_pdf=right_pdf): + k_res = op(left_kdf, right_kdf) + k_res = k_res.to_pandas() + k_res = k_res.sort_values(by=list(k_res.columns)) + k_res = k_res.reset_index(drop=True) + p_res = op(left_pdf, right_pdf) + p_res = p_res.sort_values(by=list(p_res.columns)) + p_res = p_res.reset_index(drop=True) + self.assert_eq(k_res, p_res) + + check(lambda left, right: left.merge(right)) + check(lambda left, right: left.merge(right, on="value")) + check(lambda left, right: left.merge(right, left_on="lkey", right_on="rkey")) + check(lambda left, right: left.set_index("lkey").merge(right.set_index("rkey"))) + check( + lambda left, right: left.set_index("lkey").merge( + right, left_index=True, right_on="rkey" + ) + ) + check( + lambda left, right: left.merge( + right.set_index("rkey"), left_on="lkey", right_index=True + ) + ) + check( + lambda left, right: left.set_index("lkey").merge( + right.set_index("rkey"), left_index=True, right_index=True + ) + ) + + def test_merge_retains_indices(self): + left_pdf = pd.DataFrame({"A": [0, 1]}) + right_pdf = pd.DataFrame({"B": [1, 2]}, index=[1, 2]) + left_kdf = pp.from_pandas(left_pdf) + right_kdf = pp.from_pandas(right_pdf) + + self.assert_eq( + left_kdf.merge(right_kdf, left_index=True, right_index=True), + left_pdf.merge(right_pdf, left_index=True, right_index=True), + ) + self.assert_eq( + left_kdf.merge(right_kdf, left_on="A", right_index=True), + left_pdf.merge(right_pdf, left_on="A", right_index=True), + ) + self.assert_eq( + left_kdf.merge(right_kdf, left_index=True, right_on="B"), + left_pdf.merge(right_pdf, left_index=True, right_on="B"), + ) + self.assert_eq( + left_kdf.merge(right_kdf, left_on="A", right_on="B"), + left_pdf.merge(right_pdf, left_on="A", right_on="B"), + ) + + def test_merge_how_parameter(self): + left_pdf = pd.DataFrame({"A": [1, 2]}) + right_pdf = pd.DataFrame({"B": ["x", "y"]}, index=[1, 2]) + left_kdf = pp.from_pandas(left_pdf) + right_kdf = pp.from_pandas(right_pdf) + + kdf = left_kdf.merge(right_kdf, left_index=True, right_index=True) + pdf = left_pdf.merge(right_pdf, left_index=True, right_index=True) + self.assert_eq( + kdf.sort_values(by=list(kdf.columns)).reset_index(drop=True), + pdf.sort_values(by=list(pdf.columns)).reset_index(drop=True), + ) + + kdf = left_kdf.merge(right_kdf, left_index=True, right_index=True, how="left") + pdf = left_pdf.merge(right_pdf, left_index=True, right_index=True, how="left") + self.assert_eq( + kdf.sort_values(by=list(kdf.columns)).reset_index(drop=True), + pdf.sort_values(by=list(pdf.columns)).reset_index(drop=True), + ) + + kdf = left_kdf.merge(right_kdf, left_index=True, right_index=True, how="right") + pdf = left_pdf.merge(right_pdf, left_index=True, right_index=True, how="right") + self.assert_eq( + kdf.sort_values(by=list(kdf.columns)).reset_index(drop=True), + pdf.sort_values(by=list(pdf.columns)).reset_index(drop=True), + ) + + kdf = left_kdf.merge(right_kdf, left_index=True, right_index=True, how="outer") + pdf = left_pdf.merge(right_pdf, left_index=True, right_index=True, how="outer") + self.assert_eq( + kdf.sort_values(by=list(kdf.columns)).reset_index(drop=True), + pdf.sort_values(by=list(pdf.columns)).reset_index(drop=True), + ) + + def test_merge_raises(self): + left = pp.DataFrame( + {"value": [1, 2, 3, 5, 6], "x": list("abcde")}, + columns=["value", "x"], + index=["foo", "bar", "baz", "foo", "bar"], + ) + right = pp.DataFrame( + {"value": [4, 5, 6, 7, 8], "y": list("fghij")}, + columns=["value", "y"], + index=["baz", "foo", "bar", "baz", "foo"], + ) + + with self.assertRaisesRegex(ValueError, "No common columns to perform merge on"): + left[["x"]].merge(right[["y"]]) + + with self.assertRaisesRegex(ValueError, "not a combination of both"): + left.merge(right, on="value", left_on="x") + + with self.assertRaisesRegex(ValueError, "Must pass right_on or right_index=True"): + left.merge(right, left_on="x") + + with self.assertRaisesRegex(ValueError, "Must pass right_on or right_index=True"): + left.merge(right, left_index=True) + + with self.assertRaisesRegex(ValueError, "Must pass left_on or left_index=True"): + left.merge(right, right_on="y") + + with self.assertRaisesRegex(ValueError, "Must pass left_on or left_index=True"): + left.merge(right, right_index=True) + + with self.assertRaisesRegex( + ValueError, "len\\(left_keys\\) must equal len\\(right_keys\\)" + ): + left.merge(right, left_on="value", right_on=["value", "y"]) + + with self.assertRaisesRegex( + ValueError, "len\\(left_keys\\) must equal len\\(right_keys\\)" + ): + left.merge(right, left_on=["value", "x"], right_on="value") + + with self.assertRaisesRegex(ValueError, "['inner', 'left', 'right', 'full', 'outer']"): + left.merge(right, left_index=True, right_index=True, how="foo") + + with self.assertRaisesRegex(KeyError, "id"): + left.merge(right, on="id") + + def test_append(self): + pdf = pd.DataFrame([[1, 2], [3, 4]], columns=list("AB")) + kdf = pp.from_pandas(pdf) + other_pdf = pd.DataFrame([[3, 4], [5, 6]], columns=list("BC"), index=[2, 3]) + other_kdf = pp.from_pandas(other_pdf) + + self.assert_eq(kdf.append(kdf), pdf.append(pdf)) + self.assert_eq(kdf.append(kdf, ignore_index=True), pdf.append(pdf, ignore_index=True)) + + # Assert DataFrames with non-matching columns + self.assert_eq(kdf.append(other_kdf), pdf.append(other_pdf)) + + # Assert appending a Series fails + msg = "DataFrames.append() does not support appending Series to DataFrames" + with self.assertRaises(ValueError, msg=msg): + kdf.append(kdf["A"]) + + # Assert using the sort parameter raises an exception + msg = "The 'sort' parameter is currently not supported" + with self.assertRaises(NotImplementedError, msg=msg): + kdf.append(kdf, sort=True) + + # Assert using 'verify_integrity' only raises an exception for overlapping indices + self.assert_eq( + kdf.append(other_kdf, verify_integrity=True), + pdf.append(other_pdf, verify_integrity=True), + ) + msg = "Indices have overlapping values" + with self.assertRaises(ValueError, msg=msg): + kdf.append(kdf, verify_integrity=True) + + # Skip integrity verification when ignore_index=True + self.assert_eq( + kdf.append(kdf, ignore_index=True, verify_integrity=True), + pdf.append(pdf, ignore_index=True, verify_integrity=True), + ) + + # Assert appending multi-index DataFrames + multi_index_pdf = pd.DataFrame([[1, 2], [3, 4]], columns=list("AB"), index=[[2, 3], [4, 5]]) + multi_index_kdf = pp.from_pandas(multi_index_pdf) + other_multi_index_pdf = pd.DataFrame( + [[5, 6], [7, 8]], columns=list("AB"), index=[[2, 3], [6, 7]] + ) + other_multi_index_kdf = pp.from_pandas(other_multi_index_pdf) + + self.assert_eq( + multi_index_kdf.append(multi_index_kdf), multi_index_pdf.append(multi_index_pdf) + ) + + # Assert DataFrames with non-matching columns + self.assert_eq( + multi_index_kdf.append(other_multi_index_kdf), + multi_index_pdf.append(other_multi_index_pdf), + ) + + # Assert using 'verify_integrity' only raises an exception for overlapping indices + self.assert_eq( + multi_index_kdf.append(other_multi_index_kdf, verify_integrity=True), + multi_index_pdf.append(other_multi_index_pdf, verify_integrity=True), + ) + with self.assertRaises(ValueError, msg=msg): + multi_index_kdf.append(multi_index_kdf, verify_integrity=True) + + # Skip integrity verification when ignore_index=True + self.assert_eq( + multi_index_kdf.append(multi_index_kdf, ignore_index=True, verify_integrity=True), + multi_index_pdf.append(multi_index_pdf, ignore_index=True, verify_integrity=True), + ) + + # Assert trying to append DataFrames with different index levels + msg = "Both DataFrames have to have the same number of index levels" + with self.assertRaises(ValueError, msg=msg): + kdf.append(multi_index_kdf) + + # Skip index level check when ignore_index=True + self.assert_eq( + kdf.append(multi_index_kdf, ignore_index=True), + pdf.append(multi_index_pdf, ignore_index=True), + ) + + columns = pd.MultiIndex.from_tuples([("A", "X"), ("A", "Y")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.append(kdf), pdf.append(pdf)) + + def test_clip(self): + pdf = pd.DataFrame( + {"A": [0, 2, 4], "B": [4, 2, 0], "X": [-1, 10, 0]}, index=np.random.rand(3) + ) + kdf = pp.from_pandas(pdf) + + # Assert list-like values are not accepted for 'lower' and 'upper' + msg = "List-like value are not supported for 'lower' and 'upper' at the moment" + with self.assertRaises(ValueError, msg=msg): + kdf.clip(lower=[1]) + with self.assertRaises(ValueError, msg=msg): + kdf.clip(upper=[1]) + + # Assert no lower or upper + self.assert_eq(kdf.clip(), pdf.clip()) + # Assert lower only + self.assert_eq(kdf.clip(1), pdf.clip(1)) + # Assert upper only + self.assert_eq(kdf.clip(upper=3), pdf.clip(upper=3)) + # Assert lower and upper + self.assert_eq(kdf.clip(1, 3), pdf.clip(1, 3)) + + pdf["clip"] = pdf.A.clip(lower=1, upper=3) + kdf["clip"] = kdf.A.clip(lower=1, upper=3) + self.assert_eq(kdf, pdf) + + # Assert behavior on string values + str_kdf = pp.DataFrame({"A": ["a", "b", "c"]}, index=np.random.rand(3)) + self.assert_eq(str_kdf.clip(1, 3), str_kdf) + + def test_binary_operators(self): + pdf = pd.DataFrame( + {"A": [0, 2, 4], "B": [4, 2, 0], "X": [-1, 10, 0]}, index=np.random.rand(3) + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf + kdf.copy(), pdf + pdf.copy()) + + self.assertRaisesRegex( + ValueError, + "it comes from a different dataframe", + lambda: pp.range(10).add(pp.range(10)), + ) + + self.assertRaisesRegex( + ValueError, + "add with a sequence is currently not supported", + lambda: pp.range(10).add(pp.range(10).id), + ) + + def test_binary_operator_add(self): + # Positive + pdf = pd.DataFrame({"a": ["x"], "b": ["y"], "c": [1], "d": [2]}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf["a"] + kdf["b"], pdf["a"] + pdf["b"]) + self.assert_eq(kdf["c"] + kdf["d"], pdf["c"] + pdf["d"]) + + # Negative + ks_err_msg = "string addition can only be applied to string series or literals" + + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] + kdf["c"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["c"] + kdf["a"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["c"] + "literal") + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: "literal" + kdf["c"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: 1 + kdf["a"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] + 1) + + def test_binary_operator_sub(self): + # Positive + pdf = pd.DataFrame({"a": [2], "b": [1]}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf["a"] - kdf["b"], pdf["a"] - pdf["b"]) + + # Negative + kdf = pp.DataFrame({"a": ["x"], "b": [1]}) + ks_err_msg = "substraction can not be applied to string series or literals" + + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] - kdf["b"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] - kdf["a"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] - "literal") + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: "literal" - kdf["b"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: 1 - kdf["a"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] - 1) + + kdf = pp.DataFrame({"a": ["x"], "b": ["y"]}) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] - kdf["b"]) + + def test_binary_operator_truediv(self): + # Positive + pdf = pd.DataFrame({"a": [3], "b": [2]}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf["a"] / kdf["b"], pdf["a"] / pdf["b"]) + + # Negative + kdf = pp.DataFrame({"a": ["x"], "b": [1]}) + ks_err_msg = "division can not be applied on string series or literals" + + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] / kdf["b"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] / kdf["a"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] / "literal") + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: "literal" / kdf["b"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: 1 / kdf["a"]) + + def test_binary_operator_floordiv(self): + kdf = pp.DataFrame({"a": ["x"], "b": [1]}) + ks_err_msg = "division can not be applied on string series or literals" + + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] // kdf["b"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] // kdf["a"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] // "literal") + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: "literal" // kdf["b"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: 1 // kdf["a"]) + + def test_binary_operator_mod(self): + # Positive + pdf = pd.DataFrame({"a": [3], "b": [2]}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf["a"] % kdf["b"], pdf["a"] % pdf["b"]) + + # Negative + kdf = pp.DataFrame({"a": ["x"], "b": [1]}) + ks_err_msg = "modulo can not be applied on string series or literals" + + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] % kdf["b"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] % kdf["a"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] % "literal") + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: 1 % kdf["a"]) + + def test_binary_operator_multiply(self): + # Positive + pdf = pd.DataFrame({"a": ["x", "y"], "b": [1, 2], "c": [3, 4]}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf["b"] * kdf["c"], pdf["b"] * pdf["c"]) + self.assert_eq(kdf["c"] * kdf["b"], pdf["c"] * pdf["b"]) + self.assert_eq(kdf["a"] * kdf["b"], pdf["a"] * pdf["b"]) + self.assert_eq(kdf["b"] * kdf["a"], pdf["b"] * pdf["a"]) + self.assert_eq(kdf["a"] * 2, pdf["a"] * 2) + self.assert_eq(kdf["b"] * 2, pdf["b"] * 2) + self.assert_eq(2 * kdf["a"], 2 * pdf["a"]) + self.assert_eq(2 * kdf["b"], 2 * pdf["b"]) + + # Negative + kdf = pp.DataFrame({"a": ["x"], "b": [2]}) + ks_err_msg = "multiplication can not be applied to a string literal" + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["b"] * "literal") + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: "literal" * kdf["b"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] * "literal") + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: "literal" * kdf["a"]) + + ks_err_msg = "a string series can only be multiplied to an int series or literal" + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] * kdf["a"]) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: kdf["a"] * 0.1) + self.assertRaisesRegex(TypeError, ks_err_msg, lambda: 0.1 * kdf["a"]) + + def test_sample(self): + pdf = pd.DataFrame({"A": [0, 2, 4]}) + kdf = pp.from_pandas(pdf) + + # Make sure the tests run, but we can't check the result because they are non-deterministic. + kdf.sample(frac=0.1) + kdf.sample(frac=0.2, replace=True) + kdf.sample(frac=0.2, random_state=5) + kdf["A"].sample(frac=0.2) + kdf["A"].sample(frac=0.2, replace=True) + kdf["A"].sample(frac=0.2, random_state=5) + + with self.assertRaises(ValueError): + kdf.sample() + with self.assertRaises(NotImplementedError): + kdf.sample(n=1) + + def test_add_prefix(self): + pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}, index=np.random.rand(4)) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.add_prefix("col_"), kdf.add_prefix("col_")) + + columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")]) + pdf.columns = columns + kdf.columns = columns + self.assert_eq(pdf.add_prefix("col_"), kdf.add_prefix("col_")) + + def test_add_suffix(self): + pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}, index=np.random.rand(4)) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.add_suffix("first_series"), kdf.add_suffix("first_series")) + + columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")]) + pdf.columns = columns + kdf.columns = columns + self.assert_eq(pdf.add_suffix("first_series"), kdf.add_suffix("first_series")) + + def test_join(self): + # check basic function + pdf1 = pd.DataFrame( + {"key": ["K0", "K1", "K2", "K3"], "A": ["A0", "A1", "A2", "A3"]}, columns=["key", "A"] + ) + pdf2 = pd.DataFrame( + {"key": ["K0", "K1", "K2"], "B": ["B0", "B1", "B2"]}, columns=["key", "B"] + ) + kdf1 = pp.from_pandas(pdf1) + kdf2 = pp.from_pandas(pdf2) + + join_pdf = pdf1.join(pdf2, lsuffix="_left", rsuffix="_right") + join_pdf.sort_values(by=list(join_pdf.columns), inplace=True) + + join_kdf = kdf1.join(kdf2, lsuffix="_left", rsuffix="_right") + join_kdf.sort_values(by=list(join_kdf.columns), inplace=True) + + self.assert_eq(join_pdf, join_kdf) + + # join with duplicated columns in Series + with self.assertRaisesRegex(ValueError, "columns overlap but no suffix specified"): + ks1 = pp.Series(["A1", "A5"], index=[1, 2], name="A") + kdf1.join(ks1, how="outer") + # join with duplicated columns in DataFrame + with self.assertRaisesRegex(ValueError, "columns overlap but no suffix specified"): + kdf1.join(kdf2, how="outer") + + # check `on` parameter + join_pdf = pdf1.join(pdf2.set_index("key"), on="key", lsuffix="_left", rsuffix="_right") + join_pdf.sort_values(by=list(join_pdf.columns), inplace=True) + + join_kdf = kdf1.join(kdf2.set_index("key"), on="key", lsuffix="_left", rsuffix="_right") + join_kdf.sort_values(by=list(join_kdf.columns), inplace=True) + self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True)) + + join_pdf = pdf1.set_index("key").join( + pdf2.set_index("key"), on="key", lsuffix="_left", rsuffix="_right" + ) + join_pdf.sort_values(by=list(join_pdf.columns), inplace=True) + + join_kdf = kdf1.set_index("key").join( + kdf2.set_index("key"), on="key", lsuffix="_left", rsuffix="_right" + ) + join_kdf.sort_values(by=list(join_kdf.columns), inplace=True) + self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True)) + + # multi-index columns + columns1 = pd.MultiIndex.from_tuples([("x", "key"), ("Y", "A")]) + columns2 = pd.MultiIndex.from_tuples([("x", "key"), ("Y", "B")]) + pdf1.columns = columns1 + pdf2.columns = columns2 + kdf1.columns = columns1 + kdf2.columns = columns2 + + join_pdf = pdf1.join(pdf2, lsuffix="_left", rsuffix="_right") + join_pdf.sort_values(by=list(join_pdf.columns), inplace=True) + + join_kdf = kdf1.join(kdf2, lsuffix="_left", rsuffix="_right") + join_kdf.sort_values(by=list(join_kdf.columns), inplace=True) + + self.assert_eq(join_pdf, join_kdf) + + # check `on` parameter + join_pdf = pdf1.join( + pdf2.set_index(("x", "key")), on=[("x", "key")], lsuffix="_left", rsuffix="_right" + ) + join_pdf.sort_values(by=list(join_pdf.columns), inplace=True) + + join_kdf = kdf1.join( + kdf2.set_index(("x", "key")), on=[("x", "key")], lsuffix="_left", rsuffix="_right" + ) + join_kdf.sort_values(by=list(join_kdf.columns), inplace=True) + + self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True)) + + join_pdf = pdf1.set_index(("x", "key")).join( + pdf2.set_index(("x", "key")), on=[("x", "key")], lsuffix="_left", rsuffix="_right" + ) + join_pdf.sort_values(by=list(join_pdf.columns), inplace=True) + + join_kdf = kdf1.set_index(("x", "key")).join( + kdf2.set_index(("x", "key")), on=[("x", "key")], lsuffix="_left", rsuffix="_right" + ) + join_kdf.sort_values(by=list(join_kdf.columns), inplace=True) + + self.assert_eq(join_pdf.reset_index(drop=True), join_kdf.reset_index(drop=True)) + + # multi-index + midx1 = pd.MultiIndex.from_tuples( + [("w", "a"), ("x", "b"), ("y", "c"), ("z", "d")], names=["index1", "index2"] + ) + midx2 = pd.MultiIndex.from_tuples( + [("w", "a"), ("x", "b"), ("y", "c")], names=["index1", "index2"] + ) + pdf1.index = midx1 + pdf2.index = midx2 + kdf1 = pp.from_pandas(pdf1) + kdf2 = pp.from_pandas(pdf2) + + join_pdf = pdf1.join(pdf2, on=["index1", "index2"], rsuffix="_right") + join_pdf.sort_values(by=list(join_pdf.columns), inplace=True) + + join_kdf = kdf1.join(kdf2, on=["index1", "index2"], rsuffix="_right") + join_kdf.sort_values(by=list(join_kdf.columns), inplace=True) + + self.assert_eq(join_pdf, join_kdf) + + with self.assertRaisesRegex( + ValueError, r'len\(left_on\) must equal the number of levels in the index of "right"' + ): + kdf1.join(kdf2, on=["index1"], rsuffix="_right") + + def test_replace(self): + pdf = pd.DataFrame( + { + "name": ["Ironman", "Captain America", "Thor", "Hulk"], + "weapon": ["Mark-45", "Shield", "Mjolnir", "Smash"], + }, + index=np.random.rand(4), + ) + kdf = pp.from_pandas(pdf) + + with self.assertRaisesRegex( + NotImplementedError, "replace currently works only for method='pad" + ): + kdf.replace(method="bfill") + with self.assertRaisesRegex( + NotImplementedError, "replace currently works only when limit=None" + ): + kdf.replace(limit=10) + with self.assertRaisesRegex( + NotImplementedError, "replace currently doesn't supports regex" + ): + kdf.replace(regex="") + + with self.assertRaisesRegex(ValueError, "Length of to_replace and value must be same"): + kdf.replace(to_replace=["Ironman"], value=["Spiderman", "Doctor Strange"]) + + self.assert_eq(kdf.replace("Ironman", "Spiderman"), pdf.replace("Ironman", "Spiderman")) + self.assert_eq( + kdf.replace(["Ironman", "Captain America"], ["Rescue", "Hawkeye"]), + pdf.replace(["Ironman", "Captain America"], ["Rescue", "Hawkeye"]), + ) + self.assert_eq( + kdf.replace(("Ironman", "Captain America"), ("Rescue", "Hawkeye")), + pdf.replace(("Ironman", "Captain America"), ("Rescue", "Hawkeye")), + ) + + # inplace + pser = pdf.name + kser = kdf.name + pdf.replace("Ironman", "Spiderman", inplace=True) + kdf.replace("Ironman", "Spiderman", inplace=True) + self.assert_eq(kdf, pdf) + self.assert_eq(kser, pser) + + pdf = pd.DataFrame( + {"A": [0, 1, 2, 3, np.nan], "B": [5, 6, 7, 8, np.nan], "C": ["a", "b", "c", "d", None]}, + index=np.random.rand(5), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.replace([0, 1, 2, 3, 5, 6], 4), pdf.replace([0, 1, 2, 3, 5, 6], 4)) + + self.assert_eq( + kdf.replace([0, 1, 2, 3, 5, 6], [6, 5, 4, 3, 2, 1]), + pdf.replace([0, 1, 2, 3, 5, 6], [6, 5, 4, 3, 2, 1]), + ) + + self.assert_eq(kdf.replace({0: 10, 1: 100, 7: 200}), pdf.replace({0: 10, 1: 100, 7: 200})) + + self.assert_eq( + kdf.replace({"A": [0, np.nan], "B": [5, np.nan]}, 100), + pdf.replace({"A": [0, np.nan], "B": [5, np.nan]}, 100), + ) + + self.assert_eq( + kdf.replace({"A": {0: 100, 4: 400, np.nan: 700}}), + pdf.replace({"A": {0: 100, 4: 400, np.nan: 700}}), + ) + self.assert_eq( + kdf.replace({"X": {0: 100, 4: 400, np.nan: 700}}), + pdf.replace({"X": {0: 100, 4: 400, np.nan: 700}}), + ) + + self.assert_eq(kdf.replace({"C": ["a", None]}, "e"), pdf.replace({"C": ["a", None]}, "e")) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.replace([0, 1, 2, 3, 5, 6], 4), pdf.replace([0, 1, 2, 3, 5, 6], 4)) + + self.assert_eq( + kdf.replace([0, 1, 2, 3, 5, 6], [6, 5, 4, 3, 2, 1]), + pdf.replace([0, 1, 2, 3, 5, 6], [6, 5, 4, 3, 2, 1]), + ) + + self.assert_eq(kdf.replace({0: 10, 1: 100, 7: 200}), pdf.replace({0: 10, 1: 100, 7: 200})) + + self.assert_eq( + kdf.replace({("X", "A"): [0, np.nan], ("X", "B"): 5}, 100), + pdf.replace({("X", "A"): [0, np.nan], ("X", "B"): 5}, 100), + ) + + self.assert_eq( + kdf.replace({("X", "A"): {0: 100, 4: 400, np.nan: 700}}), + pdf.replace({("X", "A"): {0: 100, 4: 400, np.nan: 700}}), + ) + self.assert_eq( + kdf.replace({("X", "B"): {0: 100, 4: 400, np.nan: 700}}), + pdf.replace({("X", "B"): {0: 100, 4: 400, np.nan: 700}}), + ) + + self.assert_eq( + kdf.replace({("Y", "C"): ["a", None]}, "e"), pdf.replace({("Y", "C"): ["a", None]}, "e") + ) + + def test_update(self): + # check base function + def get_data(left_columns=None, right_columns=None): + left_pdf = pd.DataFrame( + {"A": ["1", "2", "3", "4"], "B": ["100", "200", np.nan, np.nan]}, columns=["A", "B"] + ) + right_pdf = pd.DataFrame( + {"B": ["x", np.nan, "y", np.nan], "C": ["100", "200", "300", "400"]}, + columns=["B", "C"], + ) + + left_kdf = pp.DataFrame( + {"A": ["1", "2", "3", "4"], "B": ["100", "200", None, None]}, columns=["A", "B"] + ) + right_kdf = pp.DataFrame( + {"B": ["x", None, "y", None], "C": ["100", "200", "300", "400"]}, columns=["B", "C"] + ) + if left_columns is not None: + left_pdf.columns = left_columns + left_kdf.columns = left_columns + if right_columns is not None: + right_pdf.columns = right_columns + right_kdf.columns = right_columns + return left_kdf, left_pdf, right_kdf, right_pdf + + left_kdf, left_pdf, right_kdf, right_pdf = get_data() + pser = left_pdf.B + kser = left_kdf.B + left_pdf.update(right_pdf) + left_kdf.update(right_kdf) + self.assert_eq(left_pdf.sort_values(by=["A", "B"]), left_kdf.sort_values(by=["A", "B"])) + self.assert_eq(kser.sort_index(), pser.sort_index()) + + left_kdf, left_pdf, right_kdf, right_pdf = get_data() + left_pdf.update(right_pdf, overwrite=False) + left_kdf.update(right_kdf, overwrite=False) + self.assert_eq(left_pdf.sort_values(by=["A", "B"]), left_kdf.sort_values(by=["A", "B"])) + + with self.assertRaises(NotImplementedError): + left_kdf.update(right_kdf, join="right") + + # multi-index columns + left_columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")]) + right_columns = pd.MultiIndex.from_tuples([("X", "B"), ("Y", "C")]) + + left_kdf, left_pdf, right_kdf, right_pdf = get_data( + left_columns=left_columns, right_columns=right_columns + ) + left_pdf.update(right_pdf) + left_kdf.update(right_kdf) + self.assert_eq( + left_pdf.sort_values(by=[("X", "A"), ("X", "B")]), + left_kdf.sort_values(by=[("X", "A"), ("X", "B")]), + ) + + left_kdf, left_pdf, right_kdf, right_pdf = get_data( + left_columns=left_columns, right_columns=right_columns + ) + left_pdf.update(right_pdf, overwrite=False) + left_kdf.update(right_kdf, overwrite=False) + self.assert_eq( + left_pdf.sort_values(by=[("X", "A"), ("X", "B")]), + left_kdf.sort_values(by=[("X", "A"), ("X", "B")]), + ) + + right_columns = pd.MultiIndex.from_tuples([("Y", "B"), ("Y", "C")]) + left_kdf, left_pdf, right_kdf, right_pdf = get_data( + left_columns=left_columns, right_columns=right_columns + ) + left_pdf.update(right_pdf) + left_kdf.update(right_kdf) + self.assert_eq( + left_pdf.sort_values(by=[("X", "A"), ("X", "B")]), + left_kdf.sort_values(by=[("X", "A"), ("X", "B")]), + ) + + def test_pivot_table_dtypes(self): + pdf = pd.DataFrame( + { + "a": [4, 2, 3, 4, 8, 6], + "b": [1, 2, 2, 4, 2, 4], + "e": [1, 2, 2, 4, 2, 4], + "c": [1, 2, 9, 4, 7, 4], + }, + index=np.random.rand(6), + ) + kdf = pp.from_pandas(pdf) + + # Skip columns comparison by reset_index + res_df = kdf.pivot_table( + index=["c"], columns="a", values=["b"], aggfunc={"b": "mean"} + ).dtypes.reset_index(drop=True) + exp_df = pdf.pivot_table( + index=["c"], columns="a", values=["b"], aggfunc={"b": "mean"} + ).dtypes.reset_index(drop=True) + self.assert_eq(res_df, exp_df) + + # Results don't have the same column's name + + # Todo: self.assert_eq(kdf.pivot_table(columns="a", values="b").dtypes, + # pdf.pivot_table(columns="a", values="b").dtypes) + + # Todo: self.assert_eq(kdf.pivot_table(index=['c'], columns="a", values="b").dtypes, + # pdf.pivot_table(index=['c'], columns="a", values="b").dtypes) + + # Todo: self.assert_eq(kdf.pivot_table(index=['e', 'c'], columns="a", values="b").dtypes, + # pdf.pivot_table(index=['e', 'c'], columns="a", values="b").dtypes) + + # Todo: self.assert_eq(kdf.pivot_table(index=['e', 'c'], + # columns="a", values="b", fill_value=999).dtypes, pdf.pivot_table(index=['e', 'c'], + # columns="a", values="b", fill_value=999).dtypes) + + def test_pivot_table(self): + pdf = pd.DataFrame( + { + "a": [4, 2, 3, 4, 8, 6], + "b": [1, 2, 2, 4, 2, 4], + "e": [10, 20, 20, 40, 20, 40], + "c": [1, 2, 9, 4, 7, 4], + "d": [-1, -2, -3, -4, -5, -6], + }, + index=np.random.rand(6), + ) + kdf = pp.from_pandas(pdf) + + # Checking if both DataFrames have the same results + self.assert_eq( + kdf.pivot_table(columns="a", values="b").sort_index(), + pdf.pivot_table(columns="a", values="b").sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table(index=["c"], columns="a", values="b").sort_index(), + pdf.pivot_table(index=["c"], columns="a", values="b").sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table(index=["c"], columns="a", values="b", aggfunc="sum").sort_index(), + pdf.pivot_table(index=["c"], columns="a", values="b", aggfunc="sum").sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table(index=["c"], columns="a", values=["b"], aggfunc="sum").sort_index(), + pdf.pivot_table(index=["c"], columns="a", values=["b"], aggfunc="sum").sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table( + index=["c"], columns="a", values=["b", "e"], aggfunc="sum" + ).sort_index(), + pdf.pivot_table( + index=["c"], columns="a", values=["b", "e"], aggfunc="sum" + ).sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table( + index=["c"], columns="a", values=["b", "e", "d"], aggfunc="sum" + ).sort_index(), + pdf.pivot_table( + index=["c"], columns="a", values=["b", "e", "d"], aggfunc="sum" + ).sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table( + index=["c"], columns="a", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"} + ).sort_index(), + pdf.pivot_table( + index=["c"], columns="a", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"} + ).sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table(index=["e", "c"], columns="a", values="b").sort_index(), + pdf.pivot_table(index=["e", "c"], columns="a", values="b").sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table(index=["e", "c"], columns="a", values="b", fill_value=999).sort_index(), + pdf.pivot_table(index=["e", "c"], columns="a", values="b", fill_value=999).sort_index(), + almost=True, + ) + + # multi-index columns + columns = pd.MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "e"), ("z", "c"), ("w", "d")] + ) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + kdf.pivot_table(columns=("x", "a"), values=("x", "b")).sort_index(), + pdf.pivot_table(columns=[("x", "a")], values=[("x", "b")]).sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table( + index=[("z", "c")], columns=("x", "a"), values=[("x", "b")] + ).sort_index(), + pdf.pivot_table( + index=[("z", "c")], columns=[("x", "a")], values=[("x", "b")] + ).sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table( + index=[("z", "c")], columns=("x", "a"), values=[("x", "b"), ("y", "e")] + ).sort_index(), + pdf.pivot_table( + index=[("z", "c")], columns=[("x", "a")], values=[("x", "b"), ("y", "e")] + ).sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table( + index=[("z", "c")], columns=("x", "a"), values=[("x", "b"), ("y", "e"), ("w", "d")] + ).sort_index(), + pdf.pivot_table( + index=[("z", "c")], + columns=[("x", "a")], + values=[("x", "b"), ("y", "e"), ("w", "d")], + ).sort_index(), + almost=True, + ) + + self.assert_eq( + kdf.pivot_table( + index=[("z", "c")], + columns=("x", "a"), + values=[("x", "b"), ("y", "e")], + aggfunc={("x", "b"): "mean", ("y", "e"): "sum"}, + ).sort_index(), + pdf.pivot_table( + index=[("z", "c")], + columns=[("x", "a")], + values=[("x", "b"), ("y", "e")], + aggfunc={("x", "b"): "mean", ("y", "e"): "sum"}, + ).sort_index(), + almost=True, + ) + + def test_pivot_table_and_index(self): + # https://github.com/databricks/koalas/issues/805 + pdf = pd.DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + }, + columns=["A", "B", "C", "D", "E"], + index=np.random.rand(9), + ) + kdf = pp.from_pandas(pdf) + + ptable = pdf.pivot_table( + values="D", index=["A", "B"], columns="C", aggfunc="sum", fill_value=0 + ).sort_index() + ktable = kdf.pivot_table( + values="D", index=["A", "B"], columns="C", aggfunc="sum", fill_value=0 + ).sort_index() + + self.assert_eq(ktable, ptable) + self.assert_eq(ktable.index, ptable.index) + self.assert_eq(repr(ktable.index), repr(ptable.index)) + + @unittest.skipIf( + LooseVersion(pyspark.__version__) < LooseVersion("2.4"), + "stack won't work properly with PySpark<2.4", + ) + def test_stack(self): + pdf_single_level_cols = pd.DataFrame( + [[0, 1], [2, 3]], index=["cat", "dog"], columns=["weight", "height"] + ) + kdf_single_level_cols = pp.from_pandas(pdf_single_level_cols) + + self.assert_eq( + kdf_single_level_cols.stack().sort_index(), pdf_single_level_cols.stack().sort_index() + ) + + multicol1 = pd.MultiIndex.from_tuples( + [("weight", "kg"), ("weight", "pounds")], names=["x", "y"] + ) + pdf_multi_level_cols1 = pd.DataFrame( + [[1, 2], [2, 4]], index=["cat", "dog"], columns=multicol1 + ) + kdf_multi_level_cols1 = pp.from_pandas(pdf_multi_level_cols1) + + self.assert_eq( + kdf_multi_level_cols1.stack().sort_index(), pdf_multi_level_cols1.stack().sort_index() + ) + + multicol2 = pd.MultiIndex.from_tuples([("weight", "kg"), ("height", "m")]) + pdf_multi_level_cols2 = pd.DataFrame( + [[1.0, 2.0], [3.0, 4.0]], index=["cat", "dog"], columns=multicol2 + ) + kdf_multi_level_cols2 = pp.from_pandas(pdf_multi_level_cols2) + + self.assert_eq( + kdf_multi_level_cols2.stack().sort_index(), pdf_multi_level_cols2.stack().sort_index() + ) + + pdf = pd.DataFrame( + { + ("y", "c"): [True, True], + ("x", "b"): [False, False], + ("x", "c"): [True, False], + ("y", "a"): [False, True], + } + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.stack().sort_index(), pdf.stack().sort_index()) + self.assert_eq(kdf[[]].stack().sort_index(), pdf[[]].stack().sort_index(), almost=True) + + def test_unstack(self): + pdf = pd.DataFrame( + np.random.randn(3, 3), + index=pd.MultiIndex.from_tuples([("rg1", "x"), ("rg1", "y"), ("rg2", "z")]), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.unstack().sort_index(), pdf.unstack().sort_index(), almost=True) + + def test_pivot_errors(self): + kdf = pp.range(10) + + with self.assertRaisesRegex(ValueError, "columns should be set"): + kdf.pivot(index="id") + + with self.assertRaisesRegex(ValueError, "values should be set"): + kdf.pivot(index="id", columns="id") + + def test_pivot_table_errors(self): + pdf = pd.DataFrame( + { + "a": [4, 2, 3, 4, 8, 6], + "b": [1, 2, 2, 4, 2, 4], + "e": [1, 2, 2, 4, 2, 4], + "c": [1, 2, 9, 4, 7, 4], + }, + index=np.random.rand(6), + ) + kdf = pp.from_pandas(pdf) + + self.assertRaises(KeyError, lambda: kdf.pivot_table(index=["c"], columns="a", values=5)) + + msg = "index should be a None or a list of columns." + with self.assertRaisesRegex(ValueError, msg): + kdf.pivot_table(index="c", columns="a", values="b") + + msg = "pivot_table doesn't support aggfunc as dict and without index." + with self.assertRaisesRegex(NotImplementedError, msg): + kdf.pivot_table(columns="a", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"}) + + msg = "columns should be one column name." + with self.assertRaisesRegex(ValueError, msg): + kdf.pivot_table(columns=["a"], values=["b"], aggfunc={"b": "mean", "e": "sum"}) + + msg = "Columns in aggfunc must be the same as values." + with self.assertRaisesRegex(ValueError, msg): + kdf.pivot_table( + index=["e", "c"], columns="a", values="b", aggfunc={"b": "mean", "e": "sum"} + ) + + msg = "values can't be a list without index." + with self.assertRaisesRegex(NotImplementedError, msg): + kdf.pivot_table(columns="a", values=["b", "e"]) + + msg = "Wrong columns A." + with self.assertRaisesRegex(ValueError, msg): + kdf.pivot_table( + index=["c"], columns="A", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"} + ) + + kdf = pp.DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + }, + columns=["A", "B", "C", "D", "E"], + index=np.random.rand(9), + ) + + msg = "values should be a numeric type." + with self.assertRaisesRegex(TypeError, msg): + kdf.pivot_table( + index=["C"], columns="A", values=["B", "E"], aggfunc={"B": "mean", "E": "sum"} + ) + + msg = "values should be a numeric type." + with self.assertRaisesRegex(TypeError, msg): + kdf.pivot_table(index=["C"], columns="A", values="B", aggfunc={"B": "mean"}) + + def test_transpose(self): + # TODO: what if with random index? + pdf1 = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}, columns=["col1", "col2"]) + kdf1 = pp.from_pandas(pdf1) + + pdf2 = pd.DataFrame( + data={"score": [9, 8], "kids": [0, 0], "age": [12, 22]}, + columns=["score", "kids", "age"], + ) + kdf2 = pp.from_pandas(pdf2) + + self.assert_eq(pdf1.transpose().sort_index(), kdf1.transpose().sort_index()) + self.assert_eq(pdf2.transpose().sort_index(), kdf2.transpose().sort_index()) + + with option_context("compute.max_rows", None): + self.assert_eq(pdf1.transpose().sort_index(), kdf1.transpose().sort_index()) + + self.assert_eq(pdf2.transpose().sort_index(), kdf2.transpose().sort_index()) + + pdf3 = pd.DataFrame( + { + ("cg1", "a"): [1, 2, 3], + ("cg1", "b"): [4, 5, 6], + ("cg2", "c"): [7, 8, 9], + ("cg3", "d"): [9, 9, 9], + }, + index=pd.MultiIndex.from_tuples([("rg1", "x"), ("rg1", "y"), ("rg2", "z")]), + ) + kdf3 = pp.from_pandas(pdf3) + + self.assert_eq(pdf3.transpose().sort_index(), kdf3.transpose().sort_index()) + + with option_context("compute.max_rows", None): + self.assert_eq(pdf3.transpose().sort_index(), kdf3.transpose().sort_index()) + + def _test_cummin(self, pdf, kdf): + self.assert_eq(pdf.cummin(), kdf.cummin()) + self.assert_eq(pdf.cummin(skipna=False), kdf.cummin(skipna=False)) + self.assert_eq(pdf.cummin().sum(), kdf.cummin().sum()) + + def test_cummin(self): + pdf = pd.DataFrame( + [[2.0, 1.0], [5, None], [1.0, 0.0], [2.0, 4.0], [4.0, 9.0]], + columns=list("AB"), + index=np.random.rand(5), + ) + kdf = pp.from_pandas(pdf) + self._test_cummin(pdf, kdf) + + def test_cummin_multiindex_columns(self): + arrays = [np.array(["A", "A", "B", "B"]), np.array(["one", "two", "one", "two"])] + pdf = pd.DataFrame(np.random.randn(3, 4), index=["A", "C", "B"], columns=arrays) + pdf.at["C", ("A", "two")] = None + kdf = pp.from_pandas(pdf) + self._test_cummin(pdf, kdf) + + def _test_cummax(self, pdf, kdf): + self.assert_eq(pdf.cummax(), kdf.cummax()) + self.assert_eq(pdf.cummax(skipna=False), kdf.cummax(skipna=False)) + self.assert_eq(pdf.cummax().sum(), kdf.cummax().sum()) + + def test_cummax(self): + pdf = pd.DataFrame( + [[2.0, 1.0], [5, None], [1.0, 0.0], [2.0, 4.0], [4.0, 9.0]], + columns=list("AB"), + index=np.random.rand(5), + ) + kdf = pp.from_pandas(pdf) + self._test_cummax(pdf, kdf) + + def test_cummax_multiindex_columns(self): + arrays = [np.array(["A", "A", "B", "B"]), np.array(["one", "two", "one", "two"])] + pdf = pd.DataFrame(np.random.randn(3, 4), index=["A", "C", "B"], columns=arrays) + pdf.at["C", ("A", "two")] = None + kdf = pp.from_pandas(pdf) + self._test_cummax(pdf, kdf) + + def _test_cumsum(self, pdf, kdf): + self.assert_eq(pdf.cumsum(), kdf.cumsum()) + self.assert_eq(pdf.cumsum(skipna=False), kdf.cumsum(skipna=False)) + self.assert_eq(pdf.cumsum().sum(), kdf.cumsum().sum()) + + def test_cumsum(self): + pdf = pd.DataFrame( + [[2.0, 1.0], [5, None], [1.0, 0.0], [2.0, 4.0], [4.0, 9.0]], + columns=list("AB"), + index=np.random.rand(5), + ) + kdf = pp.from_pandas(pdf) + self._test_cumsum(pdf, kdf) + + def test_cumsum_multiindex_columns(self): + arrays = [np.array(["A", "A", "B", "B"]), np.array(["one", "two", "one", "two"])] + pdf = pd.DataFrame(np.random.randn(3, 4), index=["A", "C", "B"], columns=arrays) + pdf.at["C", ("A", "two")] = None + kdf = pp.from_pandas(pdf) + self._test_cumsum(pdf, kdf) + + def _test_cumprod(self, pdf, kdf): + self.assert_eq(pdf.cumprod(), kdf.cumprod(), almost=True) + self.assert_eq(pdf.cumprod(skipna=False), kdf.cumprod(skipna=False), almost=True) + self.assert_eq(pdf.cumprod().sum(), kdf.cumprod().sum(), almost=True) + + def test_cumprod(self): + if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): + pdf = pd.DataFrame( + [[2.0, 1.0, 1], [5, None, 2], [1.0, -1.0, -3], [2.0, 0, 4], [4.0, 9.0, 5]], + columns=list("ABC"), + index=np.random.rand(5), + ) + kdf = pp.from_pandas(pdf) + self._test_cumprod(pdf, kdf) + else: + pdf = pd.DataFrame( + [[2, 1, 1], [5, 1, 2], [1, -1, -3], [2, 0, 4], [4, 9, 5]], + columns=list("ABC"), + index=np.random.rand(5), + ) + kdf = pp.from_pandas(pdf) + self._test_cumprod(pdf, kdf) + + def test_cumprod_multiindex_columns(self): + arrays = [np.array(["A", "A", "B", "B"]), np.array(["one", "two", "one", "two"])] + pdf = pd.DataFrame(np.random.rand(3, 4), index=["A", "C", "B"], columns=arrays) + pdf.at["C", ("A", "two")] = None + kdf = pp.from_pandas(pdf) + self._test_cumprod(pdf, kdf) + + def test_drop_duplicates(self): + pdf = pd.DataFrame( + {"a": [1, 2, 2, 2, 3], "b": ["a", "a", "a", "c", "d"]}, index=np.random.rand(5) + ) + kdf = pp.from_pandas(pdf) + + # inplace is False + for keep in ["first", "last", False]: + with self.subTest(keep=keep): + self.assert_eq( + pdf.drop_duplicates(keep=keep).sort_index(), + kdf.drop_duplicates(keep=keep).sort_index(), + ) + self.assert_eq( + pdf.drop_duplicates("a", keep=keep).sort_index(), + kdf.drop_duplicates("a", keep=keep).sort_index(), + ) + self.assert_eq( + pdf.drop_duplicates(["a", "b"], keep=keep).sort_index(), + kdf.drop_duplicates(["a", "b"], keep=keep).sort_index(), + ) + self.assert_eq( + pdf.set_index("a", append=True).drop_duplicates(keep=keep).sort_index(), + kdf.set_index("a", append=True).drop_duplicates(keep=keep).sort_index(), + ) + self.assert_eq( + pdf.set_index("a", append=True).drop_duplicates("b", keep=keep).sort_index(), + kdf.set_index("a", append=True).drop_duplicates("b", keep=keep).sort_index(), + ) + + columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b")]) + pdf.columns = columns + kdf.columns = columns + + # inplace is False + for keep in ["first", "last", False]: + with self.subTest("multi-index columns", keep=keep): + self.assert_eq( + pdf.drop_duplicates(keep=keep).sort_index(), + kdf.drop_duplicates(keep=keep).sort_index(), + ) + self.assert_eq( + pdf.drop_duplicates(("x", "a"), keep=keep).sort_index(), + kdf.drop_duplicates(("x", "a"), keep=keep).sort_index(), + ) + self.assert_eq( + pdf.drop_duplicates([("x", "a"), ("y", "b")], keep=keep).sort_index(), + kdf.drop_duplicates([("x", "a"), ("y", "b")], keep=keep).sort_index(), + ) + + # inplace is True + subset_list = [None, "a", ["a", "b"]] + for subset in subset_list: + pdf = pd.DataFrame( + {"a": [1, 2, 2, 2, 3], "b": ["a", "a", "a", "c", "d"]}, index=np.random.rand(5) + ) + kdf = pp.from_pandas(pdf) + pser = pdf.a + kser = kdf.a + pdf.drop_duplicates(subset=subset, inplace=True) + kdf.drop_duplicates(subset=subset, inplace=True) + self.assert_eq(kdf.sort_index(), pdf.sort_index()) + self.assert_eq(kser.sort_index(), pser.sort_index()) + + # multi-index columns, inplace is True + subset_list = [None, ("x", "a"), [("x", "a"), ("y", "b")]] + for subset in subset_list: + pdf = pd.DataFrame( + {"a": [1, 2, 2, 2, 3], "b": ["a", "a", "a", "c", "d"]}, index=np.random.rand(5) + ) + kdf = pp.from_pandas(pdf) + columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b")]) + pdf.columns = columns + kdf.columns = columns + pser = pdf[("x", "a")] + kser = kdf[("x", "a")] + pdf.drop_duplicates(subset=subset, inplace=True) + kdf.drop_duplicates(subset=subset, inplace=True) + self.assert_eq(kdf.sort_index(), pdf.sort_index()) + self.assert_eq(kser.sort_index(), pser.sort_index()) + + # non-string names + pdf = pd.DataFrame( + {10: [1, 2, 2, 2, 3], 20: ["a", "a", "a", "c", "d"]}, index=np.random.rand(5) + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq( + pdf.drop_duplicates(10, keep=keep).sort_index(), + kdf.drop_duplicates(10, keep=keep).sort_index(), + ) + self.assert_eq( + pdf.drop_duplicates([10, 20], keep=keep).sort_index(), + kdf.drop_duplicates([10, 20], keep=keep).sort_index(), + ) + + def test_reindex(self): + index = pd.Index(["A", "B", "C", "D", "E"]) + columns = pd.Index(["numbers"]) + pdf = pd.DataFrame([1.0, 2.0, 3.0, 4.0, None], index=index, columns=columns) + kdf = pp.from_pandas(pdf) + + columns2 = pd.Index(["numbers", "2", "3"], name="cols2") + self.assert_eq( + pdf.reindex(columns=columns2).sort_index(), kdf.reindex(columns=columns2).sort_index(), + ) + + columns = pd.Index(["numbers"], name="cols") + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + pdf.reindex(["A", "B", "C"], columns=["numbers", "2", "3"]).sort_index(), + kdf.reindex(["A", "B", "C"], columns=["numbers", "2", "3"]).sort_index(), + ) + + self.assert_eq( + pdf.reindex(["A", "B", "C"], index=["numbers", "2", "3"]).sort_index(), + kdf.reindex(["A", "B", "C"], index=["numbers", "2", "3"]).sort_index(), + ) + + self.assert_eq( + pdf.reindex(index=["A", "B"]).sort_index(), kdf.reindex(index=["A", "B"]).sort_index() + ) + + self.assert_eq( + pdf.reindex(index=["A", "B", "2", "3"]).sort_index(), + kdf.reindex(index=["A", "B", "2", "3"]).sort_index(), + ) + + self.assert_eq( + pdf.reindex(index=["A", "E", "2", "3"], fill_value=0).sort_index(), + kdf.reindex(index=["A", "E", "2", "3"], fill_value=0).sort_index(), + ) + + self.assert_eq( + pdf.reindex(columns=["numbers"]).sort_index(), + kdf.reindex(columns=["numbers"]).sort_index(), + ) + + # Using float as fill_value to avoid int64/32 clash + self.assert_eq( + pdf.reindex(columns=["numbers", "2", "3"], fill_value=0.0).sort_index(), + kdf.reindex(columns=["numbers", "2", "3"], fill_value=0.0).sort_index(), + ) + + columns2 = pd.Index(["numbers", "2", "3"]) + self.assert_eq( + pdf.reindex(columns=columns2).sort_index(), kdf.reindex(columns=columns2).sort_index(), + ) + + columns2 = pd.Index(["numbers", "2", "3"], name="cols2") + self.assert_eq( + pdf.reindex(columns=columns2).sort_index(), kdf.reindex(columns=columns2).sort_index(), + ) + + # Reindexing single Index on single Index + pindex2 = pd.Index(["A", "C", "D", "E", "0"], name="index2") + kindex2 = pp.from_pandas(pindex2) + + for fill_value in [None, 0]: + self.assert_eq( + pdf.reindex(index=pindex2, fill_value=fill_value).sort_index(), + kdf.reindex(index=kindex2, fill_value=fill_value).sort_index(), + ) + + pindex2 = pd.DataFrame({"index2": ["A", "C", "D", "E", "0"]}).set_index("index2").index + kindex2 = pp.from_pandas(pindex2) + + for fill_value in [None, 0]: + self.assert_eq( + pdf.reindex(index=pindex2, fill_value=fill_value).sort_index(), + kdf.reindex(index=kindex2, fill_value=fill_value).sort_index(), + ) + + # Reindexing MultiIndex on single Index + pindex = pd.MultiIndex.from_tuples( + [("A", "B"), ("C", "D"), ("F", "G")], names=["name1", "name2"] + ) + kindex = pp.from_pandas(pindex) + + self.assert_eq( + pdf.reindex(index=pindex, fill_value=0.0).sort_index(), + kdf.reindex(index=kindex, fill_value=0.0).sort_index(), + ) + + self.assertRaises(TypeError, lambda: kdf.reindex(columns=["numbers", "2", "3"], axis=1)) + self.assertRaises(TypeError, lambda: kdf.reindex(columns=["numbers", "2", "3"], axis=2)) + self.assertRaises(TypeError, lambda: kdf.reindex(index=["A", "B", "C"], axis=1)) + self.assertRaises(TypeError, lambda: kdf.reindex(index=123)) + + # Reindexing MultiIndex on MultiIndex + pdf = pd.DataFrame({"numbers": [1.0, 2.0, None]}, index=pindex) + kdf = pp.from_pandas(pdf) + pindex2 = pd.MultiIndex.from_tuples( + [("A", "G"), ("C", "D"), ("I", "J")], names=["name1", "name2"] + ) + kindex2 = pp.from_pandas(pindex2) + + for fill_value in [None, 0.0]: + self.assert_eq( + pdf.reindex(index=pindex2, fill_value=fill_value).sort_index(), + kdf.reindex(index=kindex2, fill_value=fill_value).sort_index(), + ) + + pindex2 = ( + pd.DataFrame({"index_level_1": ["A", "C", "I"], "index_level_2": ["G", "D", "J"]}) + .set_index(["index_level_1", "index_level_2"]) + .index + ) + kindex2 = pp.from_pandas(pindex2) + + for fill_value in [None, 0.0]: + self.assert_eq( + pdf.reindex(index=pindex2, fill_value=fill_value).sort_index(), + kdf.reindex(index=kindex2, fill_value=fill_value).sort_index(), + ) + + columns = pd.MultiIndex.from_tuples([("X", "numbers")], names=["cols1", "cols2"]) + pdf.columns = columns + kdf.columns = columns + + # Reindexing MultiIndex index on MultiIndex columns and MultiIndex index + for fill_value in [None, 0.0]: + self.assert_eq( + pdf.reindex(index=pindex2, fill_value=fill_value).sort_index(), + kdf.reindex(index=kindex2, fill_value=fill_value).sort_index(), + ) + + index = pd.Index(["A", "B", "C", "D", "E"]) + pdf = pd.DataFrame(data=[1.0, 2.0, 3.0, 4.0, None], index=index, columns=columns) + kdf = pp.from_pandas(pdf) + pindex2 = pd.Index(["A", "C", "D", "E", "0"], name="index2") + kindex2 = pp.from_pandas(pindex2) + + # Reindexing single Index on MultiIndex columns and single Index + for fill_value in [None, 0.0]: + self.assert_eq( + pdf.reindex(index=pindex2, fill_value=fill_value).sort_index(), + kdf.reindex(index=kindex2, fill_value=fill_value).sort_index(), + ) + + for fill_value in [None, 0.0]: + self.assert_eq( + pdf.reindex( + columns=[("X", "numbers"), ("Y", "2"), ("Y", "3")], fill_value=fill_value + ).sort_index(), + kdf.reindex( + columns=[("X", "numbers"), ("Y", "2"), ("Y", "3")], fill_value=fill_value + ).sort_index(), + ) + + columns2 = pd.MultiIndex.from_tuples( + [("X", "numbers"), ("Y", "2"), ("Y", "3")], names=["cols3", "cols4"] + ) + self.assert_eq( + pdf.reindex(columns=columns2).sort_index(), kdf.reindex(columns=columns2).sort_index(), + ) + + self.assertRaises(TypeError, lambda: kdf.reindex(columns=["X"])) + self.assertRaises(ValueError, lambda: kdf.reindex(columns=[("X",)])) + + def test_reindex_like(self): + data = [[1.0, 2.0], [3.0, None], [None, 4.0]] + index = pd.Index(["A", "B", "C"], name="index") + columns = pd.Index(["numbers", "values"], name="cols") + pdf = pd.DataFrame(data=data, index=index, columns=columns) + kdf = pp.from_pandas(pdf) + + # Reindexing single Index on single Index + data2 = [[5.0, None], [6.0, 7.0], [8.0, None]] + index2 = pd.Index(["A", "C", "D"], name="index2") + columns2 = pd.Index(["numbers", "F"], name="cols2") + pdf2 = pd.DataFrame(data=data2, index=index2, columns=columns2) + kdf2 = pp.from_pandas(pdf2) + + self.assert_eq( + pdf.reindex_like(pdf2).sort_index(), kdf.reindex_like(kdf2).sort_index(), + ) + + pdf2 = pd.DataFrame({"index_level_1": ["A", "C", "I"]}) + kdf2 = pp.from_pandas(pdf2) + + self.assert_eq( + pdf.reindex_like(pdf2.set_index(["index_level_1"])).sort_index(), + kdf.reindex_like(kdf2.set_index(["index_level_1"])).sort_index(), + ) + + # Reindexing MultiIndex on single Index + index2 = pd.MultiIndex.from_tuples( + [("A", "G"), ("C", "D"), ("I", "J")], names=["name3", "name4"] + ) + pdf2 = pd.DataFrame(data=data2, index=index2) + kdf2 = pp.from_pandas(pdf2) + + self.assert_eq( + pdf.reindex_like(pdf2).sort_index(), kdf.reindex_like(kdf2).sort_index(), + ) + + self.assertRaises(TypeError, lambda: kdf.reindex_like(index2)) + self.assertRaises(AssertionError, lambda: kdf2.reindex_like(kdf)) + + # Reindexing MultiIndex on MultiIndex + columns2 = pd.MultiIndex.from_tuples( + [("numbers", "third"), ("values", "second")], names=["cols3", "cols4"] + ) + pdf2.columns = columns2 + kdf2.columns = columns2 + + columns = pd.MultiIndex.from_tuples( + [("numbers", "first"), ("values", "second")], names=["cols1", "cols2"] + ) + index = pd.MultiIndex.from_tuples( + [("A", "B"), ("C", "D"), ("E", "F")], names=["name1", "name2"] + ) + pdf = pd.DataFrame(data=data, index=index, columns=columns) + kdf = pp.from_pandas(pdf) + + self.assert_eq( + pdf.reindex_like(pdf2).sort_index(), kdf.reindex_like(kdf2).sort_index(), + ) + + def test_melt(self): + pdf = pd.DataFrame( + {"A": [1, 3, 5], "B": [2, 4, 6], "C": [7, 8, 9]}, index=np.random.rand(3) + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq( + kdf.melt().sort_values(["variable", "value"]).reset_index(drop=True), + pdf.melt().sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars="A").sort_values(["variable", "value"]).reset_index(drop=True), + pdf.melt(id_vars="A").sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=["A", "B"]).sort_values(["variable", "value"]).reset_index(drop=True), + pdf.melt(id_vars=["A", "B"]).sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=("A", "B")).sort_values(["variable", "value"]).reset_index(drop=True), + pdf.melt(id_vars=("A", "B")).sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=["A"], value_vars=["C"]) + .sort_values(["variable", "value"]) + .reset_index(drop=True), + pdf.melt(id_vars=["A"], value_vars=["C"]).sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=["A"], value_vars=["B"], var_name="myVarname", value_name="myValname") + .sort_values(["myVarname", "myValname"]) + .reset_index(drop=True), + pdf.melt( + id_vars=["A"], value_vars=["B"], var_name="myVarname", value_name="myValname" + ).sort_values(["myVarname", "myValname"]), + ) + self.assert_eq( + kdf.melt(value_vars=("A", "B")) + .sort_values(["variable", "value"]) + .reset_index(drop=True), + pdf.melt(value_vars=("A", "B")).sort_values(["variable", "value"]), + ) + + self.assertRaises(KeyError, lambda: kdf.melt(id_vars="Z")) + self.assertRaises(KeyError, lambda: kdf.melt(value_vars="Z")) + + # multi-index columns + if LooseVersion("0.24") <= LooseVersion(pd.__version__) < LooseVersion("1.0.0"): + # pandas >=0.24,<1.0 doesn't support mixed int/str columns in melt. + # see: https://github.com/pandas-dev/pandas/pull/29792 + TEN = "10" + TWELVE = "20" + else: + TEN = 10.0 + TWELVE = 20.0 + + columns = pd.MultiIndex.from_tuples([(TEN, "A"), (TEN, "B"), (TWELVE, "C")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + kdf.melt().sort_values(["variable_0", "variable_1", "value"]).reset_index(drop=True), + pdf.melt().sort_values(["variable_0", "variable_1", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=[(TEN, "A")]) + .sort_values(["variable_0", "variable_1", "value"]) + .reset_index(drop=True), + pdf.melt(id_vars=[(TEN, "A")]) + .sort_values(["variable_0", "variable_1", "value"]) + .rename(columns=name_like_string), + ) + self.assert_eq( + kdf.melt(id_vars=[(TEN, "A")], value_vars=[(TWELVE, "C")]) + .sort_values(["variable_0", "variable_1", "value"]) + .reset_index(drop=True), + pdf.melt(id_vars=[(TEN, "A")], value_vars=[(TWELVE, "C")]) + .sort_values(["variable_0", "variable_1", "value"]) + .rename(columns=name_like_string), + ) + self.assert_eq( + kdf.melt( + id_vars=[(TEN, "A")], + value_vars=[(TEN, "B")], + var_name=["myV1", "myV2"], + value_name="myValname", + ) + .sort_values(["myV1", "myV2", "myValname"]) + .reset_index(drop=True), + pdf.melt( + id_vars=[(TEN, "A")], + value_vars=[(TEN, "B")], + var_name=["myV1", "myV2"], + value_name="myValname", + ) + .sort_values(["myV1", "myV2", "myValname"]) + .rename(columns=name_like_string), + ) + + columns.names = ["v0", "v1"] + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + kdf.melt().sort_values(["v0", "v1", "value"]).reset_index(drop=True), + pdf.melt().sort_values(["v0", "v1", "value"]), + ) + + self.assertRaises(ValueError, lambda: kdf.melt(id_vars=(TEN, "A"))) + self.assertRaises(ValueError, lambda: kdf.melt(value_vars=(TEN, "A"))) + self.assertRaises(KeyError, lambda: kdf.melt(id_vars=[TEN])) + self.assertRaises(KeyError, lambda: kdf.melt(id_vars=[(TWELVE, "A")])) + self.assertRaises(KeyError, lambda: kdf.melt(value_vars=[TWELVE])) + self.assertRaises(KeyError, lambda: kdf.melt(value_vars=[(TWELVE, "A")])) + + # non-string names + pdf.columns = [10.0, 20.0, 30.0] + kdf.columns = [10.0, 20.0, 30.0] + + self.assert_eq( + kdf.melt().sort_values(["variable", "value"]).reset_index(drop=True), + pdf.melt().sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=10.0).sort_values(["variable", "value"]).reset_index(drop=True), + pdf.melt(id_vars=10.0).sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=[10.0, 20.0]) + .sort_values(["variable", "value"]) + .reset_index(drop=True), + pdf.melt(id_vars=[10.0, 20.0]).sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=(10.0, 20.0)) + .sort_values(["variable", "value"]) + .reset_index(drop=True), + pdf.melt(id_vars=(10.0, 20.0)).sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(id_vars=[10.0], value_vars=[30.0]) + .sort_values(["variable", "value"]) + .reset_index(drop=True), + pdf.melt(id_vars=[10.0], value_vars=[30.0]).sort_values(["variable", "value"]), + ) + self.assert_eq( + kdf.melt(value_vars=(10.0, 20.0)) + .sort_values(["variable", "value"]) + .reset_index(drop=True), + pdf.melt(value_vars=(10.0, 20.0)).sort_values(["variable", "value"]), + ) + + def test_all(self): + pdf = pd.DataFrame( + { + "col1": [False, False, False], + "col2": [True, False, False], + "col3": [0, 0, 1], + "col4": [0, 1, 2], + "col5": [False, False, None], + "col6": [True, False, None], + }, + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.all(), pdf.all()) + + columns = pd.MultiIndex.from_tuples( + [ + ("a", "col1"), + ("a", "col2"), + ("a", "col3"), + ("b", "col4"), + ("b", "col5"), + ("c", "col6"), + ] + ) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.all(), pdf.all()) + + columns.names = ["X", "Y"] + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.all(), pdf.all()) + + with self.assertRaisesRegex( + NotImplementedError, 'axis should be either 0 or "index" currently.' + ): + kdf.all(axis=1) + + def test_any(self): + pdf = pd.DataFrame( + { + "col1": [False, False, False], + "col2": [True, False, False], + "col3": [0, 0, 1], + "col4": [0, 1, 2], + "col5": [False, False, None], + "col6": [True, False, None], + }, + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.any(), pdf.any()) + + columns = pd.MultiIndex.from_tuples( + [ + ("a", "col1"), + ("a", "col2"), + ("a", "col3"), + ("b", "col4"), + ("b", "col5"), + ("c", "col6"), + ] + ) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.any(), pdf.any()) + + columns.names = ["X", "Y"] + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.any(), pdf.any()) + + with self.assertRaisesRegex( + NotImplementedError, 'axis should be either 0 or "index" currently.' + ): + kdf.any(axis=1) + + def test_rank(self): + pdf = pd.DataFrame( + data={"col1": [1, 2, 3, 1], "col2": [3, 4, 3, 1]}, + columns=["col1", "col2"], + index=np.random.rand(4), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(pdf.rank().sort_index(), kdf.rank().sort_index()) + self.assert_eq(pdf.rank().sum(), kdf.rank().sum()) + self.assert_eq( + pdf.rank(ascending=False).sort_index(), kdf.rank(ascending=False).sort_index() + ) + self.assert_eq(pdf.rank(method="min").sort_index(), kdf.rank(method="min").sort_index()) + self.assert_eq(pdf.rank(method="max").sort_index(), kdf.rank(method="max").sort_index()) + self.assert_eq(pdf.rank(method="first").sort_index(), kdf.rank(method="first").sort_index()) + self.assert_eq(pdf.rank(method="dense").sort_index(), kdf.rank(method="dense").sort_index()) + + msg = "method must be one of 'average', 'min', 'max', 'first', 'dense'" + with self.assertRaisesRegex(ValueError, msg): + kdf.rank(method="nothing") + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "col1"), ("y", "col2")]) + pdf.columns = columns + kdf.columns = columns + self.assert_eq(pdf.rank().sort_index(), kdf.rank().sort_index()) + + def test_round(self): + pdf = pd.DataFrame( + { + "A": [0.028208, 0.038683, 0.877076], + "B": [0.992815, 0.645646, 0.149370], + "C": [0.173891, 0.577595, 0.491027], + }, + columns=["A", "B", "C"], + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + + pser = pd.Series([1, 0, 2], index=["A", "B", "C"]) + kser = pp.Series([1, 0, 2], index=["A", "B", "C"]) + self.assert_eq(pdf.round(2), kdf.round(2)) + self.assert_eq(pdf.round({"A": 1, "C": 2}), kdf.round({"A": 1, "C": 2})) + self.assert_eq(pdf.round({"A": 1, "D": 2}), kdf.round({"A": 1, "D": 2})) + self.assert_eq(pdf.round(pser), kdf.round(kser)) + msg = "decimals must be an integer, a dict-like or a Series" + with self.assertRaisesRegex(ValueError, msg): + kdf.round(1.5) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C")]) + pdf.columns = columns + kdf.columns = columns + pser = pd.Series([1, 0, 2], index=columns) + kser = pp.Series([1, 0, 2], index=columns) + self.assert_eq(pdf.round(2), kdf.round(2)) + self.assert_eq( + pdf.round({("X", "A"): 1, ("Y", "C"): 2}), kdf.round({("X", "A"): 1, ("Y", "C"): 2}) + ) + self.assert_eq(pdf.round({("X", "A"): 1, "Y": 2}), kdf.round({("X", "A"): 1, "Y": 2})) + self.assert_eq(pdf.round(pser), kdf.round(kser)) + + # non-string names + pdf = pd.DataFrame( + { + 10: [0.028208, 0.038683, 0.877076], + 20: [0.992815, 0.645646, 0.149370], + 30: [0.173891, 0.577595, 0.491027], + }, + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(pdf.round({10: 1, 30: 2}), kdf.round({10: 1, 30: 2})) + + def test_shift(self): + pdf = pd.DataFrame( + { + "Col1": [10, 20, 15, 30, 45], + "Col2": [13, 23, 18, 33, 48], + "Col3": [17, 27, 22, 37, 52], + }, + index=np.random.rand(5), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(pdf.shift(3), kdf.shift(3)) + self.assert_eq(pdf.shift().shift(-1), kdf.shift().shift(-1)) + self.assert_eq(pdf.shift().sum().astype(int), kdf.shift().sum()) + + # Need the expected result since pandas 0.23 does not support `fill_value` argument. + pdf1 = pd.DataFrame( + {"Col1": [0, 0, 0, 10, 20], "Col2": [0, 0, 0, 13, 23], "Col3": [0, 0, 0, 17, 27]}, + index=pdf.index, + ) + self.assert_eq(pdf1, kdf.shift(periods=3, fill_value=0)) + msg = "should be an int" + with self.assertRaisesRegex(ValueError, msg): + kdf.shift(1.5) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "Col1"), ("x", "Col2"), ("y", "Col3")]) + pdf.columns = columns + kdf.columns = columns + self.assert_eq(pdf.shift(3), kdf.shift(3)) + self.assert_eq(pdf.shift().shift(-1), kdf.shift().shift(-1)) + + def test_diff(self): + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6], "b": [1, 1, 2, 3, 5, 8], "c": [1, 4, 9, 16, 25, 36]}, + index=np.random.rand(6), + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(pdf.diff(), kdf.diff()) + self.assert_eq(pdf.diff().diff(-1), kdf.diff().diff(-1)) + self.assert_eq(pdf.diff().sum().astype(int), kdf.diff().sum()) + + msg = "should be an int" + with self.assertRaisesRegex(ValueError, msg): + kdf.diff(1.5) + msg = 'axis should be either 0 or "index" currently.' + with self.assertRaisesRegex(NotImplementedError, msg): + kdf.diff(axis=1) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "Col1"), ("x", "Col2"), ("y", "Col3")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(pdf.diff(), kdf.diff()) + + def test_duplicated(self): + pdf = pd.DataFrame( + {"a": [1, 1, 2, 3], "b": [1, 1, 1, 4], "c": [1, 1, 1, 5]}, index=np.random.rand(4) + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(pdf.duplicated().sort_index(), kdf.duplicated().sort_index()) + self.assert_eq( + pdf.duplicated(keep="last").sort_index(), kdf.duplicated(keep="last").sort_index(), + ) + self.assert_eq( + pdf.duplicated(keep=False).sort_index(), kdf.duplicated(keep=False).sort_index(), + ) + self.assert_eq( + pdf.duplicated(subset="b").sort_index(), kdf.duplicated(subset="b").sort_index(), + ) + self.assert_eq( + pdf.duplicated(subset=["b"]).sort_index(), kdf.duplicated(subset=["b"]).sort_index(), + ) + with self.assertRaisesRegex(ValueError, "'keep' only supports 'first', 'last' and False"): + kdf.duplicated(keep="false") + with self.assertRaisesRegex(KeyError, "'d'"): + kdf.duplicated(subset=["d"]) + + pdf.index.name = "x" + kdf.index.name = "x" + self.assert_eq(pdf.duplicated().sort_index(), kdf.duplicated().sort_index()) + + # multi-index + self.assert_eq( + pdf.set_index("a", append=True).duplicated().sort_index(), + kdf.set_index("a", append=True).duplicated().sort_index(), + ) + self.assert_eq( + pdf.set_index("a", append=True).duplicated(keep=False).sort_index(), + kdf.set_index("a", append=True).duplicated(keep=False).sort_index(), + ) + self.assert_eq( + pdf.set_index("a", append=True).duplicated(subset=["b"]).sort_index(), + kdf.set_index("a", append=True).duplicated(subset=["b"]).sort_index(), + ) + + # mutli-index columns + columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", "c")]) + pdf.columns = columns + kdf.columns = columns + self.assert_eq(pdf.duplicated().sort_index(), kdf.duplicated().sort_index()) + self.assert_eq( + pdf.duplicated(subset=("x", "b")).sort_index(), + kdf.duplicated(subset=("x", "b")).sort_index(), + ) + self.assert_eq( + pdf.duplicated(subset=[("x", "b")]).sort_index(), + kdf.duplicated(subset=[("x", "b")]).sort_index(), + ) + + # non-string names + pdf = pd.DataFrame( + {10: [1, 1, 2, 3], 20: [1, 1, 1, 4], 30: [1, 1, 1, 5]}, index=np.random.rand(4) + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(pdf.duplicated().sort_index(), kdf.duplicated().sort_index()) + self.assert_eq( + pdf.duplicated(subset=10).sort_index(), kdf.duplicated(subset=10).sort_index(), + ) + + def test_ffill(self): + idx = np.random.rand(6) + pdf = pd.DataFrame( + { + "x": [np.nan, 2, 3, 4, np.nan, 6], + "y": [1, 2, np.nan, 4, np.nan, np.nan], + "z": [1, 2, 3, 4, np.nan, np.nan], + }, + index=idx, + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.ffill(), pdf.ffill()) + self.assert_eq(kdf.ffill(limit=1), pdf.ffill(limit=1)) + + pser = pdf.y + kser = kdf.y + + kdf.ffill(inplace=True) + pdf.ffill(inplace=True) + + self.assert_eq(kdf, pdf) + self.assert_eq(kser, pser) + self.assert_eq(kser[idx[2]], pser[idx[2]]) + + def test_bfill(self): + idx = np.random.rand(6) + pdf = pd.DataFrame( + { + "x": [np.nan, 2, 3, 4, np.nan, 6], + "y": [1, 2, np.nan, 4, np.nan, np.nan], + "z": [1, 2, 3, 4, np.nan, np.nan], + }, + index=idx, + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.bfill(), pdf.bfill()) + self.assert_eq(kdf.bfill(limit=1), pdf.bfill(limit=1)) + + pser = pdf.x + kser = kdf.x + + kdf.bfill(inplace=True) + pdf.bfill(inplace=True) + + self.assert_eq(kdf, pdf) + self.assert_eq(kser, pser) + self.assert_eq(kser[idx[0]], pser[idx[0]]) + + def test_filter(self): + pdf = pd.DataFrame( + { + "aa": ["aa", "bd", "bc", "ab", "ce"], + "ba": [1, 2, 3, 4, 5], + "cb": [1.0, 2.0, 3.0, 4.0, 5.0], + "db": [1.0, np.nan, 3.0, np.nan, 5.0], + } + ) + pdf = pdf.set_index("aa") + kdf = pp.from_pandas(pdf) + + self.assert_eq( + kdf.filter(items=["ab", "aa"], axis=0).sort_index(), + pdf.filter(items=["ab", "aa"], axis=0).sort_index(), + ) + self.assert_eq( + kdf.filter(items=["ba", "db"], axis=1).sort_index(), + pdf.filter(items=["ba", "db"], axis=1).sort_index(), + ) + + self.assert_eq(kdf.filter(like="b", axis="index"), pdf.filter(like="b", axis="index")) + self.assert_eq(kdf.filter(like="c", axis="columns"), pdf.filter(like="c", axis="columns")) + + self.assert_eq(kdf.filter(regex="b.*", axis="index"), pdf.filter(regex="b.*", axis="index")) + self.assert_eq( + kdf.filter(regex="b.*", axis="columns"), pdf.filter(regex="b.*", axis="columns") + ) + + pdf = pdf.set_index("ba", append=True) + kdf = pp.from_pandas(pdf) + + self.assert_eq( + kdf.filter(items=[("aa", 1), ("bd", 2)], axis=0).sort_index(), + pdf.filter(items=[("aa", 1), ("bd", 2)], axis=0).sort_index(), + ) + + with self.assertRaisesRegex(TypeError, "Unsupported type list"): + kdf.filter(items=[["aa", 1], ("bd", 2)], axis=0) + + with self.assertRaisesRegex(ValueError, "The item should not be empty."): + kdf.filter(items=[(), ("bd", 2)], axis=0) + + self.assert_eq(kdf.filter(like="b", axis=0), pdf.filter(like="b", axis=0)) + + self.assert_eq(kdf.filter(regex="b.*", axis=0), pdf.filter(regex="b.*", axis=0)) + + with self.assertRaisesRegex(ValueError, "items should be a list-like object"): + kdf.filter(items="b") + + with self.assertRaisesRegex(ValueError, "No axis named"): + kdf.filter(regex="b.*", axis=123) + + with self.assertRaisesRegex(TypeError, "Must pass either `items`, `like`"): + kdf.filter() + + with self.assertRaisesRegex(TypeError, "mutually exclusive"): + kdf.filter(regex="b.*", like="aaa") + + # multi-index columns + pdf = pd.DataFrame( + { + ("x", "aa"): ["aa", "ab", "bc", "bd", "ce"], + ("x", "ba"): [1, 2, 3, 4, 5], + ("y", "cb"): [1.0, 2.0, 3.0, 4.0, 5.0], + ("z", "db"): [1.0, np.nan, 3.0, np.nan, 5.0], + } + ) + pdf = pdf.set_index(("x", "aa")) + kdf = pp.from_pandas(pdf) + + self.assert_eq( + kdf.filter(items=["ab", "aa"], axis=0).sort_index(), + pdf.filter(items=["ab", "aa"], axis=0).sort_index(), + ) + self.assert_eq( + kdf.filter(items=[("x", "ba"), ("z", "db")], axis=1).sort_index(), + pdf.filter(items=[("x", "ba"), ("z", "db")], axis=1).sort_index(), + ) + + self.assert_eq(kdf.filter(like="b", axis="index"), pdf.filter(like="b", axis="index")) + self.assert_eq(kdf.filter(like="c", axis="columns"), pdf.filter(like="c", axis="columns")) + + self.assert_eq(kdf.filter(regex="b.*", axis="index"), pdf.filter(regex="b.*", axis="index")) + self.assert_eq( + kdf.filter(regex="b.*", axis="columns"), pdf.filter(regex="b.*", axis="columns") + ) + + def test_pipe(self): + kdf = pp.DataFrame( + {"category": ["A", "A", "B"], "col1": [1, 2, 3], "col2": [4, 5, 6]}, + columns=["category", "col1", "col2"], + ) + + self.assertRaisesRegex( + ValueError, + "arg is both the pipe target and a keyword argument", + lambda: kdf.pipe((lambda x: x, "arg"), arg="1"), + ) + + def test_transform(self): + pdf = pd.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6] * 100, + "b": [1.0, 1.0, 2.0, 3.0, 5.0, 8.0] * 100, + "c": [1, 4, 9, 16, 25, 36] * 100, + }, + columns=["a", "b", "c"], + index=np.random.rand(600), + ) + kdf = pp.DataFrame(pdf) + self.assert_eq( + kdf.transform(lambda x: x + 1).sort_index(), pdf.transform(lambda x: x + 1).sort_index() + ) + self.assert_eq( + kdf.transform(lambda x, y: x + y, y=2).sort_index(), + pdf.transform(lambda x, y: x + y, y=2).sort_index(), + ) + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.transform(lambda x: x + 1).sort_index(), + pdf.transform(lambda x: x + 1).sort_index(), + ) + self.assert_eq( + kdf.transform(lambda x, y: x + y, y=1).sort_index(), + pdf.transform(lambda x, y: x + y, y=1).sort_index(), + ) + + with self.assertRaisesRegex(AssertionError, "the first argument should be a callable"): + kdf.transform(1) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", "c")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + kdf.transform(lambda x: x + 1).sort_index(), pdf.transform(lambda x: x + 1).sort_index() + ) + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.transform(lambda x: x + 1).sort_index(), + pdf.transform(lambda x: x + 1).sort_index(), + ) + + def test_apply(self): + pdf = pd.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6] * 100, + "b": [1.0, 1.0, 2.0, 3.0, 5.0, 8.0] * 100, + "c": [1, 4, 9, 16, 25, 36] * 100, + }, + columns=["a", "b", "c"], + index=np.random.rand(600), + ) + kdf = pp.DataFrame(pdf) + + self.assert_eq( + kdf.apply(lambda x: x + 1).sort_index(), pdf.apply(lambda x: x + 1).sort_index() + ) + self.assert_eq( + kdf.apply(lambda x, b: x + b, args=(1,)).sort_index(), + pdf.apply(lambda x, b: x + b, args=(1,)).sort_index(), + ) + self.assert_eq( + kdf.apply(lambda x, b: x + b, b=1).sort_index(), + pdf.apply(lambda x, b: x + b, b=1).sort_index(), + ) + + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.apply(lambda x: x + 1).sort_index(), pdf.apply(lambda x: x + 1).sort_index() + ) + self.assert_eq( + kdf.apply(lambda x, b: x + b, args=(1,)).sort_index(), + pdf.apply(lambda x, b: x + b, args=(1,)).sort_index(), + ) + self.assert_eq( + kdf.apply(lambda x, b: x + b, b=1).sort_index(), + pdf.apply(lambda x, b: x + b, b=1).sort_index(), + ) + + # returning a Series + self.assert_eq( + kdf.apply(lambda x: len(x), axis=1).sort_index(), + pdf.apply(lambda x: len(x), axis=1).sort_index(), + ) + self.assert_eq( + kdf.apply(lambda x, c: len(x) + c, axis=1, c=100).sort_index(), + pdf.apply(lambda x, c: len(x) + c, axis=1, c=100).sort_index(), + ) + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.apply(lambda x: len(x), axis=1).sort_index(), + pdf.apply(lambda x: len(x), axis=1).sort_index(), + ) + self.assert_eq( + kdf.apply(lambda x, c: len(x) + c, axis=1, c=100).sort_index(), + pdf.apply(lambda x, c: len(x) + c, axis=1, c=100).sort_index(), + ) + + with self.assertRaisesRegex(AssertionError, "the first argument should be a callable"): + kdf.apply(1) + + with self.assertRaisesRegex(TypeError, "The given function.*1 or 'column'; however"): + + def f1(_) -> pp.DataFrame[int]: + pass + + kdf.apply(f1, axis=0) + + with self.assertRaisesRegex(TypeError, "The given function.*0 or 'index'; however"): + + def f2(_) -> pp.Series[int]: + pass + + kdf.apply(f2, axis=1) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", "c")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + kdf.apply(lambda x: x + 1).sort_index(), pdf.apply(lambda x: x + 1).sort_index() + ) + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.apply(lambda x: x + 1).sort_index(), pdf.apply(lambda x: x + 1).sort_index() + ) + + # returning a Series + self.assert_eq( + kdf.apply(lambda x: len(x), axis=1).sort_index(), + pdf.apply(lambda x: len(x), axis=1).sort_index(), + ) + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.apply(lambda x: len(x), axis=1).sort_index(), + pdf.apply(lambda x: len(x), axis=1).sort_index(), + ) + + def test_apply_batch(self): + pdf = pd.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6] * 100, + "b": [1.0, 1.0, 2.0, 3.0, 5.0, 8.0] * 100, + "c": [1, 4, 9, 16, 25, 36] * 100, + }, + columns=["a", "b", "c"], + index=np.random.rand(600), + ) + kdf = pp.DataFrame(pdf) + + # One to test alias. + self.assert_eq(kdf.apply_batch(lambda pdf: pdf + 1).sort_index(), (pdf + 1).sort_index()) + self.assert_eq( + kdf.koalas.apply_batch(lambda pdf, a: pdf + a, args=(1,)).sort_index(), + (pdf + 1).sort_index(), + ) + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.koalas.apply_batch(lambda pdf: pdf + 1).sort_index(), (pdf + 1).sort_index() + ) + self.assert_eq( + kdf.koalas.apply_batch(lambda pdf, b: pdf + b, b=1).sort_index(), + (pdf + 1).sort_index(), + ) + + with self.assertRaisesRegex(AssertionError, "the first argument should be a callable"): + kdf.koalas.apply_batch(1) + + with self.assertRaisesRegex(TypeError, "The given function.*frame as its type hints"): + + def f2(_) -> pp.Series[int]: + pass + + kdf.koalas.apply_batch(f2) + + with self.assertRaisesRegex(ValueError, "The given function should return a frame"): + kdf.koalas.apply_batch(lambda pdf: 1) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", "c")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.koalas.apply_batch(lambda x: x + 1).sort_index(), (pdf + 1).sort_index()) + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.koalas.apply_batch(lambda x: x + 1).sort_index(), (pdf + 1).sort_index() + ) + + def test_transform_batch(self): + pdf = pd.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6] * 100, + "b": [1.0, 1.0, 2.0, 3.0, 5.0, 8.0] * 100, + "c": [1, 4, 9, 16, 25, 36] * 100, + }, + columns=["a", "b", "c"], + index=np.random.rand(600), + ) + kdf = pp.DataFrame(pdf) + + # One to test alias. + self.assert_eq( + kdf.transform_batch(lambda pdf: pdf + 1).sort_index(), (pdf + 1).sort_index() + ) + self.assert_eq( + kdf.koalas.transform_batch(lambda pdf: pdf.c + 1).sort_index(), (pdf.c + 1).sort_index() + ) + self.assert_eq( + kdf.koalas.transform_batch(lambda pdf, a: pdf + a, 1).sort_index(), + (pdf + 1).sort_index(), + ) + self.assert_eq( + kdf.koalas.transform_batch(lambda pdf, a: pdf.c + a, a=1).sort_index(), + (pdf.c + 1).sort_index(), + ) + + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.koalas.transform_batch(lambda pdf: pdf + 1).sort_index(), (pdf + 1).sort_index() + ) + self.assert_eq( + kdf.koalas.transform_batch(lambda pdf: pdf.b + 1).sort_index(), + (pdf.b + 1).sort_index(), + ) + self.assert_eq( + kdf.koalas.transform_batch(lambda pdf, a: pdf + a, 1).sort_index(), + (pdf + 1).sort_index(), + ) + self.assert_eq( + kdf.koalas.transform_batch(lambda pdf, a: pdf.c + a, a=1).sort_index(), + (pdf.c + 1).sort_index(), + ) + + with self.assertRaisesRegex(AssertionError, "the first argument should be a callable"): + kdf.koalas.transform_batch(1) + + with self.assertRaisesRegex(ValueError, "The given function should return a frame"): + kdf.koalas.transform_batch(lambda pdf: 1) + + with self.assertRaisesRegex( + ValueError, "transform_batch cannot produce aggregated results" + ): + kdf.koalas.transform_batch(lambda pdf: pd.Series(1)) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("y", "c")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + kdf.koalas.transform_batch(lambda x: x + 1).sort_index(), (pdf + 1).sort_index() + ) + with option_context("compute.shortcut_limit", 500): + self.assert_eq( + kdf.koalas.transform_batch(lambda x: x + 1).sort_index(), (pdf + 1).sort_index() + ) + + def test_transform_batch_same_anchor(self): + kdf = pp.range(10) + kdf["d"] = kdf.koalas.transform_batch(lambda pdf: pdf.id + 1) + self.assert_eq( + kdf, pd.DataFrame({"id": list(range(10)), "d": list(range(1, 11))}, columns=["id", "d"]) + ) + + kdf = pp.range(10) + # One to test alias. + kdf["d"] = kdf.id.transform_batch(lambda ser: ser + 1) + self.assert_eq( + kdf, pd.DataFrame({"id": list(range(10)), "d": list(range(1, 11))}, columns=["id", "d"]) + ) + + kdf = pp.range(10) + + def plus_one(pdf) -> pp.Series[np.int64]: + return pdf.id + 1 + + kdf["d"] = kdf.koalas.transform_batch(plus_one) + self.assert_eq( + kdf, pd.DataFrame({"id": list(range(10)), "d": list(range(1, 11))}, columns=["id", "d"]) + ) + + kdf = pp.range(10) + + def plus_one(ser) -> pp.Series[np.int64]: + return ser + 1 + + kdf["d"] = kdf.id.koalas.transform_batch(plus_one) + self.assert_eq( + kdf, pd.DataFrame({"id": list(range(10)), "d": list(range(1, 11))}, columns=["id", "d"]) + ) + + def test_empty_timestamp(self): + pdf = pd.DataFrame( + { + "t": [ + datetime(2019, 1, 1, 0, 0, 0), + datetime(2019, 1, 2, 0, 0, 0), + datetime(2019, 1, 3, 0, 0, 0), + ] + }, + index=np.random.rand(3), + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf[kdf["t"] != kdf["t"]], pdf[pdf["t"] != pdf["t"]]) + self.assert_eq(kdf[kdf["t"] != kdf["t"]].dtypes, pdf[pdf["t"] != pdf["t"]].dtypes) + + def test_to_spark(self): + kdf = pp.from_pandas(self.pdf) + + with self.assertRaisesRegex(ValueError, "'index_col' cannot be overlapped"): + kdf.to_spark(index_col="a") + + with self.assertRaisesRegex(ValueError, "length of index columns.*1.*3"): + kdf.to_spark(index_col=["x", "y", "z"]) + + def test_keys(self): + pdf = pd.DataFrame( + [[1, 2], [4, 5], [7, 8]], + index=["cobra", "viper", "sidewinder"], + columns=["max_speed", "shield"], + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.keys(), pdf.keys()) + + def test_quantile(self): + pdf, kdf = self.df_pair + + self.assert_eq(kdf.quantile(0.5), pdf.quantile(0.5)) + self.assert_eq(kdf.quantile([0.25, 0.5, 0.75]), pdf.quantile([0.25, 0.5, 0.75])) + + self.assert_eq(kdf.loc[[]].quantile(0.5), pdf.loc[[]].quantile(0.5)) + self.assert_eq( + kdf.loc[[]].quantile([0.25, 0.5, 0.75]), pdf.loc[[]].quantile([0.25, 0.5, 0.75]) + ) + + with self.assertRaisesRegex( + NotImplementedError, 'axis should be either 0 or "index" currently.' + ): + kdf.quantile(0.5, axis=1) + with self.assertRaisesRegex(ValueError, "accuracy must be an integer; however"): + kdf.quantile(accuracy="a") + with self.assertRaisesRegex(ValueError, "q must be a float or an array of floats;"): + kdf.quantile(q="a") + with self.assertRaisesRegex(ValueError, "q must be a float or an array of floats;"): + kdf.quantile(q=["a"]) + + self.assert_eq(kdf.quantile(0.5, numeric_only=False), pdf.quantile(0.5, numeric_only=False)) + self.assert_eq( + kdf.quantile([0.25, 0.5, 0.75], numeric_only=False), + pdf.quantile([0.25, 0.5, 0.75], numeric_only=False), + ) + + # multi-index column + columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.quantile(0.5), pdf.quantile(0.5)) + self.assert_eq(kdf.quantile([0.25, 0.5, 0.75]), pdf.quantile([0.25, 0.5, 0.75])) + + pdf = pd.DataFrame({"x": ["a", "b", "c"]}) + kdf = pp.from_pandas(pdf) + + if LooseVersion(pd.__version__) >= LooseVersion("1.0.0"): + self.assert_eq(kdf.quantile(0.5), pdf.quantile(0.5)) + self.assert_eq(kdf.quantile([0.25, 0.5, 0.75]), pdf.quantile([0.25, 0.5, 0.75])) + else: + self.assert_eq(kdf.quantile(0.5), pd.Series(name=0.5)) + self.assert_eq(kdf.quantile([0.25, 0.5, 0.75]), pd.DataFrame(index=[0.25, 0.5, 0.75])) + + with self.assertRaisesRegex(TypeError, "Could not convert object \\(string\\) to numeric"): + kdf.quantile(0.5, numeric_only=False) + with self.assertRaisesRegex(TypeError, "Could not convert object \\(string\\) to numeric"): + kdf.quantile([0.25, 0.5, 0.75], numeric_only=False) + + def test_pct_change(self): + pdf = pd.DataFrame( + {"a": [1, 2, 3, 2], "b": [4.0, 2.0, 3.0, 1.0], "c": [300, 200, 400, 200]}, + index=np.random.rand(4), + ) + pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")]) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.pct_change(2), pdf.pct_change(2), check_exact=False) + self.assert_eq(kdf.pct_change().sum(), pdf.pct_change().sum(), check_exact=False) + + def test_where(self): + kdf = pp.from_pandas(self.pdf) + + with self.assertRaisesRegex(ValueError, "type of cond must be a DataFrame or Series"): + kdf.where(1) + + def test_mask(self): + kdf = pp.from_pandas(self.pdf) + + with self.assertRaisesRegex(ValueError, "type of cond must be a DataFrame or Series"): + kdf.mask(1) + + def test_query(self): + pdf = pd.DataFrame({"A": range(1, 6), "B": range(10, 0, -2), "C": range(10, 5, -1)}) + kdf = pp.from_pandas(pdf) + + exprs = ("A > B", "A < C", "C == B") + for expr in exprs: + self.assert_eq(kdf.query(expr), pdf.query(expr)) + + # test `inplace=True` + for expr in exprs: + dummy_kdf = kdf.copy() + dummy_pdf = pdf.copy() + + pser = dummy_pdf.A + kser = dummy_kdf.A + dummy_pdf.query(expr, inplace=True) + dummy_kdf.query(expr, inplace=True) + + self.assert_eq(dummy_kdf, dummy_pdf) + self.assert_eq(kser, pser) + + # invalid values for `expr` + invalid_exprs = (1, 1.0, (exprs[0],), [exprs[0]]) + for expr in invalid_exprs: + with self.assertRaisesRegex( + ValueError, + "expr must be a string to be evaluated, {} given".format(type(expr).__name__), + ): + kdf.query(expr) + + # invalid values for `inplace` + invalid_inplaces = (1, 0, "True", "False") + for inplace in invalid_inplaces: + with self.assertRaisesRegex( + ValueError, + 'For argument "inplace" expected type bool, received type {}.'.format( + type(inplace).__name__ + ), + ): + kdf.query("a < b", inplace=inplace) + + # doesn't support for MultiIndex columns + columns = pd.MultiIndex.from_tuples([("A", "Z"), ("B", "X"), ("C", "C")]) + kdf.columns = columns + with self.assertRaisesRegex(ValueError, "Doesn't support for MultiIndex columns"): + kdf.query("('A', 'Z') > ('B', 'X')") + + def test_take(self): + pdf = pd.DataFrame( + {"A": range(0, 50000), "B": range(100000, 0, -2), "C": range(100000, 50000, -1)} + ) + kdf = pp.from_pandas(pdf) + + # axis=0 (default) + self.assert_eq(kdf.take([1, 2]).sort_index(), pdf.take([1, 2]).sort_index()) + self.assert_eq(kdf.take([-1, -2]).sort_index(), pdf.take([-1, -2]).sort_index()) + self.assert_eq( + kdf.take(range(100, 110)).sort_index(), pdf.take(range(100, 110)).sort_index() + ) + self.assert_eq( + kdf.take(range(-110, -100)).sort_index(), pdf.take(range(-110, -100)).sort_index() + ) + self.assert_eq( + kdf.take([10, 100, 1000, 10000]).sort_index(), + pdf.take([10, 100, 1000, 10000]).sort_index(), + ) + self.assert_eq( + kdf.take([-10, -100, -1000, -10000]).sort_index(), + pdf.take([-10, -100, -1000, -10000]).sort_index(), + ) + + # axis=1 + self.assert_eq(kdf.take([1, 2], axis=1).sort_index(), pdf.take([1, 2], axis=1).sort_index()) + self.assert_eq( + kdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index() + ) + self.assert_eq( + kdf.take(range(1, 3), axis=1).sort_index(), pdf.take(range(1, 3), axis=1).sort_index(), + ) + self.assert_eq( + kdf.take(range(-1, -3), axis=1).sort_index(), + pdf.take(range(-1, -3), axis=1).sort_index(), + ) + self.assert_eq( + kdf.take([2, 1], axis=1).sort_index(), pdf.take([2, 1], axis=1).sort_index(), + ) + self.assert_eq( + kdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index(), + ) + + # MultiIndex columns + columns = pd.MultiIndex.from_tuples([("A", "Z"), ("B", "X"), ("C", "C")]) + kdf.columns = columns + pdf.columns = columns + + # MultiIndex columns with axis=0 (default) + self.assert_eq(kdf.take([1, 2]).sort_index(), pdf.take([1, 2]).sort_index()) + self.assert_eq(kdf.take([-1, -2]).sort_index(), pdf.take([-1, -2]).sort_index()) + self.assert_eq( + kdf.take(range(100, 110)).sort_index(), pdf.take(range(100, 110)).sort_index() + ) + self.assert_eq( + kdf.take(range(-110, -100)).sort_index(), pdf.take(range(-110, -100)).sort_index() + ) + self.assert_eq( + kdf.take([10, 100, 1000, 10000]).sort_index(), + pdf.take([10, 100, 1000, 10000]).sort_index(), + ) + self.assert_eq( + kdf.take([-10, -100, -1000, -10000]).sort_index(), + pdf.take([-10, -100, -1000, -10000]).sort_index(), + ) + + # axis=1 + self.assert_eq(kdf.take([1, 2], axis=1).sort_index(), pdf.take([1, 2], axis=1).sort_index()) + self.assert_eq( + kdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index() + ) + self.assert_eq( + kdf.take(range(1, 3), axis=1).sort_index(), pdf.take(range(1, 3), axis=1).sort_index(), + ) + self.assert_eq( + kdf.take(range(-1, -3), axis=1).sort_index(), + pdf.take(range(-1, -3), axis=1).sort_index(), + ) + self.assert_eq( + kdf.take([2, 1], axis=1).sort_index(), pdf.take([2, 1], axis=1).sort_index(), + ) + self.assert_eq( + kdf.take([-1, -2], axis=1).sort_index(), pdf.take([-1, -2], axis=1).sort_index(), + ) + + # Checking the type of indices. + self.assertRaises(ValueError, lambda: kdf.take(1)) + self.assertRaises(ValueError, lambda: kdf.take("1")) + self.assertRaises(ValueError, lambda: kdf.take({1, 2})) + self.assertRaises(ValueError, lambda: kdf.take({1: None, 2: None})) + + def test_axes(self): + pdf = self.pdf + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.axes, kdf.axes) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b")]) + pdf.columns = columns + kdf.columns = columns + self.assert_eq(pdf.axes, kdf.axes) + + def test_udt(self): + sparse_values = {0: 0.1, 1: 1.1} + sparse_vector = SparseVector(len(sparse_values), sparse_values) + pdf = pd.DataFrame({"a": [sparse_vector], "b": [10]}) + + if LooseVersion(pyspark.__version__) < LooseVersion("2.4"): + with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf, pdf) + else: + kdf = pp.from_pandas(pdf) + self.assert_eq(kdf, pdf) + + def test_eval(self): + pdf = pd.DataFrame({"A": range(1, 6), "B": range(10, 0, -2)}) + kdf = pp.from_pandas(pdf) + + # operation between columns (returns Series) + self.assert_eq(pdf.eval("A + B"), kdf.eval("A + B")) + self.assert_eq(pdf.eval("A + A"), kdf.eval("A + A")) + # assignment (returns DataFrame) + self.assert_eq(pdf.eval("C = A + B"), kdf.eval("C = A + B")) + self.assert_eq(pdf.eval("A = A + A"), kdf.eval("A = A + A")) + # operation between scalars (returns scalar) + self.assert_eq(pdf.eval("1 + 1"), kdf.eval("1 + 1")) + # complicated operations with assignment + self.assert_eq( + pdf.eval("B = A + B // (100 + 200) * (500 - B) - 10.5"), + kdf.eval("B = A + B // (100 + 200) * (500 - B) - 10.5"), + ) + + # inplace=True (only support for assignment) + pdf.eval("C = A + B", inplace=True) + kdf.eval("C = A + B", inplace=True) + self.assert_eq(pdf, kdf) + pser = pdf.A + kser = kdf.A + pdf.eval("A = B + C", inplace=True) + kdf.eval("A = B + C", inplace=True) + self.assert_eq(pdf, kdf) + self.assert_eq(pser, kser) + + # doesn't support for multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z", "c")]) + kdf.columns = columns + self.assertRaises(ValueError, lambda: kdf.eval("x.a + y.b")) + + @unittest.skipIf(not have_tabulate, "tabulate not installed") + def test_to_markdown(self): + pdf = pd.DataFrame(data={"animal_1": ["elk", "pig"], "animal_2": ["dog", "quetzal"]}) + kdf = pp.from_pandas(pdf) + + # `to_markdown()` is supported in pandas >= 1.0.0 since it's newly added in pandas 1.0.0. + if LooseVersion(pd.__version__) < LooseVersion("1.0.0"): + self.assertRaises(NotImplementedError, lambda: kdf.to_markdown()) + else: + self.assert_eq(pdf.to_markdown(), kdf.to_markdown()) + + def test_cache(self): + pdf = pd.DataFrame( + [(0.2, 0.3), (0.0, 0.6), (0.6, 0.0), (0.2, 0.1)], columns=["dogs", "cats"] + ) + kdf = pp.from_pandas(pdf) + + with kdf.cache() as cached_df: + self.assert_eq(isinstance(cached_df, CachedDataFrame), True) + self.assert_eq( + repr(cached_df.storage_level), repr(StorageLevel(True, True, False, True)) + ) + + def test_persist(self): + pdf = pd.DataFrame( + [(0.2, 0.3), (0.0, 0.6), (0.6, 0.0), (0.2, 0.1)], columns=["dogs", "cats"] + ) + kdf = pp.from_pandas(pdf) + storage_levels = [ + StorageLevel.DISK_ONLY, + StorageLevel.MEMORY_AND_DISK, + StorageLevel.MEMORY_ONLY, + StorageLevel.OFF_HEAP, + ] + + for storage_level in storage_levels: + with kdf.persist(storage_level) as cached_df: + self.assert_eq(isinstance(cached_df, CachedDataFrame), True) + self.assert_eq(repr(cached_df.storage_level), repr(storage_level)) + + self.assertRaises(TypeError, lambda: kdf.persist("DISK_ONLY")) + + def test_squeeze(self): + axises = [None, 0, 1, "rows", "index", "columns"] + + # Multiple columns + pdf = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"], index=["x", "y"]) + kdf = pp.from_pandas(pdf) + for axis in axises: + self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis)) + # Multiple columns with MultiIndex columns + columns = pd.MultiIndex.from_tuples([("A", "Z"), ("B", "X")]) + pdf.columns = columns + kdf.columns = columns + for axis in axises: + self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis)) + + # Single column with single value + pdf = pd.DataFrame([[1]], columns=["a"], index=["x"]) + kdf = pp.from_pandas(pdf) + for axis in axises: + self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis)) + # Single column with single value with MultiIndex column + columns = pd.MultiIndex.from_tuples([("A", "Z")]) + pdf.columns = columns + kdf.columns = columns + for axis in axises: + self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis)) + + # Single column with multiple values + pdf = pd.DataFrame([1, 2, 3, 4], columns=["a"]) + kdf = pp.from_pandas(pdf) + for axis in axises: + self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis)) + # Single column with multiple values with MultiIndex column + pdf.columns = columns + kdf.columns = columns + for axis in axises: + self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis)) + + def test_rfloordiv(self): + pdf = pd.DataFrame( + {"angles": [0, 3, 4], "degrees": [360, 180, 360]}, + index=["circle", "triangle", "rectangle"], + columns=["angles", "degrees"], + ) + kdf = pp.from_pandas(pdf) + + if LooseVersion(pd.__version__) < LooseVersion("1.0.0") and LooseVersion( + pd.__version__ + ) >= LooseVersion("0.24.0"): + expected_result = pd.DataFrame( + {"angles": [np.inf, 3.0, 2.0], "degrees": [0.0, 0.0, 0.0]}, + index=["circle", "triangle", "rectangle"], + columns=["angles", "degrees"], + ) + else: + expected_result = pdf.rfloordiv(10) + + self.assert_eq(kdf.rfloordiv(10), expected_result) + + def test_truncate(self): + pdf1 = pd.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g"], + "B": ["h", "i", "j", "k", "l", "m", "n"], + "C": ["o", "p", "q", "r", "s", "t", "u"], + }, + index=[-500, -20, -1, 0, 400, 550, 1000], + ) + kdf1 = pp.from_pandas(pdf1) + pdf2 = pd.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g"], + "B": ["h", "i", "j", "k", "l", "m", "n"], + "C": ["o", "p", "q", "r", "s", "t", "u"], + }, + index=[1000, 550, 400, 0, -1, -20, -500], + ) + kdf2 = pp.from_pandas(pdf2) + + self.assert_eq(kdf1.truncate(), pdf1.truncate()) + self.assert_eq(kdf1.truncate(before=-20), pdf1.truncate(before=-20)) + self.assert_eq(kdf1.truncate(after=400), pdf1.truncate(after=400)) + self.assert_eq(kdf1.truncate(copy=False), pdf1.truncate(copy=False)) + self.assert_eq(kdf1.truncate(-20, 400, copy=False), pdf1.truncate(-20, 400, copy=False)) + # The bug for these tests has been fixed in pandas 1.1.0. + if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): + self.assert_eq(kdf2.truncate(0, 550), pdf2.truncate(0, 550)) + self.assert_eq(kdf2.truncate(0, 550, copy=False), pdf2.truncate(0, 550, copy=False)) + else: + expected_kdf = pp.DataFrame( + {"A": ["b", "c", "d"], "B": ["i", "j", "k"], "C": ["p", "q", "r"]}, + index=[550, 400, 0], + ) + self.assert_eq(kdf2.truncate(0, 550), expected_kdf) + self.assert_eq(kdf2.truncate(0, 550, copy=False), expected_kdf) + + # axis = 1 + self.assert_eq(kdf1.truncate(axis=1), pdf1.truncate(axis=1)) + self.assert_eq(kdf1.truncate(before="B", axis=1), pdf1.truncate(before="B", axis=1)) + self.assert_eq(kdf1.truncate(after="A", axis=1), pdf1.truncate(after="A", axis=1)) + self.assert_eq(kdf1.truncate(copy=False, axis=1), pdf1.truncate(copy=False, axis=1)) + self.assert_eq(kdf2.truncate("B", "C", axis=1), pdf2.truncate("B", "C", axis=1)) + self.assert_eq( + kdf1.truncate("B", "C", copy=False, axis=1), + pdf1.truncate("B", "C", copy=False, axis=1), + ) + + # MultiIndex columns + columns = pd.MultiIndex.from_tuples([("A", "Z"), ("B", "X"), ("C", "Z")]) + pdf1.columns = columns + kdf1.columns = columns + pdf2.columns = columns + kdf2.columns = columns + + self.assert_eq(kdf1.truncate(), pdf1.truncate()) + self.assert_eq(kdf1.truncate(before=-20), pdf1.truncate(before=-20)) + self.assert_eq(kdf1.truncate(after=400), pdf1.truncate(after=400)) + self.assert_eq(kdf1.truncate(copy=False), pdf1.truncate(copy=False)) + self.assert_eq(kdf1.truncate(-20, 400, copy=False), pdf1.truncate(-20, 400, copy=False)) + # The bug for these tests has been fixed in pandas 1.1.0. + if LooseVersion(pd.__version__) >= LooseVersion("1.1.0"): + self.assert_eq(kdf2.truncate(0, 550), pdf2.truncate(0, 550)) + self.assert_eq(kdf2.truncate(0, 550, copy=False), pdf2.truncate(0, 550, copy=False)) + else: + expected_kdf.columns = columns + self.assert_eq(kdf2.truncate(0, 550), expected_kdf) + self.assert_eq(kdf2.truncate(0, 550, copy=False), expected_kdf) + # axis = 1 + self.assert_eq(kdf1.truncate(axis=1), pdf1.truncate(axis=1)) + self.assert_eq(kdf1.truncate(before="B", axis=1), pdf1.truncate(before="B", axis=1)) + self.assert_eq(kdf1.truncate(after="A", axis=1), pdf1.truncate(after="A", axis=1)) + self.assert_eq(kdf1.truncate(copy=False, axis=1), pdf1.truncate(copy=False, axis=1)) + self.assert_eq(kdf2.truncate("B", "C", axis=1), pdf2.truncate("B", "C", axis=1)) + self.assert_eq( + kdf1.truncate("B", "C", copy=False, axis=1), + pdf1.truncate("B", "C", copy=False, axis=1), + ) + + # Exceptions + kdf = pp.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g"], + "B": ["h", "i", "j", "k", "l", "m", "n"], + "C": ["o", "p", "q", "r", "s", "t", "u"], + }, + index=[-500, 100, 400, 0, -1, 550, -20], + ) + msg = "truncate requires a sorted index" + with self.assertRaisesRegex(ValueError, msg): + kdf.truncate() + + kdf = pp.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g"], + "B": ["h", "i", "j", "k", "l", "m", "n"], + "C": ["o", "p", "q", "r", "s", "t", "u"], + }, + index=[-500, -20, -1, 0, 400, 550, 1000], + ) + msg = "Truncate: -20 must be after 400" + with self.assertRaisesRegex(ValueError, msg): + kdf.truncate(400, -20) + msg = "Truncate: B must be after C" + with self.assertRaisesRegex(ValueError, msg): + kdf.truncate("C", "B", axis=1) + + def test_explode(self): + pdf = pd.DataFrame({"A": [[-1.0, np.nan], [0.0, np.inf], [1.0, -np.inf]], "B": 1}) + pdf.index.name = "index" + pdf.columns.name = "columns" + kdf = pp.from_pandas(pdf) + + if LooseVersion(pd.__version__) >= LooseVersion("0.25.0"): + expected_result1 = pdf.explode("A") + expected_result2 = pdf.explode("B") + else: + expected_result1 = pd.DataFrame( + {"A": [-1, np.nan, 0, np.inf, 1, -np.inf], "B": [1, 1, 1, 1, 1, 1]}, + index=pd.Index([0, 0, 1, 1, 2, 2]), + ) + expected_result1.index.name = "index" + expected_result1.columns.name = "columns" + expected_result2 = pdf + + self.assert_eq(kdf.explode("A"), expected_result1, almost=True) + self.assert_eq(repr(kdf.explode("B")), repr(expected_result2)) + self.assert_eq(kdf.explode("A").index.name, expected_result1.index.name) + self.assert_eq(kdf.explode("A").columns.name, expected_result1.columns.name) + + self.assertRaises(ValueError, lambda: kdf.explode(["A", "B"])) + + # MultiIndex + midx = pd.MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "c")], names=["index1", "index2"] + ) + pdf.index = midx + kdf = pp.from_pandas(pdf) + + if LooseVersion(pd.__version__) >= LooseVersion("0.25.0"): + expected_result1 = pdf.explode("A") + expected_result2 = pdf.explode("B") + else: + midx = pd.MultiIndex.from_tuples( + [("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("y", "c"), ("y", "c")], + names=["index1", "index2"], + ) + expected_result1.index = midx + expected_result2 = pdf + + self.assert_eq(kdf.explode("A"), expected_result1, almost=True) + self.assert_eq(repr(kdf.explode("B")), repr(expected_result2)) + self.assert_eq(kdf.explode("A").index.names, expected_result1.index.names) + self.assert_eq(kdf.explode("A").columns.name, expected_result1.columns.name) + + self.assertRaises(ValueError, lambda: kdf.explode(["A", "B"])) + + # MultiIndex columns + columns = pd.MultiIndex.from_tuples([("A", "Z"), ("B", "X")], names=["column1", "column2"]) + pdf.columns = columns + kdf.columns = columns + + if LooseVersion(pd.__version__) >= LooseVersion("0.25.0"): + expected_result1 = pdf.explode(("A", "Z")) + expected_result2 = pdf.explode(("B", "X")) + expected_result3 = pdf.A.explode("Z") + else: + expected_result1.columns = columns + expected_result2 = pdf + expected_result3 = pd.DataFrame({"Z": [-1, np.nan, 0, np.inf, 1, -np.inf]}, index=midx) + expected_result3.index.name = "index" + expected_result3.columns.name = "column2" + + self.assert_eq(kdf.explode(("A", "Z")), expected_result1, almost=True) + self.assert_eq(repr(kdf.explode(("B", "X"))), repr(expected_result2)) + self.assert_eq(kdf.explode(("A", "Z")).index.names, expected_result1.index.names) + self.assert_eq(kdf.explode(("A", "Z")).columns.names, expected_result1.columns.names) + + self.assert_eq(kdf.A.explode("Z"), expected_result3, almost=True) + + self.assertRaises(ValueError, lambda: kdf.explode(["A", "B"])) + self.assertRaises(ValueError, lambda: kdf.explode("A")) + + def test_spark_schema(self): + kdf = pp.DataFrame( + { + "a": list("abc"), + "b": list(range(1, 4)), + "c": np.arange(3, 6).astype("i1"), + "d": np.arange(4.0, 7.0, dtype="float64"), + "e": [True, False, True], + "f": pd.date_range("20130101", periods=3), + }, + columns=["a", "b", "c", "d", "e", "f"], + ) + self.assertEqual(kdf.spark_schema(), kdf.spark.schema()) + self.assertEqual(kdf.spark_schema("index"), kdf.spark.schema("index")) + + def test_print_schema(self): + kdf = pp.DataFrame( + {"a": list("abc"), "b": list(range(1, 4)), "c": np.arange(3, 6).astype("i1")}, + columns=["a", "b", "c"], + ) + + prev = sys.stdout + try: + out = StringIO() + sys.stdout = out + kdf.print_schema() + actual = out.getvalue().strip() + + out = StringIO() + sys.stdout = out + kdf.spark.print_schema() + expected = out.getvalue().strip() + + self.assertEqual(actual, expected) + finally: + sys.stdout = prev + + def test_explain_hint(self): + kdf1 = pp.DataFrame( + {"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]}, columns=["lkey", "value"] + ) + kdf2 = pp.DataFrame( + {"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]}, columns=["rkey", "value"] + ) + merged = kdf1.merge(kdf2.hint("broadcast"), left_on="lkey", right_on="rkey") + prev = sys.stdout + try: + out = StringIO() + sys.stdout = out + merged.explain() + actual = out.getvalue().strip() + + out = StringIO() + sys.stdout = out + merged.spark.explain() + expected = out.getvalue().strip() + + self.assertEqual(actual, expected) + finally: + sys.stdout = prev + + def test_mad(self): + pdf = pd.DataFrame( + { + "A": [1, 2, None, 4, np.nan], + "B": [-0.1, 0.2, -0.3, np.nan, 0.5], + "C": ["a", "b", "c", "d", "e"], + } + ) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.mad(), pdf.mad()) + self.assert_eq(kdf.mad(axis=1), pdf.mad(axis=1)) + + with self.assertRaises(ValueError): + kdf.mad(axis=2) + + # MultiIndex columns + columns = pd.MultiIndex.from_tuples([("A", "X"), ("A", "Y"), ("A", "Z")]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq(kdf.mad(), pdf.mad()) + self.assert_eq(kdf.mad(axis=1), pdf.mad(axis=1)) + + pdf = pd.DataFrame({"A": [True, True, False, False], "B": [True, False, False, True]}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(kdf.mad(), pdf.mad()) + self.assert_eq(kdf.mad(axis=1), pdf.mad(axis=1)) + + def test_abs(self): + pdf = pd.DataFrame({"a": [-2, -1, 0, 1]}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(abs(kdf), abs(pdf)) + self.assert_eq(np.abs(kdf), np.abs(pdf)) + + def test_iteritems(self): + pdf = pd.DataFrame( + {"species": ["bear", "bear", "marsupial"], "population": [1864, 22000, 80000]}, + index=["panda", "polar", "koala"], + columns=["species", "population"], + ) + kdf = pp.from_pandas(pdf) + + for (p_name, p_items), (k_name, k_items) in zip(pdf.iteritems(), kdf.iteritems()): + self.assert_eq(p_name, k_name) + self.assert_eq(p_items, k_items) + + @unittest.skipIf( + LooseVersion(pyspark.__version__) < LooseVersion("3.0"), + "tail won't work properly with PySpark<3.0", + ) + def test_tail(self): + pdf = pd.DataFrame({"x": range(1000)}) + kdf = pp.from_pandas(pdf) + + self.assert_eq(pdf.tail(), kdf.tail()) + self.assert_eq(pdf.tail(10), kdf.tail(10)) + self.assert_eq(pdf.tail(-990), kdf.tail(-990)) + self.assert_eq(pdf.tail(0), kdf.tail(0)) + self.assert_eq(pdf.tail(-1001), kdf.tail(-1001)) + self.assert_eq(pdf.tail(1001), kdf.tail(1001)) + self.assert_eq((pdf + 1).tail(), (kdf + 1).tail()) + self.assert_eq((pdf + 1).tail(10), (kdf + 1).tail(10)) + self.assert_eq((pdf + 1).tail(-990), (kdf + 1).tail(-990)) + self.assert_eq((pdf + 1).tail(0), (kdf + 1).tail(0)) + self.assert_eq((pdf + 1).tail(-1001), (kdf + 1).tail(-1001)) + self.assert_eq((pdf + 1).tail(1001), (kdf + 1).tail(1001)) + with self.assertRaisesRegex(TypeError, "bad operand type for unary -: 'str'"): + kdf.tail("10") + + @unittest.skipIf( + LooseVersion(pyspark.__version__) < LooseVersion("3.0"), + "last_valid_index won't work properly with PySpark<3.0", + ) + def test_last_valid_index(self): + pdf = pd.DataFrame( + {"a": [1, 2, 3, None], "b": [1.0, 2.0, 3.0, None], "c": [100, 200, 400, None]}, + index=["Q", "W", "E", "R"], + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.last_valid_index(), kdf.last_valid_index()) + self.assert_eq(pdf[[]].last_valid_index(), kdf[[]].last_valid_index()) + + # MultiIndex columns + pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")]) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.last_valid_index(), kdf.last_valid_index()) + + # Empty DataFrame + pdf = pd.Series([]).to_frame() + kdf = pp.Series([]).to_frame() + self.assert_eq(pdf.last_valid_index(), kdf.last_valid_index()) + + def test_last(self): + index = pd.date_range("2018-04-09", periods=4, freq="2D") + pdf = pd.DataFrame([1, 2, 3, 4], index=index) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.last("1D"), kdf.last("1D")) + self.assert_eq(pdf.last(DateOffset(days=1)), kdf.last(DateOffset(days=1))) + with self.assertRaisesRegex(TypeError, "'last' only supports a DatetimeIndex"): + pp.DataFrame([1, 2, 3, 4]).last("1D") + + def test_first(self): + index = pd.date_range("2018-04-09", periods=4, freq="2D") + pdf = pd.DataFrame([1, 2, 3, 4], index=index) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.first("1D"), kdf.first("1D")) + self.assert_eq(pdf.first(DateOffset(days=1)), kdf.first(DateOffset(days=1))) + with self.assertRaisesRegex(TypeError, "'first' only supports a DatetimeIndex"): + pp.DataFrame([1, 2, 3, 4]).first("1D") + + def test_first_valid_index(self): + pdf = pd.DataFrame( + {"a": [None, 2, 3, 2], "b": [None, 2.0, 3.0, 1.0], "c": [None, 200, 400, 200]}, + index=["Q", "W", "E", "R"], + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.first_valid_index(), kdf.first_valid_index()) + self.assert_eq(pdf[[]].first_valid_index(), kdf[[]].first_valid_index()) + + # MultiIndex columns + pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")]) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.first_valid_index(), kdf.first_valid_index()) + + # Empty DataFrame + pdf = pd.Series([]).to_frame() + kdf = pp.Series([]).to_frame() + self.assert_eq(pdf.first_valid_index(), kdf.first_valid_index()) + + pdf = pd.DataFrame( + {"a": [None, 2, 3, 2], "b": [None, 2.0, 3.0, 1.0], "c": [None, 200, 400, 200]}, + index=[ + datetime(2021, 1, 1), + datetime(2021, 2, 1), + datetime(2021, 3, 1), + datetime(2021, 4, 1), + ], + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.first_valid_index(), kdf.first_valid_index()) + + def test_product(self): + pdf = pd.DataFrame( + {"A": [1, 2, 3, 4, 5], "B": [10, 20, 30, 40, 50], "C": ["a", "b", "c", "d", "e"]} + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index()) + + # Named columns + pdf.columns.name = "Koalas" + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index()) + + # MultiIndex columns + pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")]) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index()) + + # Named MultiIndex columns + pdf.columns.names = ["Hello", "Koalas"] + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index()) + + # No numeric columns + pdf = pd.DataFrame({"key": ["a", "b", "c"], "val": ["x", "y", "z"]}) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index()) + + # No numeric named columns + pdf.columns.name = "Koalas" + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index(), almost=True) + + # No numeric MultiIndex columns + pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y")]) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index(), almost=True) + + # No numeric named MultiIndex columns + pdf.columns.names = ["Hello", "Koalas"] + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index(), almost=True) + + # All NaN columns + pdf = pd.DataFrame( + { + "A": [np.nan, np.nan, np.nan, np.nan, np.nan], + "B": [10, 20, 30, 40, 50], + "C": ["a", "b", "c", "d", "e"], + } + ) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index(), check_exact=False) + + # All NaN named columns + pdf.columns.name = "Koalas" + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index(), check_exact=False) + + # All NaN MultiIndex columns + pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")]) + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index(), check_exact=False) + + # All NaN named MultiIndex columns + pdf.columns.names = ["Hello", "Koalas"] + kdf = pp.from_pandas(pdf) + self.assert_eq(pdf.prod(), kdf.prod().sort_index(), check_exact=False) + + def test_from_dict(self): + data = {"row_1": [3, 2, 1, 0], "row_2": [10, 20, 30, 40]} + pdf = pd.DataFrame.from_dict(data) + kdf = pp.DataFrame.from_dict(data) + self.assert_eq(pdf, kdf) + + pdf = pd.DataFrame.from_dict(data, dtype="int8") + kdf = pp.DataFrame.from_dict(data, dtype="int8") + self.assert_eq(pdf, kdf) + + pdf = pd.DataFrame.from_dict(data, orient="index", columns=["A", "B", "C", "D"]) + kdf = pp.DataFrame.from_dict(data, orient="index", columns=["A", "B", "C", "D"]) + self.assert_eq(pdf, kdf) + + def test_pad(self): + pdf = pd.DataFrame( + { + "A": [None, 3, None, None], + "B": [2, 4, None, 3], + "C": [None, None, None, 1], + "D": [0, 1, 5, 4], + }, + columns=["A", "B", "C", "D"], + ) + kdf = pp.from_pandas(pdf) + + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + self.assert_eq(pdf.pad(), kdf.pad()) + + # Test `inplace=True` + pdf.pad(inplace=True) + kdf.pad(inplace=True) + self.assert_eq(pdf, kdf) + else: + expected = pp.DataFrame( + { + "A": [None, 3, 3, 3], + "B": [2.0, 4.0, 4.0, 3.0], + "C": [None, None, None, 1], + "D": [0, 1, 5, 4], + }, + columns=["A", "B", "C", "D"], + ) + self.assert_eq(expected, kdf.pad()) + + # Test `inplace=True` + kdf.pad(inplace=True) + self.assert_eq(expected, kdf) + + def test_backfill(self): + pdf = pd.DataFrame( + { + "A": [None, 3, None, None], + "B": [2, 4, None, 3], + "C": [None, None, None, 1], + "D": [0, 1, 5, 4], + }, + columns=["A", "B", "C", "D"], + ) + kdf = pp.from_pandas(pdf) + + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + self.assert_eq(pdf.backfill(), kdf.backfill()) + + # Test `inplace=True` + pdf.backfill(inplace=True) + kdf.backfill(inplace=True) + self.assert_eq(pdf, kdf) + else: + expected = pp.DataFrame( + { + "A": [3.0, 3.0, None, None], + "B": [2.0, 4.0, 3.0, 3.0], + "C": [1.0, 1.0, 1.0, 1.0], + "D": [0, 1, 5, 4], + }, + columns=["A", "B", "C", "D"], + ) + self.assert_eq(expected, kdf.backfill()) + + # Test `inplace=True` + kdf.backfill(inplace=True) + self.assert_eq(expected, kdf) + + def test_align(self): + pdf1 = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}, index=[10, 20, 30]) + kdf1 = pp.from_pandas(pdf1) + + for join in ["outer", "inner", "left", "right"]: + for axis in [None, 0, 1]: + kdf_l, kdf_r = kdf1.align(kdf1[["b"]], join=join, axis=axis) + pdf_l, pdf_r = pdf1.align(pdf1[["b"]], join=join, axis=axis) + self.assert_eq(kdf_l, pdf_l) + self.assert_eq(kdf_r, pdf_r) + + kdf_l, kdf_r = kdf1[["a"]].align(kdf1[["b", "a"]], join=join, axis=axis) + pdf_l, pdf_r = pdf1[["a"]].align(pdf1[["b", "a"]], join=join, axis=axis) + self.assert_eq(kdf_l, pdf_l) + self.assert_eq(kdf_r, pdf_r) + + kdf_l, kdf_r = kdf1[["b", "a"]].align(kdf1[["a"]], join=join, axis=axis) + pdf_l, pdf_r = pdf1[["b", "a"]].align(pdf1[["a"]], join=join, axis=axis) + self.assert_eq(kdf_l, pdf_l) + self.assert_eq(kdf_r, pdf_r) + + kdf_l, kdf_r = kdf1.align(kdf1["b"], axis=0) + pdf_l, pdf_r = pdf1.align(pdf1["b"], axis=0) + self.assert_eq(kdf_l, pdf_l) + self.assert_eq(kdf_r, pdf_r) + + kdf_l, kser_b = kdf1[["a"]].align(kdf1["b"], axis=0) + pdf_l, pser_b = pdf1[["a"]].align(pdf1["b"], axis=0) + self.assert_eq(kdf_l, pdf_l) + self.assert_eq(kser_b, pser_b) + + self.assertRaises(ValueError, lambda: kdf1.align(kdf1, join="unknown")) + self.assertRaises(ValueError, lambda: kdf1.align(kdf1["b"])) + self.assertRaises(NotImplementedError, lambda: kdf1.align(kdf1["b"], axis=1)) + + pdf2 = pd.DataFrame({"a": [4, 5, 6], "d": ["d", "e", "f"]}, index=[10, 11, 12]) + kdf2 = pp.from_pandas(pdf2) + + for join in ["outer", "inner", "left", "right"]: + kdf_l, kdf_r = kdf1.align(kdf2, join=join, axis=1) + pdf_l, pdf_r = pdf1.align(pdf2, join=join, axis=1) + self.assert_eq(kdf_l.sort_index(), pdf_l.sort_index()) + self.assert_eq(kdf_r.sort_index(), pdf_r.sort_index()) + + def test_between_time(self): + idx = pd.date_range("2018-04-09", periods=4, freq="1D20min") + pdf = pd.DataFrame({"A": [1, 2, 3, 4]}, index=idx) + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.between_time("0:15", "0:45").sort_index(), + kdf.between_time("0:15", "0:45").sort_index(), + ) + + pdf.index.name = "ts" + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.between_time("0:15", "0:45").sort_index(), + kdf.between_time("0:15", "0:45").sort_index(), + ) + + # Column label is 'index' + pdf.columns = pd.Index(["index"]) + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.between_time("0:15", "0:45").sort_index(), + kdf.between_time("0:15", "0:45").sort_index(), + ) + + # Both index name and column label are 'index' + pdf.index.name = "index" + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.between_time("0:15", "0:45").sort_index(), + kdf.between_time("0:15", "0:45").sort_index(), + ) + + # Index name is 'index', column label is ('X', 'A') + pdf.columns = pd.MultiIndex.from_arrays([["X"], ["A"]]) + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.between_time("0:15", "0:45").sort_index(), + kdf.between_time("0:15", "0:45").sort_index(), + ) + + with self.assertRaisesRegex( + NotImplementedError, "between_time currently only works for axis=0" + ): + kdf.between_time("0:15", "0:45", axis=1) + + kdf = pp.DataFrame({"A": [1, 2, 3, 4]}) + with self.assertRaisesRegex(TypeError, "Index must be DatetimeIndex"): + kdf.between_time("0:15", "0:45") + + def test_at_time(self): + idx = pd.date_range("2018-04-09", periods=4, freq="1D20min") + pdf = pd.DataFrame({"A": [1, 2, 3, 4]}, index=idx) + kdf = pp.from_pandas(pdf) + kdf.at_time("0:20") + self.assert_eq( + pdf.at_time("0:20").sort_index(), kdf.at_time("0:20").sort_index(), + ) + + # Index name is 'ts' + pdf.index.name = "ts" + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.at_time("0:20").sort_index(), kdf.at_time("0:20").sort_index(), + ) + + # Index name is 'ts', column label is 'index' + pdf.columns = pd.Index(["index"]) + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.at_time("0:40").sort_index(), kdf.at_time("0:40").sort_index(), + ) + + # Both index name and column label are 'index' + pdf.index.name = "index" + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.at_time("0:40").sort_index(), kdf.at_time("0:40").sort_index(), + ) + + # Index name is 'index', column label is ('X', 'A') + pdf.columns = pd.MultiIndex.from_arrays([["X"], ["A"]]) + kdf = pp.from_pandas(pdf) + self.assert_eq( + pdf.at_time("0:40").sort_index(), kdf.at_time("0:40").sort_index(), + ) + + with self.assertRaisesRegex(NotImplementedError, "'asof' argument is not supported"): + kdf.at_time("0:15", asof=True) + + with self.assertRaisesRegex(NotImplementedError, "at_time currently only works for axis=0"): + kdf.at_time("0:15", axis=1) + + kdf = pp.DataFrame({"A": [1, 2, 3, 4]}) + with self.assertRaisesRegex(TypeError, "Index must be DatetimeIndex"): + kdf.at_time("0:15") + + +if __name__ == "__main__": + from pyspark.pandas.tests.test_dataframe import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/run-tests.py b/python/run-tests.py index a13828d81f04f..fd9f287710b29 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -279,7 +279,8 @@ def main(): if python_implementation not in module.excluded_python_implementations: for test_goal in module.python_test_goals: heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', - 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] + 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests', + 'pyspark.pandas.tests'] if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): priority = 0 else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4f924755bb909..a4cb34b5c3e19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1472,7 +1472,11 @@ class Analyzer(override val catalogManager: CatalogManager) // The update value can access columns from both target and source tables. UpdateAction( resolvedUpdateCondition, - resolveAssignments(assignments, m, resolveValuesWithSourceOnly = false)) + resolveAssignments(Some(assignments), m, resolveValuesWithSourceOnly = false)) + case UpdateStarAction(updateCondition) => + UpdateAction( + updateCondition.map(resolveExpressionByPlanChildren(_, m)), + resolveAssignments(assignments = None, m, resolveValuesWithSourceOnly = false)) case o => o } val newNotMatchedActions = m.notMatchedActions.map { @@ -1483,7 +1487,15 @@ class Analyzer(override val catalogManager: CatalogManager) resolveExpressionByPlanChildren(_, Project(Nil, m.sourceTable))) InsertAction( resolvedInsertCondition, - resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true)) + resolveAssignments(Some(assignments), m, resolveValuesWithSourceOnly = true)) + case InsertStarAction(insertCondition) => + // The insert action is used when not matched, so its condition and value can only + // access columns from the source table. + val resolvedInsertCondition = insertCondition.map( + resolveExpressionByPlanChildren(_, Project(Nil, m.sourceTable))) + InsertAction( + resolvedInsertCondition, + resolveAssignments(assignments = None, m, resolveValuesWithSourceOnly = true)) case o => o } val resolvedMergeCondition = resolveExpressionByPlanChildren(m.mergeCondition, m) @@ -1501,7 +1513,7 @@ class Analyzer(override val catalogManager: CatalogManager) } def resolveAssignments( - assignments: Seq[Assignment], + assignments: Option[Seq[Assignment]], mergeInto: MergeIntoTable, resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = { if (assignments.isEmpty) { @@ -1509,7 +1521,7 @@ class Analyzer(override val catalogManager: CatalogManager) val expandedValues = mergeInto.sourceTable.output expandedColumns.zip(expandedValues).map(kv => Assignment(kv._1, kv._2)) } else { - assignments.map { assign => + assignments.get.map { assign => val resolvedKey = assign.key match { case c if !c.resolved => resolveExpressionByPlanChildren(c, Project(Nil, mergeInto.targetTable)) @@ -1790,16 +1802,30 @@ class Analyzer(override val catalogManager: CatalogManager) // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && - groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => - val newGroups = groups.map { - case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => - aggs(index - 1) - case ordinal @ UnresolvedOrdinal(index) => - throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size, ordinal) - case o => o - } + groups.exists(containUnresolvedOrdinal) => + val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, aggs)) Aggregate(newGroups, aggs, child) } + + private def containUnresolvedOrdinal(e: Expression): Boolean = e match { + case _: UnresolvedOrdinal => true + case gs: BaseGroupingSets => gs.children.exists(containUnresolvedOrdinal) + case _ => false + } + + private def resolveGroupByExpressionOrdinal( + expr: Expression, + aggs: Seq[Expression]): Expression = expr match { + case ordinal @ UnresolvedOrdinal(index) => + if (index > 0 && index <= aggs.size) { + aggs(index - 1) + } else { + throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size, ordinal) + } + case gs: BaseGroupingSets => + gs.withNewChildren(gs.children.map(resolveGroupByExpressionOrdinal(_, aggs))) + case others => others + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 389bbb828da6f..e751c32a7a068 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -260,7 +260,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { s"join condition '${condition.sql}' " + s"of type ${condition.dataType.catalogString} is not a boolean.") - case Aggregate(groupingExprs, aggregateExprs, child) => + case a @ Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression): Boolean = { expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) } @@ -305,6 +305,12 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { s"nor is it an aggregate function. " + "Add to group by or wrap in first() (or first_value) if you don't care " + "which value you get.") + case s: ScalarSubquery + if s.children.nonEmpty && !groupingExprs.exists(_.semanticEquals(s)) => + failAnalysis(s"Correlated scalar subquery '${s.sql}' is neither " + + "present in the group by, nor in an aggregate function. Add it to group by " + + "using ordinal position or wrap it in first() (or first_value) if you don't " + + "care which value you get.") case e if groupingExprs.exists(_.semanticEquals(e)) => // OK case e => e.children.foreach(checkValidAggregateExpression) } @@ -735,6 +741,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { case child => child } + // Check whether the given expressions contains the subquery expression. + def containsExpr(expressions: Seq[Expression]): Boolean = { + expressions.exists(_.find(_.semanticEquals(expr)).isDefined) + } + // Validate the subquery plan. checkAnalysis(expr.plan) @@ -756,7 +767,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // Only certain operators are allowed to host subquery expression containing // outer references. plan match { - case _: Filter | _: Aggregate | _: Project | _: SupportsSubquery => // Ok + case _: Filter | _: Project | _: SupportsSubquery => // Ok + case a: Aggregate => + // If the correlated scalar subquery is in the grouping expressions of an Aggregate, + // it must also be in the aggregate expressions to be rewritten in the optimization + // phase. + if (containsExpr(a.groupingExpressions) && !containsExpr(a.aggregateExpressions)) { + failAnalysis("Correlated scalar subqueries in the group by clause " + + s"must also be in the aggregate expressions:\n$a") + } case other => failAnalysis( "Correlated scalar sub-queries can only be used in a " + s"Filter/Aggregate/Project and a few commands: $plan") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index 1e7480a69e40f..c64c2ddb06cec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{BaseGroupingSets, Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin @@ -27,13 +27,20 @@ import org.apache.spark.sql.types.IntegerType * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression. */ object SubstituteUnresolvedOrdinals extends Rule[LogicalPlan] { - private def isIntLiteral(e: Expression) = e match { + private def containIntLiteral(e: Expression): Boolean = e match { case Literal(_, IntegerType) => true + case gs: BaseGroupingSets => gs.children.exists(containIntLiteral) case _ => false } + private def substituteUnresolvedOrdinal(expression: Expression): Expression = expression match { + case ordinal @ Literal(index: Int, IntegerType) => + withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) + case e => e + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => + case s: Sort if conf.orderByOrdinal && s.order.exists(o => containIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) @@ -42,10 +49,12 @@ object SubstituteUnresolvedOrdinals extends Rule[LogicalPlan] { } withOrigin(s.origin)(s.copy(order = newOrders)) - case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) => + case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(containIntLiteral) => val newGroups = a.groupingExpressions.map { case ordinal @ Literal(index: Int, IntegerType) => withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) + case gs: BaseGroupingSets => + withOrigin(gs.origin)(gs.withNewChildren(gs.children.map(substituteUnresolvedOrdinal))) case other => other } withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 3fc3db30fa7b1..3b2f4ca79cbc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -263,6 +263,9 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio override def terminate(): TraversableOnce[InternalRow] = throw QueryExecutionErrors.cannotTerminateGeneratorError(this) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UnresolvedGenerator = copy(children = newChildren) } case class UnresolvedFunction( @@ -284,6 +287,15 @@ case class UnresolvedFunction( val distinct = if (isDistinct) "distinct " else "" s"'$name($distinct${children.mkString(", ")})" } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UnresolvedFunction = { + if (filter.isDefined) { + copy(arguments = newChildren.dropRight(1), filter = Some(newChildren.last)) + } else { + copy(arguments = newChildren) + } + } } object UnresolvedFunction { @@ -441,6 +453,8 @@ case class MultiAlias(child: Expression, names: Seq[String]) override def toString: String = s"$child AS $names" + override protected def withNewChildInternal(newChild: Expression): MultiAlias = + copy(child = newChild) } /** @@ -475,6 +489,11 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def toString: String = s"$child[$extraction]" override def sql: String = s"${child.sql}[${extraction.sql}]" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): UnresolvedExtractValue = { + copy(child = newLeft, extraction = newRight) + } } /** @@ -499,6 +518,9 @@ case class UnresolvedAlias( override def newInstance(): NamedExpression = throw new UnresolvedException("newInstance") override lazy val resolved = false + + override protected def withNewChildInternal(newChild: Expression): UnresolvedAlias = + copy(child = newChild) } /** @@ -520,6 +542,9 @@ case class UnresolvedSubqueryColumnAliases( override def output: Seq[Attribute] = Nil override lazy val resolved = false + + override protected def withNewChildInternal( + newChild: LogicalPlan): UnresolvedSubqueryColumnAliases = copy(child = newChild) } /** @@ -541,6 +566,9 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq override def dataType: DataType = throw new UnresolvedException("dataType") override def nullable: Boolean = throw new UnresolvedException("nullable") override lazy val resolved = false + + override protected def withNewChildInternal(newChild: Expression): UnresolvedDeserializer = + copy(deserializer = newChild) } case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression @@ -587,6 +615,8 @@ case class UnresolvedHaving( extends UnaryNode { override lazy val resolved: Boolean = false override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedHaving = + copy(child = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 0de17d420f0c9..7cb830d115689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -114,6 +114,9 @@ case class CallMethodViaReflection(children: Seq[Expression]) /** A temporary buffer used to hold intermediate results returned by children. */ @transient private lazy val buffer = new Array[Object](argExprs.length) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CallMethodViaReflection = copy(children = newChildren) } object CallMethodViaReflection { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 879b154a84761..1e1b7eeca0f35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1812,6 +1812,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } else { s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}" } + + override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild) } /** @@ -1841,6 +1843,8 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St Some(SQLConf.STORE_ASSIGNMENT_POLICY.key), Some(SQLConf.StoreAssignmentPolicy.LEGACY.toString)) + override protected def withNewChildInternal(newChild: Expression): AnsiCast = + copy(child = newChild) } object AnsiCast { @@ -1998,4 +2002,6 @@ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: S case DecimalType => DecimalType.SYSTEM_DEFAULT case _ => target.asInstanceOf[DataType] } + + override protected def withNewChildInternal(newChild: Expression): UpCast = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index 550fa4c3f73e4..de4b874637f09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -78,6 +78,9 @@ case class DynamicPruningSubquery( buildKeys = buildKeys.map(_.canonicalized), exprId = ExprId(0)) } + + override protected def withNewChildInternal(newChild: Expression): DynamicPruningSubquery = + copy(pruningKey = newChild) } /** @@ -94,4 +97,7 @@ case class DynamicPruningExpression(child: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.genCode(ctx) } + + override protected def withNewChildInternal(newChild: Expression): DynamicPruningExpression = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala index 05d553757e742..ab390618d4c5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala @@ -43,6 +43,7 @@ abstract class PartitionTransformExpression extends Expression with Unevaluable */ case class Years(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Years = copy(child = newChild) } /** @@ -50,6 +51,7 @@ case class Years(child: Expression) extends PartitionTransformExpression { */ case class Months(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Months = copy(child = newChild) } /** @@ -57,6 +59,7 @@ case class Months(child: Expression) extends PartitionTransformExpression { */ case class Days(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Days = copy(child = newChild) } /** @@ -64,6 +67,7 @@ case class Days(child: Expression) extends PartitionTransformExpression { */ case class Hours(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Hours = copy(child = newChild) } /** @@ -71,4 +75,5 @@ case class Hours(child: Expression) extends PartitionTransformExpression { */ case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Bucket = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index da2e1821feb0f..73f8c300b4ae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -73,4 +73,7 @@ case class PythonUDF( // `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result. this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDF = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 4086e7698e7b1..375ae95acfc39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1195,4 +1195,7 @@ case class ScalaUDF( resultConverter(result) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ScalaUDF = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d9923b5d022e0..9aef25ce60599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -88,6 +88,9 @@ case class SortOrder( children.exists(required.child.semanticEquals) && direction == required.direction && nullOrdering == required.nullOrdering } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): SortOrder = + copy(child = newChildren.head, sameOrderExpressions = newChildren.tail) } object SortOrder { @@ -226,4 +229,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } override def dataType: DataType = LongType + + override protected def withNewChildInternal(newChild: Expression): SortPrefix = + copy(child = newChild.asInstanceOf[SortOrder]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala index a1f7ba3008775..0f224fefe3911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala @@ -140,6 +140,9 @@ case class ExpressionProxy( } override def hashCode(): Int = this.id.hashCode() + + override protected def withNewChildInternal(newChild: Expression): ExpressionProxy = + copy(child = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index f7fe467cea830..ed1d77017c120 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -92,6 +92,9 @@ case class TimeWindow( } dataTypeCheck } + + override protected def withNewChildInternal(newChild: Expression): TimeWindow = + copy(timeColumn = newChild) } object TimeWindow { @@ -155,4 +158,7 @@ case class PreciseTimestampConversion( """.stripMargin) } override def nullSafeEval(input: Any): Any = input + + override protected def withNewChildInternal(newChild: Expression): PreciseTimestampConversion = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala index cae25a263a8e3..0f63de1bf7e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala @@ -84,4 +84,13 @@ case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[Str override def typeCheckFailureMessage: String = AnsiCast.typeCheckFailureMessage(child.dataType, dataType, None, None) + + override protected def withNewChildInternal(newChild: Expression): TryCast = + copy(child = newChild) + + override def toString: String = { + s"try_cast($child as ${dataType.simpleString})" + } + + override def sql: String = s"TRY_CAST(${child.sql} AS ${dataType.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 42dc6f6b200d0..19e212d1f9e69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -249,4 +249,8 @@ case class ApproxCountDistinctForIntervals( override def getLong(offset: Int): Long = array(offset) override def setLong(offset: Int, value: Long): Unit = { array(offset) = value } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ApproxCountDistinctForIntervals = + copy(child = newLeft, endpointsExpression = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 4e4a06a628453..38d8d7d71ead8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -208,6 +208,10 @@ case class ApproximatePercentile( override def deserialize(bytes: Array[Byte]): PercentileDigest = { ApproximatePercentile.serializer.deserialize(bytes) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): ApproximatePercentile = + copy(child = newFirst, percentageExpression = newSecond, accuracyExpression = newThird) } object ApproximatePercentile { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 36004b0ea6244..90e91ae41856c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -93,4 +93,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit coalesce(child.cast(sumDataType), Literal.default(sumDataType))), /* count = */ If(child.isNull, count, count + 1L) ) + + override protected def withNewChildInternal(newChild: Expression): Average = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 4ca933ff45d02..c5c78e5062f56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -167,6 +167,9 @@ case class StddevPop( } override def prettyName: String = "stddev_pop" + + override protected def withNewChildInternal(newChild: Expression): StddevPop = + copy(child = newChild) } // Compute the sample standard deviation of a column @@ -197,6 +200,9 @@ case class StddevSamp( override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp") + + override protected def withNewChildInternal(newChild: Expression): StddevSamp = + copy(child = newChild) } // Compute the population variance of a column @@ -223,6 +229,9 @@ case class VariancePop( } override def prettyName: String = "var_pop" + + override protected def withNewChildInternal(newChild: Expression): VariancePop = + copy(child = newChild) } // Compute the sample variance of a column @@ -250,6 +259,9 @@ case class VarianceSamp( } override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp") + + override protected def withNewChildInternal(newChild: Expression): VarianceSamp = + copy(child = newChild) } @ExpressionDescription( @@ -278,6 +290,9 @@ case class Skewness( If(n === 0.0, Literal.create(null, DoubleType), If(m2 === 0.0, divideByZeroEvalResult, sqrt(n) * m3 / sqrt(m2 * m2 * m2))) } + + override protected def withNewChildInternal(newChild: Expression): Skewness = + copy(child = newChild) } @ExpressionDescription( @@ -306,4 +321,7 @@ case class Kurtosis( } override def prettyName: String = "kurtosis" + + override protected def withNewChildInternal(newChild: Expression): Kurtosis = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index d819971478ecf..c798004fe7843 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -127,4 +127,7 @@ case class Corr( } override def prettyName: String = "corr" + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Corr = + copy(x = newLeft, y = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 189d21603e70f..1d13155ef6898 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -89,6 +89,9 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { ) } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Count = + copy(children = newChildren) } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala index c1c4c84497bcd..d4fdd5115b59d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala @@ -56,4 +56,7 @@ case class CountIf(predicate: Expression) extends UnevaluableAggregate with Impl s"function $prettyName requires boolean type, not ${predicate.dataType.catalogString}" ) } + + override protected def withNewChildInternal(newChild: Expression): CountIf = + copy(predicate = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index a838a0a0e8977..38d0db1e7610c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -154,4 +154,12 @@ case class CountMinSketchAgg( override def second: Expression = epsExpression override def third: Expression = confidenceExpression override def fourth: Expression = seedExpression + + override protected def withNewChildrenInternal(first: Expression, second: Expression, + third: Expression, fourth: Expression): CountMinSketchAgg = + copy( + child = first, + epsExpression = second, + confidenceExpression = third, + seedExpression = fourth) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 8fcee104d276b..9ea9b3782032b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -109,6 +109,10 @@ case class CovPopulation( If(n === 0.0, Literal.create(null, DoubleType), ck / n) } override def prettyName: String = "covar_pop" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): CovPopulation = + copy(left = newLeft, right = newRight) } @@ -135,4 +139,7 @@ case class CovSample( If(n === 1.0, divideByZeroEvalResult, ck / (n - 1.0))) } override def prettyName: String = "covar_samp" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): CovSample = copy(left = newLeft, right = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index accd15a711503..ea994af0e6168 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -118,6 +118,8 @@ case class First(child: Expression, ignoreNulls: Boolean) override lazy val evaluateExpression: AttributeReference = first override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" + + override protected def withNewChildInternal(newChild: Expression): First = copy(child = newChild) } object FirstLast { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 430c25cee2a93..9b0493f3e68a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -138,6 +138,9 @@ case class HyperLogLogPlusPlus( override def eval(buffer: InternalRow): Any = { hllppHelper.query(buffer, mutableAggBufferOffset) } + + override protected def withNewChildInternal(newChild: Expression): HyperLogLogPlusPlus = + copy(child = newChild) } object HyperLogLogPlusPlus { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index e3c427d584489..0fe6199cd8c31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -115,4 +115,6 @@ case class Last(child: Expression, ignoreNulls: Boolean) override lazy val evaluateExpression: AttributeReference = last override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" + + override protected def withNewChildInternal(newChild: Expression): Last = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 42721ea48c7ca..b802678ec0468 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -62,4 +62,6 @@ case class Max(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex } override lazy val evaluateExpression: AttributeReference = max + + override protected def withNewChildInternal(newChild: Expression): Max = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index e402bcae144ad..664bc32ccc464 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -110,6 +110,9 @@ case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = greatest(oldExpr, newExpr) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): MaxBy = + copy(valueExpr = newLeft, orderingExpr = newRight) } @ExpressionDescription( @@ -130,4 +133,7 @@ case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = least(oldExpr, newExpr) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): MinBy = + copy(valueExpr = newLeft, orderingExpr = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 84410c7de3229..9c5c7bbda4dc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -62,4 +62,6 @@ case class Min(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex } override lazy val evaluateExpression: AttributeReference = min + + override protected def withNewChildInternal(newChild: Expression): Min = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index b81c523ce32ba..5bce4d348c726 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -304,4 +304,11 @@ case class Percentile( bis.close() } } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Percentile = copy( + child = newFirst, + percentageExpression = newSecond, + frequencyExpression = newThird + ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 422fcab5bf890..b90e46e1545d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -153,5 +153,9 @@ case class PivotFirst( override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): PivotFirst = + copy(pivotColumn = newLeft, valueColumn = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala index 50c74f1c49a99..3af3944fd47d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala @@ -59,4 +59,7 @@ case class Product(child: Expression) Seq(coalesce(coalesce(product.left, one) * product.right, product.left)) override lazy val evaluateExpression: Expression = product + + override protected def withNewChildInternal(newChild: Expression): Product = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index f412a3ec31e0f..56eebedddf08d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -148,4 +148,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } + + override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala index 5b914c4333687..878d853aca3cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -56,6 +56,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) since = "3.0.0") case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(arg = newChild) } @ExpressionDescription( @@ -73,4 +75,6 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { since = "3.0.0") case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) { override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(arg = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala index 5ffc0f6ce3a42..86a16ad389b5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala @@ -69,6 +69,9 @@ case class BitAndAgg(child: Expression) extends BitAggregate { override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { BitwiseAnd(left, right) } + + override protected def withNewChildInternal(newChild: Expression): BitAndAgg = + copy(child = newChild) } @ExpressionDescription( @@ -87,6 +90,9 @@ case class BitOrAgg(child: Expression) extends BitAggregate { override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { BitwiseOr(left, right) } + + override protected def withNewChildInternal(newChild: Expression): BitOrAgg = + copy(child = newChild) } @ExpressionDescription( @@ -105,4 +111,7 @@ case class BitXorAgg(child: Expression) extends BitAggregate { override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { BitwiseXor(left, right) } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index d8a76d7add262..a8db8211a9e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -125,6 +125,9 @@ case class CollectList( override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { new GenericArrayData(buffer.toArray) } + + override protected def withNewChildInternal(newChild: Expression): CollectList = + copy(child = newChild) } /** @@ -191,4 +194,7 @@ case class CollectSet( override def prettyName: String = "collect_set" override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty + + override protected def withNewChildInternal(newChild: Expression): CollectSet = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e0c6ce7208c94..281734c6f14ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -164,6 +164,16 @@ case class AggregateExpression( case _ => aggFuncStr } } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): AggregateExpression = + if (filter.isDefined) { + copy( + aggregateFunction = newChildren(0).asInstanceOf[AggregateFunction], + filter = Some(newChildren(1))) + } else { + copy(aggregateFunction = newChildren(0).asInstanceOf[AggregateFunction]) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 64ea579e5ca05..28851918429aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -105,6 +105,9 @@ case class UnaryMinus( case funcName => s"$funcName(${child.sql})" } } + + override protected def withNewChildInternal(newChild: Expression): UnaryMinus = + copy(child = newChild) } @ExpressionDescription( @@ -131,6 +134,9 @@ case class UnaryPositive(child: Expression) protected override def nullSafeEval(input: Any): Any = input override def sql: String = s"(+ ${child.sql})" + + override protected def withNewChildInternal(newChild: Expression): UnaryPositive = + copy(child = newChild) } /** @@ -183,6 +189,8 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) + + override protected def withNewChildInternal(newChild: Expression): Abs = copy(child = newChild) } abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { @@ -309,6 +317,9 @@ case class Add( } override def exactMathMethod: Option[String] = Some("addExact") + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Add = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -352,6 +363,9 @@ case class Subtract( } override def exactMathMethod: Option[String] = Some("subtractExact") + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Subtract = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -380,6 +394,9 @@ case class Multiply( protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) override def exactMathMethod: Option[String] = Some("multiplyExact") + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight) } // Common base trait for Divide and Remainder, since these two classes are almost identical @@ -506,6 +523,9 @@ case class Divide( } override def evalOperation(left: Any, right: Any): Any = div(left, right) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Divide = copy(left = newLeft, right = newRight) } // scalastyle:off line.size.limit @@ -553,6 +573,10 @@ case class IntegralDivide( } override def evalOperation(left: Any, right: Any): Any = div(left, right) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): IntegralDivide = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -607,6 +631,9 @@ case class Remainder( } override def evalOperation(left: Any, right: Any): Any = mod(left, right) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Remainder = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -791,6 +818,9 @@ case class Pmod( } override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Pmod = + copy(left = newLeft, right = newRight) } /** @@ -866,6 +896,9 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression |$codes """.stripMargin) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Least = + copy(children = newChildren) } /** @@ -941,4 +974,7 @@ case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpress |$codes """.stripMargin) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Greatest = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index a1fb68ea169c5..3940c65593ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -56,6 +56,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme } protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BitwiseAnd = copy(left = newLeft, right = newRight) } /** @@ -92,6 +95,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BitwiseOr = copy(left = newLeft, right = newRight) } /** @@ -128,6 +134,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme } protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BitwiseXor = copy(left = newLeft, right = newRight) } /** @@ -169,6 +178,9 @@ case class BitwiseNot(child: Expression) protected override def nullSafeEval(input: Any): Any = not(input) override def sql: String = s"~${child.sql}" + + override protected def withNewChildInternal(newChild: Expression): BitwiseNot = + copy(child = newChild) } @ExpressionDescription( @@ -204,6 +216,9 @@ case class BitwiseCount(child: Expression) case IntegerType => java.lang.Long.bitCount(input.asInstanceOf[Int]) case LongType => java.lang.Long.bitCount(input.asInstanceOf[Long]) } + + override protected def withNewChildInternal(newChild: Expression): BitwiseCount = + copy(child = newChild) } object BitwiseGetUtil { @@ -262,4 +277,7 @@ case class BitwiseGet(left: Expression, right: Expression) override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bit_get") + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BitwiseGet = copy(left = newLeft, right = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 689858dc6ee67..c840cdfd8b2dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -22,7 +22,7 @@ import java.lang.{Boolean => JBool} import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{LeafLike, TreeNode} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{BooleanType, DataType} @@ -298,11 +298,13 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } buf.toString } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Block]): Block = + super.legacyWithNewChildren(newChildren) } -case object EmptyBlock extends Block with Serializable { +case object EmptyBlock extends Block with Serializable with LeafLike[Block] { override val code: String = "" - override def children: Seq[Block] = Seq.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d3fad8cb329c2..125e796a98c2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -125,6 +125,8 @@ case class Size(child: Expression, legacySizeOfNull: Boolean) defineCodeGen(ctx, ev, c => s"($c).numElements()") } } + + override protected def withNewChildInternal(newChild: Expression): Size = copy(child = newChild) } object Size { @@ -159,6 +161,9 @@ case class MapKeys(child: Expression) } override def prettyName: String = "map_keys" + + override protected def withNewChildInternal(newChild: Expression): MapKeys = + copy(child = newChild) } @ExpressionDescription( @@ -321,6 +326,9 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI } override def prettyName: String = "arrays_zip" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ArraysZip = + copy(children = newChildren) } /** @@ -351,6 +359,9 @@ case class MapValues(child: Expression) } override def prettyName: String = "map_values" + + override protected def withNewChildInternal(newChild: Expression): MapValues = + copy(child = newChild) } /** @@ -523,6 +534,8 @@ case class MapEntries(child: Expression) } override def prettyName: String = "map_entries" + + override def withNewChildInternal(newChild: Expression): MapEntries = copy(child = newChild) } /** @@ -642,6 +655,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres } override def prettyName: String = "map_concat" + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): MapConcat = + copy(children = newChildren) } /** @@ -720,6 +736,9 @@ case class MapFromEntries(child: Expression) extends UnaryExpression with NullIn } override def prettyName: String = "map_from_entries" + + override protected def withNewChildInternal(newChild: Expression): MapFromEntries = + copy(child = newChild) } @@ -919,6 +938,10 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } override def prettyName: String = "sort_array" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): SortArray = + copy(base = newLeft, ascendingOrder = newRight) } /** @@ -1007,6 +1030,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) } override def freshCopy(): Shuffle = Shuffle(child, randomSeed) + + override def withNewChildInternal(newChild: Expression): Shuffle = copy(child = newChild) } /** @@ -1083,6 +1108,9 @@ case class Reverse(child: Expression) } override def prettyName: String = "reverse" + + override protected def withNewChildInternal(newChild: Expression): Reverse = + copy(child = newChild) } /** @@ -1180,6 +1208,10 @@ case class ArrayContains(left: Expression, right: Expression) } override def prettyName: String = "array_contains" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayContains = + copy(left = newLeft, right = newRight) } /** @@ -1403,6 +1435,10 @@ case class ArraysOverlap(left: Expression, right: Expression) } override def prettyName: String = "arrays_overlap" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArraysOverlap = + copy(left = newLeft, right = newRight) } /** @@ -1516,6 +1552,10 @@ case class Slice(x: Expression, start: Expression, length: Expression) |} """.stripMargin } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Slice = + copy(x = newFirst, start = newSecond, length = newThird) } /** @@ -1559,6 +1599,16 @@ case class ArrayJoin( Seq(array, delimiter) } + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + if (nullReplacement.isDefined) { + copy( + array = newChildren(0), + delimiter = newChildren(1), + nullReplacement = Some(newChildren(2))) + } else { + copy(array = newChildren(0), delimiter = newChildren(1)) + } + override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -1756,6 +1806,9 @@ case class ArrayMin(child: Expression) } override def prettyName: String = "array_min" + + override protected def withNewChildInternal(newChild: Expression): ArrayMin = + copy(child = newChild) } /** @@ -1824,6 +1877,9 @@ case class ArrayMax(child: Expression) } override def prettyName: String = "array_max" + + override protected def withNewChildInternal(newChild: Expression): ArrayMax = + copy(child = newChild) } @@ -1903,6 +1959,10 @@ case class ArrayPosition(left: Expression, right: Expression) """.stripMargin }) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPosition = + copy(left = newLeft, right = newRight) } /** @@ -2085,6 +2145,9 @@ case class ElementAt( } override def prettyName: String = "element_at" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight) } /** @@ -2291,6 +2354,9 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio override def toString: String = s"concat(${children.mkString(", ")})" override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Concat = + copy(children = newChildren) } /** @@ -2403,6 +2469,9 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran } override def prettyName: String = "flatten" + + override protected def withNewChildInternal(newChild: Expression): Flatten = + copy(child = newChild) } @ExpressionDescription( @@ -2460,6 +2529,15 @@ case class Sequence( override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt + override def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): TimeZoneAwareExpression = { + if (stepOpt.isDefined) { + copy(start = newChildren(0), stop = newChildren(1), stepOpt = Some(newChildren(2))) + } else { + copy(start = newChildren(0), stop = newChildren(1)) + } + } + override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children.exists(_.nullable) @@ -2949,6 +3027,8 @@ case class ArrayRepeat(left: Expression, right: Expression) """.stripMargin } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayRepeat = copy(left = newLeft, right = newRight) } /** @@ -3063,6 +3143,9 @@ case class ArrayRemove(left: Expression, right: Expression) } override def prettyName: String = "array_remove" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayRemove = copy(left = newLeft, right = newRight) } /** @@ -3295,6 +3378,9 @@ case class ArrayDistinct(child: Expression) } override def prettyName: String = "array_distinct" + + override protected def withNewChildInternal(newChild: Expression): ArrayDistinct = + copy(child = newChild) } /** @@ -3497,6 +3583,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi } override def prettyName: String = "array_union" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayUnion = copy(left = newLeft, right = newRight) } object ArrayUnion { @@ -3780,6 +3869,10 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina } override def prettyName: String = "array_intersect" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayIntersect = + copy(left = newLeft, right = newRight) } /** @@ -4004,4 +4097,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL } override def prettyName: String = "array_except" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayExcept = copy(left = newLeft, right = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3c016a7a54995..f1456c4c8e079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -102,6 +102,9 @@ case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolea } override def prettyName: String = "array" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CreateArray = + copy(children = newChildren) } object CreateArray { @@ -254,6 +257,9 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) } override def prettyName: String = "map" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CreateMap = + copy(children = newChildren) } object CreateMap { @@ -314,6 +320,10 @@ case class MapFromArrays(left: Expression, right: Expression) } override def prettyName: String = "map_from_arrays" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MapFromArrays = + copy(left = newLeft, right = newRight) } /** @@ -493,6 +503,9 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ") s"$alias($childrenSQL)" }.getOrElse(super.sql) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CreateNamedStruct = copy(children = newChildren) } /** @@ -576,6 +589,13 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } override def prettyName: String = "str_to_map" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy( + text = newFirst, + pairDelim = newSecond, + keyValueDelim = newThird + ) } /** @@ -627,6 +647,9 @@ case class WithField(name: String, valExpr: Expression) "WithField.nullable should not be called.") override def prettyName: String = "WithField" + + override protected def withNewChildInternal(newChild: Expression): WithField = + copy(valExpr = newChild) } /** @@ -659,6 +682,9 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat case e: Expression => e } + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) + override def dataType: StructType = StructType(newFields) override def nullable: Boolean = structExpr.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 139d9a584ccbe..f64cc8a28b566 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -138,6 +138,9 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } }) } + + override protected def withNewChildInternal(newChild: Expression): GetStructField = + copy(child = newChild) } /** @@ -212,6 +215,9 @@ case class GetArrayStructFields( """ }) } + + override protected def withNewChildInternal(newChild: Expression): GetArrayStructFields = + copy(child = newChild) } /** @@ -292,6 +298,10 @@ case class GetArrayItem( """ }) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GetArrayItem = + copy(child = newLeft, ordinal = newRight) } /** @@ -470,4 +480,8 @@ case class GetMapValue( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType], failOnError) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GetMapValue = + copy(child = newLeft, key = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index a062dd49a3c92..e708d56cd89c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -95,6 +95,13 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def toString: String = s"if ($predicate) $trueValue else $falseValue" override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy( + predicate = newFirst, + trueValue = newSecond, + falseValue = newThird + ) } /** @@ -132,6 +139,9 @@ case class CaseWhen( override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) + // both then and else expressions should be considered. @transient override lazy val inputTypesForMerging: Seq[DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 5bfae7b77e096..8feaf52ecb134 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -36,6 +36,12 @@ case class KnownNotNull(child: Expression) extends TaggingExpression { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.genCode(ctx).copy(isNull = FalseLiteral) } + + override protected def withNewChildInternal(newChild: Expression): KnownNotNull = + copy(child = newChild) } -case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression +case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression { + override protected def withNewChildInternal(newChild: Expression): KnownFloatingPointNormalized = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index ac47020de4d46..79bbc103c92d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -140,6 +140,9 @@ case class CsvToStructs( override def inputTypes: Seq[AbstractDataType] = StringType :: Nil override def prettyName: String = "from_csv" + + override protected def withNewChildInternal(newChild: Expression): CsvToStructs = + copy(child = newChild) } /** @@ -197,6 +200,9 @@ case class SchemaOfCsv( } override def prettyName: String = "schema_of_csv" + + override protected def withNewChildInternal(newChild: Expression): SchemaOfCsv = + copy(child = newChild) } /** @@ -264,4 +270,7 @@ case class StructsToCsv( override def inputTypes: Seq[AbstractDataType] = StructType :: Nil override def prettyName: String = "to_csv" + + override protected def withNewChildInternal(newChild: Expression): StructsToCsv = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 355064e73dfab..ba9d458c0ae5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -251,6 +251,9 @@ case class DateAdd(startDate: Expression, days: Expression) } override def prettyName: String = "date_add" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateAdd = copy(startDate = newLeft, days = newRight) } /** @@ -286,6 +289,9 @@ case class DateSub(startDate: Expression, days: Expression) } override def prettyName: String = "date_sub" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateSub = copy(startDate = newLeft, days = newRight) } trait GetTimeField extends UnaryExpression @@ -323,6 +329,7 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) extends Ge override def withTimeZone(timeZoneId: String): Hour = copy(timeZoneId = Option(timeZoneId)) override val func = DateTimeUtils.getHours override val funcName = "getHours" + override protected def withNewChildInternal(newChild: Expression): Hour = copy(child = newChild) } @ExpressionDescription( @@ -339,6 +346,7 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) extends override def withTimeZone(timeZoneId: String): Minute = copy(timeZoneId = Option(timeZoneId)) override val func = DateTimeUtils.getMinutes override val funcName = "getMinutes" + override protected def withNewChildInternal(newChild: Expression): Minute = copy(child = newChild) } @ExpressionDescription( @@ -355,6 +363,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) extends override def withTimeZone(timeZoneId: String): Second = copy(timeZoneId = Option(timeZoneId)) override val func = DateTimeUtils.getSeconds override val funcName = "getSeconds" + override protected def withNewChildInternal(newChild: Expression): Second = + copy(child = newChild) } case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None) @@ -366,6 +376,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No copy(timeZoneId = Option(timeZoneId)) override val func = DateTimeUtils.getSecondsWithFraction override val funcName = "getSecondsWithFraction" + override protected def withNewChildInternal(newChild: Expression): SecondWithFraction = + copy(child = newChild) } trait GetDateField extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { @@ -398,6 +410,8 @@ trait GetDateField extends UnaryExpression with ImplicitCastInputTypes with Null case class DayOfYear(child: Expression) extends GetDateField { override val func = DateTimeUtils.getDayInYear override val funcName = "getDayInYear" + override protected def withNewChildInternal(newChild: Expression): DayOfYear = + copy(child = newChild) } @ExpressionDescription( @@ -421,6 +435,9 @@ case class DateFromUnixDate(child: Expression) extends UnaryExpression defineCodeGen(ctx, ev, c => c) override def prettyName: String = "date_from_unix_date" + + override protected def withNewChildInternal(newChild: Expression): DateFromUnixDate = + copy(child = newChild) } @ExpressionDescription( @@ -444,6 +461,9 @@ case class UnixDate(child: Expression) extends UnaryExpression defineCodeGen(ctx, ev, c => c) override def prettyName: String = "unix_date" + + override protected def withNewChildInternal(newChild: Expression): UnixDate = + copy(child = newChild) } abstract class IntegralToTimestampBase extends UnaryExpression @@ -531,6 +551,9 @@ case class SecondsToTimestamp(child: Expression) extends UnaryExpression } override def prettyName: String = "timestamp_seconds" + + override protected def withNewChildInternal(newChild: Expression): SecondsToTimestamp = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -550,6 +573,9 @@ case class MillisToTimestamp(child: Expression) override def upScaleFactor: Long = MICROS_PER_MILLIS override def prettyName: String = "timestamp_millis" + + override protected def withNewChildInternal(newChild: Expression): MillisToTimestamp = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -569,6 +595,9 @@ case class MicrosToTimestamp(child: Expression) override def upScaleFactor: Long = 1L override def prettyName: String = "timestamp_micros" + + override protected def withNewChildInternal(newChild: Expression): MicrosToTimestamp = + copy(child = newChild) } abstract class TimestampToLongBase extends UnaryExpression @@ -608,6 +637,9 @@ case class UnixSeconds(child: Expression) extends TimestampToLongBase { override def scaleFactor: Long = MICROS_PER_SECOND override def prettyName: String = "unix_seconds" + + override protected def withNewChildInternal(newChild: Expression): UnixSeconds = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -625,6 +657,9 @@ case class UnixMillis(child: Expression) extends TimestampToLongBase { override def scaleFactor: Long = MICROS_PER_MILLIS override def prettyName: String = "unix_millis" + + override protected def withNewChildInternal(newChild: Expression): UnixMillis = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -642,6 +677,9 @@ case class UnixMicros(child: Expression) extends TimestampToLongBase { override def scaleFactor: Long = 1L override def prettyName: String = "unix_micros" + + override protected def withNewChildInternal(newChild: Expression): UnixMicros = + copy(child = newChild) } @ExpressionDescription( @@ -656,11 +694,15 @@ case class UnixMicros(child: Expression) extends TimestampToLongBase { case class Year(child: Expression) extends GetDateField { override val func = DateTimeUtils.getYear override val funcName = "getYear" + override protected def withNewChildInternal(newChild: Expression): Year = + copy(child = newChild) } case class YearOfWeek(child: Expression) extends GetDateField { override val func = DateTimeUtils.getWeekBasedYear override val funcName = "getWeekBasedYear" + override protected def withNewChildInternal(newChild: Expression): YearOfWeek = + copy(child = newChild) } @ExpressionDescription( @@ -675,6 +717,8 @@ case class YearOfWeek(child: Expression) extends GetDateField { case class Quarter(child: Expression) extends GetDateField { override val func = DateTimeUtils.getQuarter override val funcName = "getQuarter" + override protected def withNewChildInternal(newChild: Expression): Quarter = + copy(child = newChild) } @ExpressionDescription( @@ -689,6 +733,7 @@ case class Quarter(child: Expression) extends GetDateField { case class Month(child: Expression) extends GetDateField { override val func = DateTimeUtils.getMonth override val funcName = "getMonth" + override protected def withNewChildInternal(newChild: Expression): Month = copy(child = newChild) } @ExpressionDescription( @@ -703,6 +748,8 @@ case class Month(child: Expression) extends GetDateField { case class DayOfMonth(child: Expression) extends GetDateField { override val func = DateTimeUtils.getDayOfMonth override val funcName = "getDayOfMonth" + override protected def withNewChildInternal(newChild: Expression): DayOfMonth = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -719,6 +766,8 @@ case class DayOfMonth(child: Expression) extends GetDateField { case class DayOfWeek(child: Expression) extends GetDateField { override val func = DateTimeUtils.getDayOfWeek override val funcName = "getDayOfWeek" + override protected def withNewChildInternal(newChild: Expression): DayOfWeek = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -735,6 +784,8 @@ case class DayOfWeek(child: Expression) extends GetDateField { case class WeekDay(child: Expression) extends GetDateField { override val func = DateTimeUtils.getWeekDay override val funcName = "getWeekDay" + override protected def withNewChildInternal(newChild: Expression): WeekDay = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -751,6 +802,8 @@ case class WeekDay(child: Expression) extends GetDateField { case class WeekOfYear(child: Expression) extends GetDateField { override val func = DateTimeUtils.getWeekOfYear override val funcName = "getWeekOfYear" + override protected def withNewChildInternal(newChild: Expression): WeekOfYear = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -814,6 +867,10 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override protected def formatString: Expression = right override protected def isParsing: Boolean = false + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateFormatClass = + copy(left = newLeft, right = newRight) } /** @@ -859,6 +916,10 @@ case class ToUnixTimestamp( } override def prettyName: String = "to_unix_timestamp" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ToUnixTimestamp = + copy(timeExp = newLeft, format = newRight) } // scalastyle:off line.size.limit @@ -915,6 +976,10 @@ case class UnixTimestamp( } override def prettyName: String = "unix_timestamp" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): UnixTimestamp = + copy(timeExp = newLeft, format = newRight) } abstract class ToTimestamp @@ -1120,6 +1185,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override protected def formatString: Expression = format override protected def isParsing: Boolean = false + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): FromUnixTime = + copy(sec = newLeft, format = newRight) } /** @@ -1152,6 +1221,9 @@ case class LastDay(startDate: Expression) } override def prettyName: String = "last_day" + + override protected def withNewChildInternal(newChild: Expression): LastDay = + copy(startDate = newChild) } /** @@ -1249,6 +1321,10 @@ case class NextDay( } override def prettyName: String = "next_day" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): NextDay = + copy(startDate = newLeft, dayOfWeek = newRight) } /** @@ -1292,6 +1368,10 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S }) } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TimeAdd = + copy(start = newLeft, interval = newRight) } /** @@ -1305,6 +1385,8 @@ case class DatetimeSub( override def exprsReplaced: Seq[Expression] = Seq(start, interval) override def toString: String = s"$start - $interval" override def mkString(childrenString: Seq[String]): String = childrenString.mkString(" - ") + override protected def withNewChildInternal(newChild: Expression): DatetimeSub = + copy(child = newChild) } /** @@ -1367,6 +1449,10 @@ case class DateAddInterval( override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateAddInterval = + copy(start = newLeft, interval = newRight) } sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { @@ -1447,6 +1533,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) extends UTCTime override val func = DateTimeUtils.fromUTCTime override val funcName: String = "fromUTCTime" override val prettyName: String = "from_utc_timestamp" + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): FromUTCTimestamp = + copy(left = newLeft, right = newRight) } /** @@ -1478,6 +1567,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) extends UTCTimest override val func = DateTimeUtils.toUTCTime override val funcName: String = "toUTCTime" override val prettyName: String = "to_utc_timestamp" + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ToUTCTimestamp = + copy(left = newLeft, right = newRight) } abstract class AddMonthsBase extends BinaryExpression with ImplicitCastInputTypes @@ -1517,6 +1609,10 @@ case class AddMonths(startDate: Expression, numMonths: Expression) extends AddMo override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) override def prettyName: String = "add_months" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): AddMonths = + copy(startDate = newLeft, numMonths = newRight) } // Adds the year-month interval to the date @@ -1528,6 +1624,10 @@ case class DateAddYMInterval(date: Expression, interval: Expression) extends Add override def toString: String = s"$left + $right" override def sql: String = s"${left.sql} + ${right.sql}" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateAddYMInterval = + copy(date = newLeft, interval = newRight) } // Adds the year-month interval to the timestamp @@ -1562,6 +1662,10 @@ case class TimestampAddYMInterval( s"""$dtu.timestampAddMonths($micros, $months, $zid)""" }) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TimestampAddYMInterval = + copy(timestamp = newLeft, interval = newRight) } /** @@ -1628,6 +1732,10 @@ case class MonthsBetween( } override def prettyName: String = "months_between" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): MonthsBetween = + copy(date1 = newFirst, date2 = newSecond, roundOff = newThird) } /** @@ -1672,6 +1780,9 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr override def flatArguments: Iterator[Any] = Iterator(left, format) override def prettyName: String = "to_date" + + override protected def withNewChildInternal(newChild: Expression): ParseToDate = + copy(child = newChild) } /** @@ -1714,6 +1825,9 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child: override def prettyName: String = "to_timestamp" override def dataType: DataType = TimestampType + + override protected def withNewChildInternal(newChild: Expression): ParseToTimestamp = + copy(child = newChild) } trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { @@ -1849,6 +1963,10 @@ case class TruncDate(date: Expression, format: Expression) (date: String, fmt: String) => s"truncDate($date, $fmt);" } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TruncDate = + copy(date = newLeft, format = newRight) } /** @@ -1920,6 +2038,10 @@ case class TruncTimestamp( s"truncTimestamp($date, $fmt, $zid);" } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TruncTimestamp = + copy(format = newLeft, timestamp = newRight) } /** @@ -1952,6 +2074,10 @@ case class DateDiff(endDate: Expression, startDate: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (end, start) => s"$end - $start") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateDiff = + copy(endDate = newLeft, startDate = newRight) } /** @@ -1969,6 +2095,10 @@ private case class GetTimestamp( override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GetTimestamp = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -2032,6 +2162,10 @@ case class MakeDate( } override def prettyName: String = "make_date" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): MakeDate = + copy(year = newFirst, month = newSecond, day = newThird) } // scalastyle:off line.size.limit @@ -2198,6 +2332,20 @@ case class MakeTimestamp( } override def prettyName: String = "make_timestamp" + +// override def children: Seq[Expression] = Seq(year, month, day, hour, min, sec) ++ timezone + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): MakeTimestamp = { + val timezoneOpt = if (timezone.isDefined) Some(newChildren(6)) else None + copy( + year = newChildren(0), + month = newChildren(1), + day = newChildren(2), + hour = newChildren(3), + min = newChildren(4), + sec = newChildren(5), + timezone = timezoneOpt) + } } object DatePart { @@ -2284,6 +2432,9 @@ case class DatePart(field: Expression, source: Expression, child: Expression) override def exprsReplaced: Seq[Expression] = Seq(field, source) override def prettyName: String = "date_part" + + override protected def withNewChildInternal(newChild: Expression): DatePart = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -2349,6 +2500,9 @@ case class Extract(field: Expression, source: Expression, child: Expression) override def mkString(childrenString: Seq[String]): String = { prettyName + childrenString.mkString("(", " FROM ", ")") } + + override protected def withNewChildInternal(newChild: Expression): Extract = + copy(child = newChild) } /** @@ -2401,6 +2555,10 @@ case class SubtractTimestamps( defineCodeGen(ctx, ev, (end, start) => s"new org.apache.spark.unsafe.types.CalendarInterval(0, 0, $end - $start)") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): SubtractTimestamps = + copy(left = newLeft, right = newRight) } object SubtractTimestamps { @@ -2452,6 +2610,10 @@ case class SubtractDates( s"$dtu.subtractDates($leftDays, $rightDays)" }) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): SubtractDates = + copy(left = newLeft, right = newRight) } object SubtractDates { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index b987beda6407e..7165bca201a9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -40,6 +40,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression with NullInt override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } + + override protected def withNewChildInternal(newChild: Expression): UnscaledValue = + copy(child = newChild) } /** @@ -89,6 +92,9 @@ case class MakeDecimal( |""".stripMargin }) } + + override protected def withNewChildInternal(newChild: Expression): MakeDecimal = + copy(child = newChild) } object MakeDecimal { @@ -111,6 +117,9 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) } /** @@ -145,6 +154,9 @@ case class CheckOverflow( override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)" override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflow = + copy(child = newChild) } // A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. @@ -194,4 +206,7 @@ case class CheckOverflowInSum( override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index f10ceea519cce..fef9bb338d834 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -118,6 +118,9 @@ case class UserDefinedGenerator( } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UserDefinedGenerator = copy(children = newChildren) } /** @@ -227,6 +230,9 @@ case class Stack(children: Seq[Expression]) extends Generator { |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Stack = + copy(children = newChildren) } /** @@ -253,6 +259,9 @@ case class ReplicateRows(children: Seq[Expression]) extends Generator with Codeg InternalRow(fields: _*) } } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ReplicateRows = copy(children = newChildren) } /** @@ -269,6 +278,9 @@ case class GeneratorOuter(child: Generator) extends UnaryExpression with Generat override def elementSchema: StructType = child.elementSchema override lazy val resolved: Boolean = false + + override protected def withNewChildInternal(newChild: Expression): GeneratorOuter = + copy(child = newChild.asInstanceOf[Generator]) } /** @@ -369,6 +381,8 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with // scalastyle:on line.size.limit case class Explode(child: Expression) extends ExplodeBase { override val position: Boolean = false + override protected def withNewChildInternal(newChild: Expression): Explode = + copy(child = newChild) } /** @@ -394,6 +408,8 @@ case class Explode(child: Expression) extends ExplodeBase { // scalastyle:on line.size.limit line.contains.tab case class PosExplode(child: Expression) extends ExplodeBase { override val position = true + override protected def withNewChildInternal(newChild: Expression): PosExplode = + copy(child = newChild) } /** @@ -445,4 +461,6 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.genCode(ctx) } + + override protected def withNewChildInternal(newChild: Expression): Inline = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index bf28efabcd561..0dd82bed15082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -111,6 +111,8 @@ case class Cube( children: Seq[Expression]) extends BaseGroupingSets { override def groupingSets: Seq[Seq[Expression]] = groupingSetIndexes.map(_.map(children)) override def selectedGroupByExprs: Seq[Seq[Expression]] = BaseGroupingSets.cubeExprs(groupingSets) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Cube = + copy(children = newChildren) } object Cube { @@ -125,6 +127,8 @@ case class Rollup( override def groupingSets: Seq[Seq[Expression]] = groupingSetIndexes.map(_.map(children)) override def selectedGroupByExprs: Seq[Seq[Expression]] = BaseGroupingSets.rollupExprs(groupingSets) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Rollup = + copy(children = newChildren) } object Rollup { @@ -142,6 +146,9 @@ case class GroupingSets( // Includes the `userGivenGroupByExprs` in the children, which will be included in the final // GROUP BY expressions, so that `SELECT c ... GROUP BY (a, b, c) GROUPING SETS (a, b)` works. override def children: Seq[Expression] = flatGroupingSets ++ userGivenGroupByExprs + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): GroupingSets = + super.legacyWithNewChildren(newChildren).asInstanceOf[GroupingSets] } object GroupingSets { @@ -184,6 +191,8 @@ case class Grouping(child: Expression) extends Expression with Unevaluable AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) override def dataType: DataType = ByteType override def nullable: Boolean = false + override protected def withNewChildInternal(newChild: Expression): Grouping = + copy(child = newChild) } /** @@ -223,6 +232,8 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une override def dataType: DataType = GroupingID.dataType override def nullable: Boolean = false override def prettyName: String = "grouping_id" + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): GroupingID = + copy(groupByExprs = newChildren) } object GroupingID { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 9738559b6d67a..f23c1e56ce4e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -69,6 +69,8 @@ case class Md5(child: Expression) defineCodeGen(ctx, ev, c => s"UTF8String.fromString(${classOf[DigestUtils].getName}.md5Hex($c))") } + + override protected def withNewChildInternal(newChild: Expression): Md5 = copy(child = newChild) } /** @@ -152,6 +154,9 @@ case class Sha2(left: Expression, right: Expression) """ }) } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Sha2 = + copy(left = newLeft, right = newRight) } /** @@ -182,6 +187,8 @@ case class Sha1(child: Expression) s"UTF8String.fromString(${classOf[DigestUtils].getName}.sha1Hex($c))" ) } + + override protected def withNewChildInternal(newChild: Expression): Sha1 = copy(child = newChild) } /** @@ -221,6 +228,8 @@ case class Crc32(child: Expression) """ }) } + + override protected def withNewChildInternal(newChild: Expression): Crc32 = copy(child = newChild) } @@ -598,6 +607,9 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpress override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { Murmur3HashFunction.hash(value, dataType, seed).toInt } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Murmur3Hash = + copy(children = newChildren) } object Murmur3HashFunction extends InterpretedHashFunction { @@ -638,6 +650,9 @@ case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpressio override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { XxHash64Function.hash(value, dataType, seed) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): XxHash64 = + copy(children = newChildren) } object XxHash64Function extends InterpretedHashFunction { @@ -842,6 +857,9 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |$code """.stripMargin } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): HiveHash = + copy(children = newChildren) } object HiveHashFunction extends InterpretedHashFunction { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index bbfdf7135824c..a0f9dc2f58b20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -103,6 +103,12 @@ case class LambdaFunction( lazy val bound: Boolean = arguments.forall(_.resolved) override def eval(input: InternalRow): Any = function.eval(input) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): LambdaFunction = + copy( + function = newChildren.head, + arguments = newChildren.tail.asInstanceOf[Seq[NamedExpression]]) } object LambdaFunction { @@ -219,6 +225,7 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr nullSafeEval(inputRow, value) } } + } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { @@ -289,6 +296,10 @@ case class ArrayTransform( } override def prettyName: String = "transform" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayTransform = + copy(argument = newLeft, function = newRight) } /** @@ -378,6 +389,10 @@ case class ArraySort( } override def prettyName: String = "array_sort" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArraySort = + copy(argument = newLeft, function = newRight) } object ArraySort { @@ -448,6 +463,10 @@ case class MapFilter( override def functionType: AbstractDataType = BooleanType override def prettyName: String = "map_filter" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MapFilter = + copy(argument = newLeft, function = newRight) } /** @@ -513,6 +532,10 @@ case class ArrayFilter( } override def prettyName: String = "filter" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayFilter = + copy(argument = newLeft, function = newRight) } /** @@ -594,6 +617,10 @@ case class ArrayExists( } override def prettyName: String = "exists" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayExists = + copy(argument = newLeft, function = newRight) } object ArrayExists { @@ -670,6 +697,10 @@ case class ArrayForAll( } override def prettyName: String = "forall" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayForAll = + copy(argument = newLeft, function = newRight) } /** @@ -767,6 +798,10 @@ case class ArrayAggregate( override def second: Expression = zero override def third: Expression = merge override def fourth: Expression = finish + + override protected def withNewChildrenInternal(first: Expression, second: Expression, + third: Expression, fourth: Expression): ArrayAggregate = + copy(argument = first, zero = second, merge = third, finish = fourth) } /** @@ -802,7 +837,7 @@ case class TransformKeys( } @transient lazy val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + _, Seq(keyVar: NamedLambdaVariable, valueVar: NamedLambdaVariable), _) = function private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) @@ -821,6 +856,10 @@ case class TransformKeys( } override def prettyName: String = "transform_keys" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TransformKeys = + copy(argument = newLeft, function = newRight) } /** @@ -852,7 +891,7 @@ case class TransformValues( } @transient lazy val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + _, Seq(keyVar: NamedLambdaVariable, valueVar: NamedLambdaVariable), _) = function override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] @@ -869,6 +908,10 @@ case class TransformValues( } override def prettyName: String = "transform_values" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TransformValues = + copy(argument = newLeft, function = newRight) } /** @@ -1056,6 +1099,13 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) override def first: Expression = left override def second: Expression = right override def third: Expression = function + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): MapZipWith = + copy( + left = newFirst, + right = newSecond, + function = newThird) } // scalastyle:off line.size.limit @@ -1136,4 +1186,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression) override def first: Expression = left override def second: Expression = right override def third: Expression = function + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): ZipWith = + copy(left = newFirst, right = newSecond, function = newThird) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 23cf0bcafbe10..4311b38bdc78c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -49,22 +49,40 @@ abstract class ExtractIntervalPart( } case class ExtractIntervalYears(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") + extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalYears = + copy(child = newChild) +} case class ExtractIntervalMonths(child: Expression) - extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") + extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMonths = + copy(child = newChild) +} case class ExtractIntervalDays(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") + extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalDays = + copy(child = newChild) +} case class ExtractIntervalHours(child: Expression) - extends ExtractIntervalPart(child, LongType, getHours, "getHours") + extends ExtractIntervalPart(child, LongType, getHours, "getHours") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalHours = + copy(child = newChild) +} case class ExtractIntervalMinutes(child: Expression) - extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") + extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMinutes = + copy(child = newChild) +} case class ExtractIntervalSeconds(child: Expression) - extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") + extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalSeconds = + copy(child = newChild) +} object ExtractIntervalPart { @@ -119,6 +137,10 @@ case class MultiplyInterval( if (failOnError) multiplyExact else multiply override protected def operationName: String = if (failOnError) "multiplyExact" else "multiply" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MultiplyInterval = + copy(interval = newLeft, num = newRight) } case class DivideInterval( @@ -131,6 +153,10 @@ case class DivideInterval( if (failOnError) divideExact else divide override protected def operationName: String = if (failOnError) "divideExact" else "divide" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DivideInterval = + copy(interval = newLeft, num = newRight) } // scalastyle:off line.size.limit @@ -251,6 +277,19 @@ case class MakeInterval( } override def prettyName: String = "make_interval" + + // Seq(years, months, weeks, days, hours, mins, secs) + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): MakeInterval = + copy( + years = newChildren(0), + months = newChildren(1), + weeks = newChildren(2), + days = newChildren(3), + hours = newChildren(4), + mins = newChildren(5), + secs = newChildren(6) + ) } // Multiply an year-month interval by a numeric @@ -298,6 +337,10 @@ case class MultiplyYMInterval( } override def toString: String = s"($left * $right)" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MultiplyYMInterval = + copy(interval = newLeft, num = newRight) } // Multiply a day-time interval by a numeric @@ -340,6 +383,10 @@ case class MultiplyDTInterval( } override def toString: String = s"($left * $right)" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MultiplyDTInterval = + copy(interval = newLeft, num = newRight) } // Divide an year-month interval by a numeric @@ -394,6 +441,10 @@ case class DivideYMInterval( } override def toString: String = s"($left / $right)" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DivideYMInterval = + copy(interval = newLeft, num = newRight) } // Divide a day-time interval by a numeric @@ -437,4 +488,8 @@ case class DivideDTInterval( } override def toString: String = s"($left / $right)" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DivideDTInterval = + copy(interval = newLeft, num = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index b217110f075a7..6a56bbf1916bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -335,6 +335,10 @@ case class GetJsonObject(json: Expression, path: Expression) false } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GetJsonObject = + copy(json = newLeft, path = newRight) } // scalastyle:off line.size.limit line.contains.tab @@ -498,6 +502,9 @@ case class JsonTuple(children: Seq[Expression]) generator.copyCurrentStructure(parser) } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): JsonTuple = + copy(children = newChildren) } /** @@ -609,6 +616,9 @@ case class JsonToStructs( } override def prettyName: String = "from_json" + + override protected def withNewChildInternal(newChild: Expression): JsonToStructs = + copy(child = newChild) } /** @@ -731,6 +741,9 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil override def prettyName: String = "to_json" + + override protected def withNewChildInternal(newChild: Expression): StructsToJson = + copy(child = newChild) } /** @@ -805,6 +818,9 @@ case class SchemaOfJson( } override def prettyName: String = "schema_of_json" + + override protected def withNewChildInternal(newChild: Expression): SchemaOfJson = + copy(child = newChild) } /** @@ -874,6 +890,9 @@ case class LengthOfJsonArray(child: Expression) extends UnaryExpression } length } + + override protected def withNewChildInternal(newChild: Expression): LengthOfJsonArray = + copy(child = newChild) } /** @@ -943,4 +962,7 @@ case class JsonObjectKeys(child: Expression) extends UnaryExpression with Codege } new GenericArrayData(arrayBufferOfKeys.toArray) } + + override protected def withNewChildInternal(newChild: Expression): JsonObjectKeys = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 3b58f3d868d3c..516eeb9929e80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -187,7 +187,9 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") """, since = "1.4.0", group = "math_funcs") -case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") +case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") { + override protected def withNewChildInternal(newChild: Expression): Acos = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -203,7 +205,9 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS" """, since = "1.4.0", group = "math_funcs") -case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") +case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") { + override protected def withNewChildInternal(newChild: Expression): Asin = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -217,7 +221,9 @@ case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN" """, since = "1.4.0", group = "math_funcs") -case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") +case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") { + override protected def withNewChildInternal(newChild: Expression): Atan = copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns the cube root of `expr`.", @@ -228,7 +234,9 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN" """, since = "1.4.0", group = "math_funcs") -case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") +case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") { + override protected def withNewChildInternal(newChild: Expression): Cbrt = copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns the smallest integer not smaller than `expr`.", @@ -267,6 +275,8 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } + + override protected def withNewChildInternal(newChild: Expression): Ceil = copy(child = newChild) } @ExpressionDescription( @@ -285,7 +295,9 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" """, since = "1.4.0", group = "math_funcs") -case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") +case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") { + override protected def withNewChildInternal(newChild: Expression): Cos = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -303,7 +315,9 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") """, since = "1.4.0", group = "math_funcs") -case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") +case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") { + override protected def withNewChildInternal(newChild: Expression): Cosh = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -324,6 +338,7 @@ case class Acosh(child: Expression) defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c - 1.0))") } + override protected def withNewChildInternal(newChild: Expression): Acosh = copy(child = newChild) } /** @@ -372,6 +387,10 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre """ ) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird) } @ExpressionDescription( @@ -387,6 +406,7 @@ case class Exp(child: Expression) extends UnaryMathExpression(StrictMath.exp, "E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.exp($c)") } + override protected def withNewChildInternal(newChild: Expression): Exp = copy(child = newChild) } @ExpressionDescription( @@ -402,6 +422,7 @@ case class Expm1(child: Expression) extends UnaryMathExpression(StrictMath.expm1 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.expm1($c)") } + override protected def withNewChildInternal(newChild: Expression): Expm1 = copy(child = newChild) } @ExpressionDescription( @@ -441,6 +462,8 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } + + override protected def withNewChildInternal(newChild: Expression): Floor = copy(child = newChild) } object Factorial { @@ -514,6 +537,9 @@ case class Factorial(child: Expression) """ }) } + + override protected def withNewChildInternal(newChild: Expression): Factorial = + copy(child = newChild) } @ExpressionDescription( @@ -527,6 +553,7 @@ case class Factorial(child: Expression) group = "math_funcs") case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") { override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("ln") + override protected def withNewChildInternal(newChild: Expression): Log = copy(child = newChild) } @ExpressionDescription( @@ -551,6 +578,7 @@ case class Log2(child: Expression) """ ) } + override protected def withNewChildInternal(newChild: Expression): Log2 = copy(child = newChild) } @ExpressionDescription( @@ -562,7 +590,9 @@ case class Log2(child: Expression) """, since = "1.4.0", group = "math_funcs") -case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, "LOG10") +case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, "LOG10") { + override protected def withNewChildInternal(newChild: Expression): Log10 = copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns log(1 + `expr`).", @@ -575,6 +605,7 @@ case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, group = "math_funcs") case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p, "LOG1P") { protected override val yAsymptote: Double = -1.0 + override protected def withNewChildInternal(newChild: Expression): Log1p = copy(child = newChild) } // scalastyle:off line.size.limit @@ -591,6 +622,7 @@ case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p, case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rint") + override protected def withNewChildInternal(newChild: Expression): Rint = copy(child = newChild) } @ExpressionDescription( @@ -602,7 +634,9 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND """, since = "1.4.0", group = "math_funcs") -case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") +case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") { + override protected def withNewChildInternal(newChild: Expression): Signum = copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.", @@ -617,7 +651,9 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S """, since = "1.4.0", group = "math_funcs") -case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") +case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") { + override protected def withNewChildInternal(newChild: Expression): Sin = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -634,7 +670,9 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") """, since = "1.4.0", group = "math_funcs") -case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") { + override protected def withNewChildInternal(newChild: Expression): Sinh = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -656,6 +694,7 @@ case class Asinh(child: Expression) s"$c == Double.NEGATIVE_INFINITY ? Double.NEGATIVE_INFINITY : " + s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c + 1.0))") } + override protected def withNewChildInternal(newChild: Expression): Asinh = copy(child = newChild) } @ExpressionDescription( @@ -667,7 +706,9 @@ case class Asinh(child: Expression) """, since = "1.1.1", group = "math_funcs") -case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") +case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") { + override protected def withNewChildInternal(newChild: Expression): Sqrt = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -684,7 +725,9 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT" """, since = "1.4.0", group = "math_funcs") -case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") +case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") { + override protected def withNewChildInternal(newChild: Expression): Tan = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -706,6 +749,7 @@ case class Cot(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.tan($c);") } + override protected def withNewChildInternal(newChild: Expression): Cot = copy(child = newChild) } @ExpressionDescription( @@ -724,7 +768,9 @@ case class Cot(child: Expression) """, since = "1.4.0", group = "math_funcs") -case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") +case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") { + override protected def withNewChildInternal(newChild: Expression): Tanh = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -747,6 +793,7 @@ case class Atanh(child: Expression) defineCodeGen(ctx, ev, c => s"0.5 * (java.lang.StrictMath.log1p($c) - java.lang.StrictMath.log1p(- $c))") } + override protected def withNewChildInternal(newChild: Expression): Atanh = copy(child = newChild) } @ExpressionDescription( @@ -764,6 +811,8 @@ case class Atanh(child: Expression) group = "math_funcs") case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { override def funcName: String = "toDegrees" + override protected def withNewChildInternal(newChild: Expression): ToDegrees = + copy(child = newChild) } @ExpressionDescription( @@ -781,6 +830,8 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre group = "math_funcs") case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { override def funcName: String = "toRadians" + override protected def withNewChildInternal(newChild: Expression): ToRadians = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -811,6 +862,8 @@ case class Bin(child: Expression) defineCodeGen(ctx, ev, (c) => s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } + + override protected def withNewChildInternal(newChild: Expression): Bin = copy(child = newChild) } object Hex { @@ -923,6 +976,8 @@ case class Hex(child: Expression) }) }) } + + override protected def withNewChildInternal(newChild: Expression): Hex = copy(child = newChild) } /** @@ -958,6 +1013,8 @@ case class Unhex(child: Expression) """ }) } + + override protected def withNewChildInternal(newChild: Expression): Unhex = copy(child = newChild) } @@ -996,6 +1053,9 @@ case class Atan2(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1012,6 +1072,8 @@ case class Pow(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.StrictMath.pow($c1, $c2)") } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @@ -1048,6 +1110,9 @@ case class ShiftLeft(left: Expression, right: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left << $right") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ShiftLeft = copy(left = newLeft, right = newRight) } @@ -1084,6 +1149,9 @@ case class ShiftRight(left: Expression, right: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ShiftRight = copy(left = newLeft, right = newRight) } @@ -1120,6 +1188,10 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ShiftRightUnsigned = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1132,7 +1204,10 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) since = "1.4.0", group = "math_funcs") case class Hypot(left: Expression, right: Expression) - extends BinaryMathExpression(math.hypot, "HYPOT") + extends BinaryMathExpression(math.hypot, "HYPOT") { + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Hypot = + copy(left = newLeft, right = newRight) +} /** @@ -1190,6 +1265,9 @@ case class Logarithm(left: Expression, right: Expression) """) } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Logarithm = copy(left = newLeft, right = newRight) } /** @@ -1387,6 +1465,8 @@ case class Round(child: Expression, scale: Expression) extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP") with Serializable with ImplicitCastInputTypes { def this(child: Expression) = this(child, Literal(0)) + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Round = + copy(child = newLeft, scale = newRight) } /** @@ -1409,6 +1489,8 @@ case class BRound(child: Expression, scale: Expression) extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN") with Serializable with ImplicitCastInputTypes { def this(child: Expression) = this(child, Literal(0)) + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft, scale = newRight) } object WidthBucket { @@ -1511,4 +1593,8 @@ case class WidthBucket( override def second: Expression = minValue override def third: Expression = maxValue override def fourth: Expression = numBucket + + override protected def withNewChildrenInternal( + first: Expression, second: Expression, third: Expression, fourth: Expression): WidthBucket = + copy(value = first, minValue = second, maxValue = third, numBucket = fourth) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6b3b949af24cf..9e854cf5fd891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -51,6 +51,9 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { | ${ev.value} = $c; """.stripMargin) } + + override protected def withNewChildInternal(newChild: Expression): PrintToStderr = + copy(child = newChild) } /** @@ -100,6 +103,9 @@ case class RaiseError(child: Expression, dataType: DataType) value = JavaCode.defaultLiteral(dataType) ) } + + override protected def withNewChildInternal(newChild: Expression): RaiseError = + copy(child = newChild) } object RaiseError { @@ -133,6 +139,9 @@ case class AssertTrue(left: Expression, right: Expression, child: Expression) override def flatArguments: Iterator[Any] = Iterator(left, right) override def exprsReplaced: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): AssertTrue = + copy(child = newChild) } object AssertTrue { @@ -268,4 +277,6 @@ case class TypeOf(child: Expression) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, _ => s"""UTF8String.fromString(${child.dataType.catalogString})""") } + + override protected def withNewChildInternal(newChild: Expression): TypeOf = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e73b024dd18c2..b73a189027bfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -226,6 +226,9 @@ case class Alias(child: Expression, name: String)( if (qualifier.nonEmpty) qualifier.map(quoteIfNeeded).mkString(".") + "." else "" s"${child.sql} AS $qualifierPrefix${quoteIfNeeded(name)}" } + + override protected def withNewChildInternal(newChild: Expression): Alias = + copy(child = newChild)(exprId, qualifier, explicitMetadata, nonInheritableMetadataKeys) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index d508129c190b9..2c2df6bf438b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -120,6 +120,9 @@ case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpress |} while (false); """.stripMargin) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Coalesce = + copy(children = newChildren) } @@ -141,6 +144,8 @@ case class IfNull(left: Expression, right: Expression, child: Expression) override def flatArguments: Iterator[Any] = Iterator(left, right) override def exprsReplaced: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): IfNull = copy(child = newChild) } @@ -162,6 +167,8 @@ case class NullIf(left: Expression, right: Expression, child: Expression) override def flatArguments: Iterator[Any] = Iterator(left, right) override def exprsReplaced: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): NullIf = copy(child = newChild) } @@ -182,6 +189,8 @@ case class Nvl(left: Expression, right: Expression, child: Expression) extends R override def flatArguments: Iterator[Any] = Iterator(left, right) override def exprsReplaced: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): Nvl = copy(child = newChild) } @@ -205,6 +214,8 @@ case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child: override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3) override def exprsReplaced: Seq[Expression] = Seq(expr1, expr2, expr3) + + override protected def withNewChildInternal(newChild: Expression): Nvl2 = copy(child = newChild) } @@ -249,6 +260,8 @@ case class IsNaN(child: Expression) extends UnaryExpression ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } + + override protected def withNewChildInternal(newChild: Expression): IsNaN = copy(child = newChild) } /** @@ -311,6 +324,9 @@ case class NaNvl(left: Expression, right: Expression) }""") } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): NaNvl = + copy(left = newLeft, right = newRight) } @@ -339,6 +355,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { } override def sql: String = s"(${child.sql} IS NULL)" + + override protected def withNewChildInternal(newChild: Expression): IsNull = copy(child = newChild) } @@ -374,6 +392,9 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } override def sql: String = s"(${child.sql} IS NOT NULL)" + + override protected def withNewChildInternal(newChild: Expression): IsNotNull = + copy(child = newChild) } @@ -466,4 +487,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = FalseLiteral) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): AtLeastNNonNulls = copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5be521683381d..5ae0cef7b400c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.{Method, Modifier} import scala.collection.JavaConverters._ -import scala.collection.mutable.{Builder, IndexedSeq, WrappedArray} +import scala.collection.mutable.{Builder, WrappedArray} import scala.reflect.ClassTag import scala.util.{Properties, Try} @@ -279,6 +279,9 @@ case class StaticInvoke( """ ev.copy(code = code) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(arguments = newChildren) } /** @@ -400,6 +403,9 @@ case class Invoke( } override def toString: String = s"$targetObject.$functionName" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Invoke = + copy(targetObject = newChildren.head, arguments = newChildren.tail) } object NewInstance { @@ -506,6 +512,9 @@ case class NewInstance( } override def toString: String = s"newInstance($cls)" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): NewInstance = + copy(arguments = newChildren) } /** @@ -543,6 +552,9 @@ case class UnwrapOption( """ ev.copy(code = code) } + + override protected def withNewChildInternal(newChild: Expression): UnwrapOption = + copy(child = newChild) } /** @@ -573,6 +585,9 @@ case class WrapOption(child: Expression, optType: DataType) """ ev.copy(code = code, isNull = FalseLiteral) } + + override protected def withNewChildInternal(newChild: Expression): WrapOption = + copy(child = newChild) } object LambdaVariable { @@ -659,6 +674,9 @@ case class UnresolvedMapObjects( override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { throw QueryExecutionErrors.customCollectionClsNotResolvedError } + + override protected def withNewChildInternal(newChild: Expression): UnresolvedMapObjects = + copy(child = newChild) } object MapObjects { @@ -1025,6 +1043,13 @@ case class MapObjects private( """ ev.copy(code = code, isNull = genInputData.isNull) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy( + loopVar = newFirst.asInstanceOf[LambdaVariable], + lambdaFunction = newSecond, + inputData = newThird) } /** @@ -1044,6 +1069,9 @@ case class UnresolvedCatalystToExternalMap( override lazy val resolved = false override def dataType: DataType = ObjectType(collClass) + + override protected def withNewChildInternal( + newChild: Expression): UnresolvedCatalystToExternalMap = copy(child = newChild) } object CatalystToExternalMap { @@ -1214,6 +1242,15 @@ case class CatalystToExternalMap private( """ ev.copy(code = code, isNull = genInputData.isNull) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CatalystToExternalMap = + copy( + keyLoopVar = newChildren(0).asInstanceOf[LambdaVariable], + keyLambdaFunction = newChildren(1), + valueLoopVar = newChildren(2).asInstanceOf[LambdaVariable], + valueLambdaFunction = newChildren(3), + inputData = newChildren(4)) } object ExternalMapToCatalyst { @@ -1437,6 +1474,15 @@ case class ExternalMapToCatalyst private( """ ev.copy(code = code, isNull = inputMap.isNull) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ExternalMapToCatalyst = + copy( + keyLoopVar = newChildren(0).asInstanceOf[LambdaVariable], + keyConverter = newChildren(1), + valueLoopVar = newChildren(2).asInstanceOf[LambdaVariable], + valueConverter = newChildren(3), + inputData = newChildren(4)) } /** @@ -1487,6 +1533,9 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) """.stripMargin ev.copy(code = code, isNull = FalseLiteral) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CreateExternalRow = copy(children = newChildren) } /** @@ -1516,6 +1565,9 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) } override def dataType: DataType = BinaryType + + override protected def withNewChildInternal(newChild: Expression): EncodeUsingSerializer = + copy(child = newChild) } /** @@ -1548,6 +1600,9 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B } override def dataType: DataType = ObjectType(tag.runtimeClass) + + override protected def withNewChildInternal(newChild: Expression): DecodeUsingSerializer[T] = + copy(child = newChild) } /** @@ -1629,6 +1684,10 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp """.stripMargin ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): InitializeJavaBean = + super.legacyWithNewChildren(newChildren).asInstanceOf[InitializeJavaBean] } /** @@ -1676,6 +1735,9 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) """ ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } + + override protected def withNewChildInternal(newChild: Expression): AssertNotNull = + copy(child = newChild) } /** @@ -1727,6 +1789,9 @@ case class GetExternalRowField( """ ev.copy(code = code, isNull = FalseLiteral) } + + override protected def withNewChildInternal(newChild: Expression): GetExternalRowField = + copy(child = newChild) } /** @@ -1801,4 +1866,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) """ ev.copy(code = code, isNull = input.isNull) } + + override protected def withNewChildInternal(newChild: Expression): ValidateExternalType = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 33eb120e009ed..d9d0643a9130c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -322,6 +322,8 @@ case class Not(child: Expression) } override def sql: String = s"(NOT ${child.sql})" + + override protected def withNewChildInternal(newChild: Expression): Not = copy(child = newChild) } /** @@ -379,6 +381,9 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) override def nullable: Boolean = children.exists(_.nullable) override def toString: String = s"$value IN ($query)" override def sql: String = s"(${value.sql} IN (${query.sql}))" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): InSubquery = + copy(values = newChildren.dropRight(1), query = newChildren.last.asInstanceOf[ListQuery]) } @@ -520,6 +525,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val listSQL = list.map(_.sql).mkString(", ") s"($valueSQL IN ($listSQL))" } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): In = + copy(value = newChildren.head, list = newChildren.tail) } /** @@ -625,6 +633,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with .mkString(", ") s"($valueSQL IN ($listSQL))" } + + override protected def withNewChildInternal(newChild: Expression): InSet = copy(child = newChild) } @ExpressionDescription( @@ -708,6 +718,9 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with """) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): And = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -792,6 +805,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P """) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Or = + copy(left = newLeft, right = newRight) } @@ -877,6 +893,9 @@ case class EqualTo(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): EqualTo = copy(left = newLeft, right = newRight) } // TODO: although map type is not orderable, technically map type should be able to be used @@ -938,6 +957,10 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): EqualNullSafe = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -970,6 +993,9 @@ case class LessThan(left: Expression, right: Expression) override def symbol: String = "<" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1002,6 +1028,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) override def symbol: String = "<=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1034,6 +1063,9 @@ case class GreaterThan(left: Expression, right: Expression) override def symbol: String = ">" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1066,6 +1098,10 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) override def symbol: String = ">=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GreaterThanOrEqual = + copy(left = newLeft, right = newRight) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 0a4c6e27d51d9..d470cadff85b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -111,6 +111,8 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG { override def sql: String = { s"rand(${if (hideSeed) "" else child.sql})" } + + override protected def withNewChildInternal(newChild: Expression): Rand = copy(child = newChild) } object Rand { @@ -162,6 +164,8 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG { override def sql: String = { s"randn(${if (hideSeed) "" else child.sql})" } + + override protected def withNewChildInternal(newChild: Expression): Randn = copy(child = newChild) } object Randn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9fdab350ceb95..13d00faea37f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -180,6 +180,9 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) }) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Like = + copy(left = newLeft, right = newRight) } sealed abstract class MultiLikeBase @@ -268,10 +271,14 @@ sealed abstract class LikeAllBase extends MultiLikeBase { case class LikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase { override def isNotSpecified: Boolean = false + override protected def withNewChildInternal(newChild: Expression): LikeAll = + copy(child = newChild) } case class NotLikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase { override def isNotSpecified: Boolean = true + override protected def withNewChildInternal(newChild: Expression): NotLikeAll = + copy(child = newChild) } /** @@ -324,10 +331,14 @@ sealed abstract class LikeAnyBase extends MultiLikeBase { case class LikeAny(child: Expression, patterns: Seq[UTF8String]) extends LikeAnyBase { override def isNotSpecified: Boolean = false + override protected def withNewChildInternal(newChild: Expression): LikeAny = + copy(child = newChild) } case class NotLikeAny(child: Expression, patterns: Seq[UTF8String]) extends LikeAnyBase { override def isNotSpecified: Boolean = true + override protected def withNewChildInternal(newChild: Expression): NotLikeAny = + copy(child = newChild) } // scalastyle:off line.contains.tab @@ -409,6 +420,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress }) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): RLike = + copy(left = newLeft, right = newRight) } @@ -467,6 +481,10 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) } override def prettyName: String = "split" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringSplit = + copy(str = newFirst, regex = newSecond, limit = newThird) } @@ -622,6 +640,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def second: Expression = regexp override def third: Expression = rep override def fourth: Expression = pos + + override protected def withNewChildrenInternal( + first: Expression, second: Expression, third: Expression, fourth: Expression): RegExpReplace = + copy(subject = first, regexp = second, rep = third, pos = fourth) } object RegExpReplace { @@ -765,6 +787,10 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio }""" }) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): RegExpExtract = + copy(subject = newFirst, regexp = newSecond, idx = newThird) } /** @@ -868,4 +894,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres """ }) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): RegExpExtractAll = + copy(subject = newFirst, regexp = newSecond, idx = newThird) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 714f1d6dc4bfc..3d5f812af9c2e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -227,6 +227,9 @@ case class ConcatWs(children: Seq[Expression]) """) } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ConcatWs = + copy(children = newChildren) } /** @@ -366,6 +369,9 @@ case class Elt( |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt = + copy(children = newChildren) } @@ -403,6 +409,8 @@ case class Upper(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } + + override protected def withNewChildInternal(newChild: Expression): Upper = copy(child = newChild) } /** @@ -430,6 +438,8 @@ case class Lower(child: Expression) override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("lower") + + override protected def withNewChildInternal(newChild: Expression): Lower = copy(child = newChild) } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -454,6 +464,8 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight) } /** @@ -464,6 +476,8 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight) } /** @@ -474,6 +488,8 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight) } /** @@ -522,6 +538,10 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def third: Expression = replaceExpr override def prettyName: String = "replace" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringReplace = + copy(srcExpr = newFirst, searchExpr = newSecond, replaceExpr = newThird) } object Overlay { @@ -634,6 +654,10 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: override def second: Expression = replace override def third: Expression = pos override def fourth: Expression = len + + override protected def withNewChildrenInternal( + first: Expression, second: Expression, third: Expression, fourth: Expression): Overlay = + copy(input = first, replace = second, pos = third, len = fourth) } object StringTranslate { @@ -731,6 +755,10 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def second: Expression = matchingExpr override def third: Expression = replaceExpr override def prettyName: String = "translate" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringTranslate = + copy(srcExpr = newFirst, matchingExpr = newSecond, replaceExpr = newThird) } /** @@ -769,6 +797,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def dataType: DataType = IntegerType override def prettyName: String = "find_in_set" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): FindInSet = copy(left = newLeft, right = newRight) } trait String2TrimExpression extends Expression with ImplicitCastInputTypes { @@ -926,6 +957,11 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) srcString.trim(trimString) override val trimMethod: String = "trim" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy( + srcStr = newChildren.head, + trimStr = if (trimStr.isDefined) Some(newChildren.last) else None) } /** @@ -974,6 +1010,9 @@ case class StringTrimBoth(srcStr: Expression, trimStr: Option[Expression], child override def flatArguments: Iterator[Any] = Iterator(srcStr, trimStr) override def prettyName: String = "btrim" + + override protected def withNewChildInternal(newChild: Expression): StringTrimBoth = + copy(child = newChild) } object StringTrimLeft { @@ -1027,6 +1066,12 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None srcString.trimLeft(trimString) override val trimMethod: String = "trimLeft" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): StringTrimLeft = + copy( + srcStr = newChildren.head, + trimStr = if (trimStr.isDefined) Some(newChildren.last) else None) } object StringTrimRight { @@ -1082,6 +1127,12 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non srcString.trimRight(trimString) override val trimMethod: String = "trimRight" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): StringTrimRight = + copy( + srcStr = newChildren.head, + trimStr = if (trimStr.isDefined) Some(newChildren.last) else None) } /** @@ -1120,6 +1171,9 @@ case class StringInstr(str: Expression, substr: Expression) defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): StringInstr = copy(str = newLeft, substr = newRight) } /** @@ -1164,6 +1218,10 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): SubstringIndex = + copy(strExpr = newFirst, delimExpr = newSecond, countExpr = newThird) } /** @@ -1258,6 +1316,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("locate") + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringLocate = + copy(substr = newFirst, str = newSecond, start = newThird) + } /** @@ -1302,6 +1365,10 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera } override def prettyName: String = "lpad" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringLPad = + copy(str = newFirst, len = newSecond, pad = newThird) } /** @@ -1347,6 +1414,10 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera } override def prettyName: String = "rpad" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringRPad = + copy(str = newFirst, len = newSecond, pad = newThird) } object ParseUrl { @@ -1519,6 +1590,9 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge } } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl = + copy(children = newChildren) } /** @@ -1606,6 +1680,9 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def prettyName: String = getTagValue( FunctionRegistry.FUNC_ALIAS).getOrElse("format_string") + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): FormatString = FormatString(newChildren: _*) } /** @@ -1638,6 +1715,9 @@ case class InitCap(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") } + + override protected def withNewChildInternal(newChild: Expression): InitCap = + copy(child = newChild) } /** @@ -1669,6 +1749,9 @@ case class StringRepeat(str: Expression, times: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): StringRepeat = copy(str = newLeft, times = newRight) } /** @@ -1700,6 +1783,9 @@ case class StringSpace(child: Expression) } override def prettyName: String = "space" + + override protected def withNewChildInternal(newChild: Expression): StringSpace = + copy(child = newChild) } /** @@ -1767,6 +1853,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } }) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Substring = + copy(str = newFirst, pos = newSecond, len = newThird) + } /** @@ -1791,6 +1882,8 @@ case class Right(str: Expression, len: Expression, child: Expression) extends Ru override def flatArguments: Iterator[Any] = Iterator(str, len) override def exprsReplaced: Seq[Expression] = Seq(str, len) + + override protected def withNewChildInternal(newChild: Expression): Right = copy(child = newChild) } /** @@ -1814,6 +1907,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run override def flatArguments: Iterator[Any] = Iterator(str, len) override def exprsReplaced: Seq[Expression] = Seq(str, len) + override protected def withNewChildInternal(newChild: Expression): Left = copy(child = newChild) } /** @@ -1851,6 +1945,8 @@ case class Length(child: Expression) case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } + + override protected def withNewChildInternal(newChild: Expression): Length = copy(child = newChild) } /** @@ -1883,6 +1979,9 @@ case class BitLength(child: Expression) } override def prettyName: String = "bit_length" + + override protected def withNewChildInternal(newChild: Expression): BitLength = + copy(child = newChild) } /** @@ -1916,6 +2015,9 @@ case class OctetLength(child: Expression) } override def prettyName: String = "octet_length" + + override protected def withNewChildInternal(newChild: Expression): OctetLength = + copy(child = newChild) } /** @@ -1943,6 +2045,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Levenshtein = copy(left = newLeft, right = newRight) } /** @@ -1969,6 +2074,9 @@ case class SoundEx(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } + + override protected def withNewChildInternal(newChild: Expression): SoundEx = + copy(child = newChild) } /** @@ -2012,6 +2120,8 @@ case class Ascii(child: Expression) } """}) } + + override protected def withNewChildInternal(newChild: Expression): Ascii = copy(child = newChild) } /** @@ -2060,6 +2170,8 @@ case class Chr(child: Expression) """ }) } + + override protected def withNewChildInternal(newChild: Expression): Chr = copy(child = newChild) } /** @@ -2090,6 +2202,8 @@ case class Base64(child: Expression) ${classOf[CommonsBase64].getName}.encodeBase64($child)); """}) } + + override protected def withNewChildInternal(newChild: Expression): Base64 = copy(child = newChild) } /** @@ -2119,6 +2233,9 @@ case class UnBase64(child: Expression) ${ev.value} = ${classOf[CommonsBase64].getName}.decodeBase64($child.toString()); """}) } + + override protected def withNewChildInternal(newChild: Expression): UnBase64 = + copy(child = newChild) } object Decode { @@ -2178,6 +2295,8 @@ case class Decode(params: Seq[Expression], child: Expression) extends RuntimeRep override def flatArguments: Iterator[Any] = Iterator(params) override def exprsReplaced: Seq[Expression] = params + + override protected def withNewChildInternal(newChild: Expression): Decode = copy(child = newChild) } /** @@ -2219,6 +2338,10 @@ case class StringDecode(bin: Expression, charset: Expression) } """) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): StringDecode = + copy(bin = newLeft, charset = newRight) } /** @@ -2259,6 +2382,9 @@ case class Encode(value: Expression, charset: Expression) org.apache.spark.unsafe.Platform.throwException(e); }""") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Encode = copy(value = newLeft, charset = newRight) } /** @@ -2439,6 +2565,9 @@ case class FormatNumber(x: Expression, d: Expression) } override def prettyName: String = "format_number" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): FormatNumber = copy(x = newLeft, d = newRight) } /** @@ -2509,4 +2638,9 @@ case class Sentences( } new GenericArrayData(result.toSeq) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Sentences = + copy(str = newFirst, language = newSecond, country = newThird) + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index ff8856708c6d1..ea6e427a95b5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -238,6 +238,9 @@ case class ScalarSubquery( children.map(_.canonicalized), ExprId(0)) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren) } object ScalarSubquery { @@ -283,6 +286,9 @@ case class ListQuery( ExprId(0), childOutputs.map(_.canonicalized.asInstanceOf[Attribute])) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery = + copy(children = newChildren) } /** @@ -325,4 +331,7 @@ case class Exists( children.map(_.canonicalized), ExprId(0)) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index fa027d1ab0561..ff486bfbdef75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -47,6 +47,13 @@ case class WindowSpecDefinition( override def children: Seq[Expression] = partitionSpec ++ orderSpec :+ frameSpecification + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): WindowSpecDefinition = + copy( + partitionSpec = newChildren.take(partitionSpec.size), + orderSpec = newChildren.drop(partitionSpec.size).dropRight(1).asInstanceOf[Seq[SortOrder]], + frameSpecification = newChildren.last.asInstanceOf[WindowFrame]) + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && frameSpecification.isInstanceOf[SpecifiedWindowFrame] @@ -266,6 +273,10 @@ case class SpecifiedWindowFrame( case _ => true } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): SpecifiedWindowFrame = + copy(lower = newLeft, upper = newRight) } case class UnresolvedWindowExpression( @@ -275,6 +286,9 @@ case class UnresolvedWindowExpression( override def dataType: DataType = throw new UnresolvedException("dataType") override def nullable: Boolean = throw new UnresolvedException("nullable") override lazy val resolved = false + + override protected def withNewChildInternal(newChild: Expression): UnresolvedWindowExpression = + copy(child = newChild) } case class WindowExpression( @@ -290,6 +304,10 @@ case class WindowExpression( override def toString: String = s"$windowFunction $windowSpec" override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): WindowExpression = + copy(windowFunction = newLeft, windowSpec = newRight.asInstanceOf[WindowSpecDefinition]) } /** @@ -458,6 +476,10 @@ case class Lead( override def first: Expression = input override def second: Expression = offset override def third: Expression = default + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Lead = + copy(input = newFirst, offset = newSecond, default = newThird) } /** @@ -513,6 +535,10 @@ case class Lag( override def first: Expression = input override def second: Expression = inputOffset override def third: Expression = default + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Lag = + copy(input = newFirst, inputOffset = newSecond, default = newThird) } abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { @@ -698,6 +724,10 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean) override def prettyName: String = "nth_value" override def sql: String = s"$prettyName(${input.sql}, ${offset.sql})${if (ignoreNulls) " ignore nulls" else ""}" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): NthValue = + copy(input = newLeft, offset = newRight) } /** @@ -800,6 +830,9 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow ) override val evaluateExpression = bucket + + override protected def withNewChildInternal( + newChild: Expression): NTile = copy(buckets = newChild) } /** @@ -884,6 +917,8 @@ abstract class RankLike extends AggregateWindowFunction { case class Rank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): Rank = Rank(order) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Rank = + copy(children = newChildren) } /** @@ -925,6 +960,8 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { override val aggBufferAttributes = rank +: orderAttrs override val initialValues = zero +: orderInit override def prettyName: String = "dense_rank" + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): DenseRank = + copy(children = newChildren) } /** @@ -966,4 +1003,6 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase override val evaluateExpression = If(n > one, (rank - one).cast(DoubleType) / (n - one).cast(DoubleType), 0.0d) override def prettyName: String = "percent_rank" + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PercentRank = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index b8fc830f18183..336dc7a480cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -75,6 +75,9 @@ case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract override def nullSafeEval(xml: Any, path: Any): Any = { xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathBoolean = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -96,6 +99,9 @@ case class XPathShort(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.shortValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathShort = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -117,6 +123,9 @@ case class XPathInt(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.intValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -138,6 +147,9 @@ case class XPathLong(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.longValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathLong = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -159,6 +171,9 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.floatValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathFloat = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -181,6 +196,9 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.doubleValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathDouble = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -202,6 +220,9 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString) UTF8String.fromString(ret) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -233,4 +254,7 @@ case class XPathList(xml: Expression, path: Expression) extends XPathExtract { null } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathList = copy(xml = newLeft, path = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 828f768f17701..2a288ffd8ecf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -107,6 +107,9 @@ case class OrderedJoin( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): OrderedJoin = + copy(left = newLeft, right = newRight) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index ac8766cd74367..a6444b13acd02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -211,4 +211,7 @@ case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with E nullSafeCodeGen(ctx, ev, codeToNormalize) } + + override protected def withNewChildInternal(newChild: Expression): NormalizeNaNAndZero = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 327856956c610..bca00ff3b7376 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.BooleanType import org.apache.spark.util.Utils @@ -123,8 +123,10 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = { mergeActions.map { case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond))) + case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond))) + case i @ InsertStarAction(Some(cond)) => i.copy(condition = Some(replaceNullWithFalse(cond))) case other => other } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 05678b7bbdabf..da1e42fd1288b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ @@ -607,6 +608,19 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe } } + /** + * Check if an [[Aggregate]] has no correlated subquery in aggregate expressions but + * still has correlated scalar subqueries in its grouping expressions, which will not + * be rewritten. + */ + private def checkScalarSubqueryInAgg(a: Aggregate): Unit = { + if (a.groupingExpressions.exists(hasCorrelatedScalarSubquery) && + !a.aggregateExpressions.exists(hasCorrelatedScalarSubquery)) { + throw new IllegalStateException( + s"Fail to rewrite correlated scalar subqueries in Aggregate:\n$a") + } + } + /** * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar * subqueries. @@ -626,6 +640,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping) val newAgg = Aggregate(newGrouping, newExprs, newChild) val attrMapping = a.output.zip(newAgg.output) + checkScalarSubqueryInAgg(newAgg) newAgg -> attrMapping } else { a -> Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e9052fd6e4e33..a788f233af3a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -415,7 +415,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } else if (clause.matchedAction().UPDATE() != null) { val condition = Option(clause.matchedCond).map(expression) if (clause.matchedAction().ASTERISK() != null) { - UpdateAction(condition, Seq()) + UpdateStarAction(condition) } else { UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList())) } @@ -430,7 +430,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg if (clause.notMatchedAction().INSERT() != null) { val condition = Option(clause.notMatchedCond).map(expression) if (clause.notMatchedAction().ASTERISK() != null) { - InsertAction(condition, Seq()) + InsertStarAction(condition) } else { val columns = clause.notMatchedAction().columns.multipartIdentifier() .asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index b6bf7cd85d472..bf3f93de97f8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -61,4 +61,7 @@ case class EventTimeWatermark( a } } + + override protected def withNewChildInternal(newChild: LogicalPlan): EventTimeWatermark = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index 30bff884b2249..6299976911ee4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -35,6 +35,9 @@ case class ScriptTransformation( ioschema: ScriptInputOutputSchema) extends UnaryNode { @transient override lazy val references: AttributeSet = AttributeSet(input.flatMap(_.references)) + + override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation = + copy(child = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 962ce938d2954..ba54be7679ec1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -41,6 +41,8 @@ import org.apache.spark.util.random.RandomSampler case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): ReturnAnswer = + copy(child = newChild) } /** @@ -52,6 +54,8 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { */ case class Subquery(child: LogicalPlan, correlated: Boolean) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): Subquery = + copy(child = newChild) } object Subquery { @@ -78,6 +82,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) override lazy val validConstraints: ExpressionSet = getAllValidConstraints(projectList) + + override protected def withNewChildInternal(newChild: LogicalPlan): Project = + copy(child = newChild) } /** @@ -136,6 +143,9 @@ case class Generate( } def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput + + override protected def withNewChildInternal(newChild: LogicalPlan): Generate = + copy(child = newChild) } case class Filter(condition: Expression, child: LogicalPlan) @@ -149,6 +159,9 @@ case class Filter(condition: Expression, child: LogicalPlan) .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(ExpressionSet(predicates)) } + + override protected def withNewChildInternal(newChild: LogicalPlan): Filter = + copy(child = newChild) } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -201,6 +214,9 @@ case class Intersect( Some(children.flatMap(_.maxRows).min) } } + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): Intersect = copy(left = newLeft, right = newRight) } case class Except( @@ -214,6 +230,9 @@ case class Except( override def metadataOutput: Seq[Attribute] = Nil override protected lazy val validConstraints: ExpressionSet = leftConstraints + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): Except = copy(left = newLeft, right = newRight) } /** Factory for constructing new `Union` nodes. */ @@ -326,6 +345,9 @@ case class Union( .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) .reduce(merge(_, _)) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union = + copy(children = newChildren) } case class Join( @@ -436,6 +458,9 @@ case class Join( || e.asInstanceOf[JoinHint].leftHint.isDefined || e.asInstanceOf[JoinHint].rightHint.isDefined) } + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): Join = copy(left = newLeft, right = newRight) } /** @@ -461,6 +486,9 @@ case class InsertIntoDir( override def output: Seq[Attribute] = Seq.empty override def metadataOutput: Seq[Attribute] = Nil override lazy val resolved: Boolean = false + + override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoDir = + copy(child = newChild) } /** @@ -515,6 +543,9 @@ case class View( case _ => false } } + + override protected def withNewChildInternal(newChild: LogicalPlan): View = + copy(child = newChild) } object View { @@ -548,12 +579,16 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) } override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) + + override protected def withNewChildInternal(newChild: LogicalPlan): With = copy(child = newChild) } case class WithWindowDefinition( windowDefinitions: Map[String, WindowSpecDefinition], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): WithWindowDefinition = + copy(child = newChild) } /** @@ -569,6 +604,7 @@ case class Sort( override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows override def outputOrdering: Seq[SortOrder] = order + override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild) } /** Factory for constructing new `Range` nodes. */ @@ -739,6 +775,9 @@ case class Aggregate( val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) getAllValidConstraints(nonAgg) } + + override protected def withNewChildInternal(newChild: LogicalPlan): Aggregate = + copy(child = newChild) } case class Window( @@ -753,6 +792,9 @@ case class Window( override def producedAttributes: AttributeSet = windowOutputSet def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) + + override protected def withNewChildInternal(newChild: LogicalPlan): Window = + copy(child = newChild) } object Expand { @@ -869,6 +911,9 @@ case class Expand( // This operator can reuse attributes (for example making them null when doing a roll up) so // the constraints of the child may no longer be valid. override protected lazy val validConstraints: ExpressionSet = ExpressionSet() + + override protected def withNewChildInternal(newChild: LogicalPlan): Expand = + copy(child = newChild) } /** @@ -901,6 +946,8 @@ case class Pivot( groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } override def metadataOutput: Seq[Attribute] = Nil + + override protected def withNewChildInternal(newChild: LogicalPlan): Pivot = copy(child = newChild) } /** @@ -950,6 +997,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderP case _ => None } } + + override protected def withNewChildInternal(newChild: LogicalPlan): GlobalLimit = + copy(child = newChild) } /** @@ -967,6 +1017,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr case _ => None } } + + override protected def withNewChildInternal(newChild: LogicalPlan): LocalLimit = + copy(child = newChild) } /** @@ -987,6 +1040,8 @@ case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservi case _ => None } } + + override protected def withNewChildInternal(newChild: LogicalPlan): Tail = copy(child = newChild) } /** @@ -1013,6 +1068,9 @@ case class SubqueryAlias( } override def doCanonicalize(): LogicalPlan = child.canonicalized + + override protected def withNewChildInternal(newChild: LogicalPlan): SubqueryAlias = + copy(child = newChild) } object SubqueryAlias { @@ -1066,6 +1124,9 @@ case class Sample( override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): Sample = + copy(child = newChild) } /** @@ -1074,6 +1135,8 @@ case class Sample( case class Distinct(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): Distinct = + copy(child = newChild) } /** @@ -1104,6 +1167,8 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) case _ => RoundRobinPartitioning(numPartitions) } } + override protected def withNewChildInternal(newChild: LogicalPlan): Repartition = + copy(child = newChild) } /** @@ -1145,6 +1210,9 @@ case class RepartitionByExpression( } override def shuffle: Boolean = true + + override protected def withNewChildInternal(newChild: LogicalPlan): RepartitionByExpression = + copy(child = newChild) } object RepartitionByExpression { @@ -1178,6 +1246,8 @@ case class Deduplicate( child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): Deduplicate = + copy(child = newChild) } /** @@ -1206,4 +1276,7 @@ case class CollectMetrics( } override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 4b5e278fccdfb..5bda94cea9527 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -31,6 +31,9 @@ case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan override lazy val resolved: Boolean = false override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedHint = + copy(child = newChild) } /** @@ -43,6 +46,9 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def output: Seq[Attribute] = child.output override def doCanonicalize(): LogicalPlan = child.canonicalized + + override protected def withNewChildInternal(newChild: LogicalPlan): ResolvedHint = + copy(child = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index d383532cbd3d3..6d61a86ab5ef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -79,7 +79,10 @@ trait ObjectConsumer extends UnaryNode { case class DeserializeToObject( deserializer: Expression, outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectProducer + child: LogicalPlan) extends UnaryNode with ObjectProducer { + override protected def withNewChildInternal(newChild: LogicalPlan): DeserializeToObject = + copy(child = newChild) +} /** * Takes the input object from child and turns it into unsafe row using the given serializer @@ -90,6 +93,9 @@ case class SerializeFromObject( child: LogicalPlan) extends ObjectConsumer { override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override protected def withNewChildInternal(newChild: LogicalPlan): SerializeFromObject = + copy(child = newChild) } object MapPartitions { @@ -111,7 +117,10 @@ object MapPartitions { case class MapPartitions( func: Iterator[Any] => Iterator[Any], outputObjAttr: Attribute, - child: LogicalPlan) extends ObjectConsumer with ObjectProducer + child: LogicalPlan) extends ObjectConsumer with ObjectProducer { + override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitions = + copy(child = newChild) +} object MapPartitionsInR { def apply( @@ -159,6 +168,9 @@ case class MapPartitionsInR( override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, outputObjAttr, child) + + override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitionsInR = + copy(child = newChild) } /** @@ -182,6 +194,9 @@ case class MapPartitionsInRWithArrow( inputSchema, StructType.fromAttributes(output), child) override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitionsInRWithArrow = + copy(child = newChild) } object MapElements { @@ -207,7 +222,10 @@ case class MapElements( argumentClass: Class[_], argumentSchema: StructType, outputObjAttr: Attribute, - child: LogicalPlan) extends ObjectConsumer with ObjectProducer + child: LogicalPlan) extends ObjectConsumer with ObjectProducer { + override protected def withNewChildInternal(newChild: LogicalPlan): MapElements = + copy(child = newChild) +} object TypedFilter { def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { @@ -251,6 +269,9 @@ case class TypedFilter( val funcObj = Literal.create(func, ObjectType(funcMethod._1)) Invoke(funcObj, funcMethod._2, BooleanType, input :: Nil) } + + override protected def withNewChildInternal(newChild: LogicalPlan): TypedFilter = + copy(child = newChild) } object FunctionUtils { @@ -334,6 +355,9 @@ case class AppendColumns( override def output: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) + + override protected def withNewChildInternal(newChild: LogicalPlan): AppendColumns = + copy(child = newChild) } /** @@ -346,6 +370,9 @@ case class AppendColumnsWithObject( child: LogicalPlan) extends ObjectConsumer { override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute) + + override protected def withNewChildInternal(newChild: LogicalPlan): AppendColumnsWithObject = + copy(child = newChild) } /** Factory for constructing new `MapGroups` nodes. */ @@ -382,7 +409,10 @@ case class MapGroups( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectProducer + child: LogicalPlan) extends UnaryNode with ObjectProducer { + override protected def withNewChildInternal(newChild: LogicalPlan): MapGroups = + copy(child = newChild) +} /** Internal class representing State */ trait LogicalGroupState[S] @@ -453,6 +483,9 @@ case class FlatMapGroupsWithState( if (isMapGroupsWithState) { assert(outputMode == OutputMode.Update) } + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsWithState = + copy(child = newChild) } /** Factory for constructing new `FlatMapGroupsInR` nodes. */ @@ -513,6 +546,9 @@ case class FlatMapGroupsInR( override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInR = + copy(child = newChild) } /** @@ -537,6 +573,9 @@ case class FlatMapGroupsInRWithArrow( inputSchema, StructType.fromAttributes(output), keyDeserializer, groupingAttributes, child) override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInRWithArrow = + copy(child = newChild) } /** Factory for constructing new `CoGroup` nodes. */ @@ -584,4 +623,7 @@ case class CoGroup( rightAttr: Seq[Attribute], outputObjAttr: Attribute, left: LogicalPlan, - right: LogicalPlan) extends BinaryNode with ObjectProducer + right: LogicalPlan) extends BinaryNode with ObjectProducer { + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): CoGroup = copy(left = newLeft, right = newRight) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 62f2d598b96dc..ba8352cf6ac89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -37,6 +37,9 @@ case class FlatMapGroupsInPandas( * from the input. */ override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInPandas = + copy(child = newChild) } /** @@ -49,6 +52,9 @@ case class MapInPandas( child: LogicalPlan) extends UnaryNode { override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): MapInPandas = + copy(child = newChild) } /** @@ -70,6 +76,10 @@ case class FlatMapCoGroupsInPandas( def leftAttributes: Seq[Attribute] = left.output.take(leftGroupingLen) def rightAttributes: Seq[Attribute] = right.output.take(rightGroupingLen) + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapCoGroupsInPandas = + copy(left = newLeft, right = newRight) } trait BaseEvalPython extends UnaryNode { @@ -89,7 +99,10 @@ trait BaseEvalPython extends UnaryNode { case class BatchEvalPython( udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], - child: LogicalPlan) extends BaseEvalPython + child: LogicalPlan) extends BaseEvalPython { + override protected def withNewChildInternal(newChild: LogicalPlan): BatchEvalPython = + copy(child = newChild) +} /** * A logical plan that evaluates a [[PythonUDF]] with Apache Arrow. @@ -98,4 +111,7 @@ case class ArrowEvalPython( udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: LogicalPlan, - evalType: Int) extends BaseEvalPython + evalType: Int) extends BaseEvalPython { + override protected def withNewChildInternal(newChild: LogicalPlan): ArrowEvalPython = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index d600c15004d1e..44550ae2844ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -167,6 +167,8 @@ case class CreateTableAsSelectStatement( ifNotExists: Boolean) extends UnaryParsedStatement { override def child: LogicalPlan = asSelect + override protected def withNewChildInternal(newChild: LogicalPlan): CreateTableAsSelectStatement = + copy(asSelect = newChild) } /** @@ -181,7 +183,10 @@ case class CreateViewStatement( child: LogicalPlan, allowExisting: Boolean, replace: Boolean, - viewType: ViewType) extends UnaryParsedStatement + viewType: ViewType) extends UnaryParsedStatement { + override protected def withNewChildInternal(newChild: LogicalPlan): CreateViewStatement = + copy(child = newChild) +} /** * A REPLACE TABLE command, as parsed from SQL. @@ -220,6 +225,8 @@ case class ReplaceTableAsSelectStatement( orCreate: Boolean) extends UnaryParsedStatement { override def child: LogicalPlan = asSelect + override protected def withNewChildInternal( + newChild: LogicalPlan): ReplaceTableAsSelectStatement = copy(asSelect = newChild) } @@ -300,6 +307,8 @@ case class InsertIntoStatement( "IF NOT EXISTS is only valid with static partitions") override def child: LogicalPlan = query + override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoStatement = + copy(query = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index c838ef2ae9265..84f3aacbf6d2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -77,6 +77,8 @@ case class AppendData( write: Option[Write] = None) extends V2WriteCommand { override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery) override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable) + override protected def withNewChildInternal(newChild: LogicalPlan): AppendData = + copy(query = newChild) } object AppendData { @@ -115,6 +117,9 @@ case class OverwriteByExpression( override def withNewTable(newTable: NamedRelation): OverwriteByExpression = { copy(table = newTable) } + + override protected def withNewChildInternal(newChild: LogicalPlan): OverwriteByExpression = + copy(query = newChild) } object OverwriteByExpression { @@ -150,6 +155,9 @@ case class OverwritePartitionsDynamic( override def withNewTable(newTable: NamedRelation): OverwritePartitionsDynamic = { copy(table = newTable) } + + override protected def withNewChildInternal(newChild: LogicalPlan): OverwritePartitionsDynamic = + copy(query = newChild) } object OverwritePartitionsDynamic { @@ -222,6 +230,9 @@ case class CreateTableAsSelect( override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } + + override protected def withNewChildInternal(newChild: LogicalPlan): CreateTableAsSelect = + copy(query = newChild) } /** @@ -272,6 +283,9 @@ case class ReplaceTableAsSelect( override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } + + override protected def withNewChildInternal(newChild: LogicalPlan): ReplaceTableAsSelect = + copy(query = newChild) } /** @@ -291,6 +305,8 @@ case class DropNamespace( ifExists: Boolean, cascade: Boolean) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(namespace = newChild) } /** @@ -301,6 +317,8 @@ case class DescribeNamespace( extended: Boolean, override val output: Seq[Attribute] = DescribeNamespace.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): DescribeNamespace = + copy(namespace = newChild) } object DescribeNamespace { @@ -319,6 +337,8 @@ case class SetNamespaceProperties( namespace: LogicalPlan, properties: Map[String, String]) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): SetNamespaceProperties = + copy(namespace = newChild) } /** @@ -328,6 +348,8 @@ case class SetNamespaceLocation( namespace: LogicalPlan, location: String) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): SetNamespaceLocation = + copy(namespace = newChild) } /** @@ -338,6 +360,8 @@ case class ShowNamespaces( pattern: Option[String], override val output: Seq[Attribute] = ShowNamespaces.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): ShowNamespaces = + copy(namespace = newChild) } object ShowNamespaces { @@ -355,6 +379,8 @@ case class DescribeRelation( isExtended: Boolean, override val output: Seq[Attribute] = DescribeRelation.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = relation + override protected def withNewChildInternal(newChild: LogicalPlan): DescribeRelation = + copy(relation = newChild) } object DescribeRelation { @@ -370,6 +396,8 @@ case class DescribeColumn( isExtended: Boolean, override val output: Seq[Attribute] = DescribeColumn.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = relation + override protected def withNewChildInternal(newChild: LogicalPlan): DescribeColumn = + copy(relation = newChild) } object DescribeColumn { @@ -383,6 +411,8 @@ case class DeleteFromTable( table: LogicalPlan, condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable = + copy(table = newChild) } /** @@ -393,6 +423,8 @@ case class UpdateTable( assignments: Seq[Assignment], condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable = + copy(table = newChild) } /** @@ -407,6 +439,9 @@ case class MergeIntoTable( def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty override def left: LogicalPlan = targetTable override def right: LogicalPlan = sourceTable + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): MergeIntoTable = + copy(targetTable = newLeft, sourceTable = newRight) } sealed abstract class MergeAction extends Expression with Unevaluable { @@ -416,18 +451,49 @@ sealed abstract class MergeAction extends Expression with Unevaluable { override def children: Seq[Expression] = condition.toSeq } -case class DeleteAction(condition: Option[Expression]) extends MergeAction +case class DeleteAction(condition: Option[Expression]) extends MergeAction { + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): DeleteAction = + copy(condition = if (condition.isDefined) Some(newChildren(0)) else None) +} case class UpdateAction( condition: Option[Expression], assignments: Seq[Assignment]) extends MergeAction { override def children: Seq[Expression] = condition.toSeq ++ assignments + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UpdateAction = + copy( + condition = if (condition.isDefined) Some(newChildren.head) else None, + assignments = newChildren.tail.asInstanceOf[Seq[Assignment]]) +} + +case class UpdateStarAction(condition: Option[Expression]) extends MergeAction { + override def children: Seq[Expression] = condition.toSeq + override lazy val resolved = false + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UpdateStarAction = + copy(condition = if (condition.isDefined) Some(newChildren(0)) else None) } case class InsertAction( condition: Option[Expression], assignments: Seq[Assignment]) extends MergeAction { override def children: Seq[Expression] = condition.toSeq ++ assignments + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): InsertAction = + copy( + condition = if (condition.isDefined) Some(newChildren.head) else None, + assignments = newChildren.tail.asInstanceOf[Seq[Assignment]]) +} + +case class InsertStarAction(condition: Option[Expression]) extends MergeAction { + override def children: Seq[Expression] = condition.toSeq + override lazy val resolved = false + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): InsertStarAction = + copy(condition = if (condition.isDefined) Some(newChildren(0)) else None) } case class Assignment(key: Expression, value: Expression) extends Expression @@ -436,6 +502,8 @@ case class Assignment(key: Expression, value: Expression) extends Expression override def dataType: DataType = throw new UnresolvedException("nullable") override def left: Expression = key override def right: Expression = value + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } /** @@ -452,7 +520,10 @@ case class Assignment(key: Expression, value: Expression) extends Expression case class DropTable( child: LogicalPlan, ifExists: Boolean, - purge: Boolean) extends UnaryCommand + purge: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropTable = + copy(child = newChild) +} /** * The logical plan for no-op command handling non-existing table. @@ -499,7 +570,10 @@ case class AlterTable( case class RenameTable( child: LogicalPlan, newName: Seq[String], - isView: Boolean) extends UnaryCommand + isView: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RenameTable = + copy(child = newChild) +} /** * The logical plan of the SHOW TABLES command. @@ -509,6 +583,8 @@ case class ShowTables( pattern: Option[String], override val output: Seq[Attribute] = ShowTables.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): ShowTables = + copy(namespace = newChild) } object ShowTables { @@ -527,6 +603,8 @@ case class ShowTableExtended( partitionSpec: Option[PartitionSpec], override val output: Seq[Attribute] = ShowTableExtended.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): ShowTableExtended = + copy(namespace = newChild) } object ShowTableExtended { @@ -548,6 +626,8 @@ case class ShowViews( pattern: Option[String], override val output: Seq[Attribute] = ShowViews.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): ShowViews = + copy(namespace = newChild) } object ShowViews { @@ -568,7 +648,10 @@ case class SetCatalogAndNamespace( /** * The logical plan of the REFRESH TABLE command. */ -case class RefreshTable(child: LogicalPlan) extends UnaryCommand +case class RefreshTable(child: LogicalPlan) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RefreshTable = + copy(child = newChild) +} /** * The logical plan of the SHOW CURRENT NAMESPACE command. @@ -587,6 +670,8 @@ case class ShowTableProperties( propertyKey: Option[String], override val output: Seq[Attribute] = ShowTableProperties.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(table = newChild) } object ShowTableProperties { @@ -605,7 +690,10 @@ object ShowTableProperties { * where the `text` is the new comment written as a string literal; or `NULL` to drop the comment. * */ -case class CommentOnNamespace(child: LogicalPlan, comment: String) extends UnaryCommand +case class CommentOnNamespace(child: LogicalPlan, comment: String) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): CommentOnNamespace = + copy(child = newChild) +} /** * The logical plan that defines or changes the comment of an TABLE for v2 catalogs. @@ -617,17 +705,26 @@ case class CommentOnNamespace(child: LogicalPlan, comment: String) extends Unary * where the `text` is the new comment written as a string literal; or `NULL` to drop the comment. * */ -case class CommentOnTable(child: LogicalPlan, comment: String) extends UnaryCommand +case class CommentOnTable(child: LogicalPlan, comment: String) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): CommentOnTable = + copy(child = newChild) +} /** * The logical plan of the REFRESH FUNCTION command. */ -case class RefreshFunction(child: LogicalPlan) extends UnaryCommand +case class RefreshFunction(child: LogicalPlan) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RefreshFunction = + copy(child = newChild) +} /** * The logical plan of the DESCRIBE FUNCTION command. */ -case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends UnaryCommand +case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DescribeFunction = + copy(child = newChild) +} /** * The logical plan of the DROP FUNCTION command. @@ -635,7 +732,10 @@ case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends Una case class DropFunction( child: LogicalPlan, ifExists: Boolean, - isTemp: Boolean) extends UnaryCommand + isTemp: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropFunction = + copy(child = newChild) +} /** * The logical plan of the SHOW FUNCTIONS command. @@ -647,6 +747,9 @@ case class ShowFunctions( pattern: Option[String], override val output: Seq[Attribute] = ShowFunctions.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = child.toSeq + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): ShowFunctions = + copy(child = if (child.isDefined) Some(newChildren.head) else None) } object ShowFunctions { @@ -661,7 +764,10 @@ object ShowFunctions { case class AnalyzeTable( child: LogicalPlan, partitionSpec: Map[String, Option[String]], - noScan: Boolean) extends UnaryCommand + noScan: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeTable = + copy(child = newChild) +} /** * The logical plan of the ANALYZE TABLES command. @@ -670,6 +776,8 @@ case class AnalyzeTables( namespace: LogicalPlan, noScan: Boolean) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeTables = + copy(namespace = newChild) } /** @@ -681,6 +789,9 @@ case class AnalyzeColumn( allColumns: Boolean) extends UnaryCommand { require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " + "mutually exclusive. Only one of them should be specified.") + + override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeColumn = + copy(child = newChild) } /** @@ -695,7 +806,10 @@ case class AnalyzeColumn( case class AddPartitions( table: LogicalPlan, parts: Seq[PartitionSpec], - ifNotExists: Boolean) extends V2PartitionCommand + ifNotExists: Boolean) extends V2PartitionCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): AddPartitions = + copy(table = newChild) +} /** * The logical plan of the ALTER TABLE DROP PARTITION command. @@ -713,7 +827,10 @@ case class DropPartitions( table: LogicalPlan, parts: Seq[PartitionSpec], ifExists: Boolean, - purge: Boolean) extends V2PartitionCommand + purge: Boolean) extends V2PartitionCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropPartitions = + copy(table = newChild) +} /** * The logical plan of the ALTER TABLE ... RENAME TO PARTITION command. @@ -721,12 +838,18 @@ case class DropPartitions( case class RenamePartitions( table: LogicalPlan, from: PartitionSpec, - to: PartitionSpec) extends V2PartitionCommand + to: PartitionSpec) extends V2PartitionCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RenamePartitions = + copy(table = newChild) +} /** * The logical plan of the ALTER TABLE ... RECOVER PARTITIONS command. */ -case class RecoverPartitions(child: LogicalPlan) extends UnaryCommand +case class RecoverPartitions(child: LogicalPlan) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RecoverPartitions = + copy(child = newChild) +} /** * The logical plan of the LOAD DATA INTO TABLE command. @@ -736,7 +859,10 @@ case class LoadData( path: String, isLocal: Boolean, isOverwrite: Boolean, - partition: Option[TablePartitionSpec]) extends UnaryCommand + partition: Option[TablePartitionSpec]) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): LoadData = + copy(child = newChild) +} /** * The logical plan of the SHOW CREATE TABLE command. @@ -744,7 +870,10 @@ case class LoadData( case class ShowCreateTable( child: LogicalPlan, asSerde: Boolean = false, - override val output: Seq[Attribute] = ShowCreateTable.getoutputAttrs) extends UnaryCommand + override val output: Seq[Attribute] = ShowCreateTable.getoutputAttrs) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): ShowCreateTable = + copy(child = newChild) +} object ShowCreateTable { def getoutputAttrs: Seq[Attribute] = { @@ -758,7 +887,10 @@ object ShowCreateTable { case class ShowColumns( child: LogicalPlan, namespace: Option[Seq[String]], - override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends UnaryCommand + override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): ShowColumns = + copy(child = newChild) +} object ShowColumns { def getOutputAttrs: Seq[Attribute] = { @@ -771,6 +903,8 @@ object ShowColumns { */ case class TruncateTable(table: LogicalPlan) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): TruncateTable = + copy(table = newChild) } /** @@ -780,6 +914,8 @@ case class TruncatePartition( table: LogicalPlan, partitionSpec: PartitionSpec) extends V2PartitionCommand { override def allowPartialPartitionSpec: Boolean = true + override protected def withNewChildInternal(newChild: LogicalPlan): TruncatePartition = + copy(table = newChild) } /** @@ -791,6 +927,8 @@ case class ShowPartitions( override val output: Seq[Attribute] = ShowPartitions.getOutputAttrs) extends V2PartitionCommand { override def allowPartialPartitionSpec: Boolean = true + override protected def withNewChildInternal(newChild: LogicalPlan): ShowPartitions = + copy(table = newChild) } object ShowPartitions { @@ -804,7 +942,10 @@ object ShowPartitions { */ case class DropView( child: LogicalPlan, - ifExists: Boolean) extends UnaryCommand + ifExists: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropView = + copy(child = newChild) +} /** * The logical plan of the MSCK REPAIR TABLE command. @@ -812,7 +953,10 @@ case class DropView( case class RepairTable( child: LogicalPlan, enableAddPartitions: Boolean, - enableDropPartitions: Boolean) extends UnaryCommand + enableDropPartitions: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RepairTable = + copy(child = newChild) +} /** * The logical plan of the ALTER VIEW ... AS command. @@ -823,6 +967,9 @@ case class AlterViewAs( query: LogicalPlan) extends BinaryCommand { override def left: LogicalPlan = child override def right: LogicalPlan = query + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = + copy(child = newLeft, query = newRight) } /** @@ -830,7 +977,10 @@ case class AlterViewAs( */ case class SetViewProperties( child: LogicalPlan, - properties: Map[String, String]) extends UnaryCommand + properties: Map[String, String]) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): SetViewProperties = + copy(child = newChild) +} /** * The logical plan of the ALTER VIEW ... UNSET TBLPROPERTIES command. @@ -838,7 +988,10 @@ case class SetViewProperties( case class UnsetViewProperties( child: LogicalPlan, propertyKeys: Seq[String], - ifExists: Boolean) extends UnaryCommand + ifExists: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): UnsetViewProperties = + copy(child = newChild) +} /** * The logical plan of the ALTER TABLE ... SET [SERDE|SERDEPROPERTIES] command. @@ -847,7 +1000,10 @@ case class SetTableSerDeProperties( child: LogicalPlan, serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], - partitionSpec: Option[TablePartitionSpec]) extends UnaryCommand + partitionSpec: Option[TablePartitionSpec]) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): SetTableSerDeProperties = + copy(child = newChild) +} /** * The logical plan of the CACHE TABLE command. @@ -868,6 +1024,10 @@ case class CacheTableAsSelect( isLazy: Boolean, options: Map[String, String], isAnalyzed: Boolean = false) extends AnalysisOnlyCommand { + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): CacheTableAsSelect = + copy(plan = newChildren.head) + override def childrenToAnalyze: Seq[LogicalPlan] = plan :: Nil override def markAsAnalyzed(): LogicalPlan = copy(isAnalyzed = true) @@ -889,6 +1049,8 @@ case class SetTableLocation( partitionSpec: Option[TablePartitionSpec], location: String) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): SetTableLocation = + copy(table = newChild) } /** @@ -898,6 +1060,8 @@ case class SetTableProperties( table: LogicalPlan, properties: Map[String, String]) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(table = newChild) } /** @@ -908,4 +1072,6 @@ case class UnsetTableProperties( propertyKeys: Seq[String], ifExists: Boolean) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(table = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index c4002aa441a50..0f8c7887b2b1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -235,6 +235,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * than numPartitions) based on hashing expressions. */ def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) } /** @@ -284,6 +287,10 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) } } } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): RangePartitioning = + copy(ordering = newChildren.asInstanceOf[Seq[SortOrder]]) } /** @@ -326,6 +333,10 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): PartitioningCollection = + super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala index 990ae302dbbee..2a29137355a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala @@ -39,5 +39,7 @@ case class WriteToStream( override def child: LogicalPlan = inputQuery + override protected def withNewChildInternal(newChild: LogicalPlan): WriteToStream = + copy(inputQuery = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala index 34a4c13efb62e..407c70a591d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala @@ -57,5 +57,8 @@ case class WriteToStreamStatement( override def output: Seq[Attribute] = Nil override def child: LogicalPlan = inputQuery + + override protected def withNewChildInternal(newChild: LogicalPlan): WriteToStreamStatement = + copy(inputQuery = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8fc62382bdbba..3fab95cbe4c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -246,11 +246,50 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arr } + private def childrenFastEquals( + originalChildren: IndexedSeq[BaseType], newChildren: IndexedSeq[BaseType]): Boolean = { + val size = originalChildren.size + var i = 0 + while (i < size) { + if (!originalChildren(i).fastEquals(newChildren(i))) return false + i += 1 + } + true + } + + // This is a temporary solution, we will change the type of children to IndexedSeq in a + // followup PR + private def asIndexedSeq(seq: Seq[BaseType]): IndexedSeq[BaseType] = { + if (seq.isInstanceOf[IndexedSeq[BaseType]]) { + seq.asInstanceOf[IndexedSeq[BaseType]] + } else { + seq.toIndexedSeq + } + } + + final def withNewChildren(newChildren: Seq[BaseType]): BaseType = { + val childrenIndexedSeq = asIndexedSeq(children) + val newChildrenIndexedSeq = asIndexedSeq(newChildren) + assert(newChildrenIndexedSeq.size == childrenIndexedSeq.size, "Incorrect number of children") + if (childrenIndexedSeq.isEmpty || + childrenFastEquals(newChildrenIndexedSeq, childrenIndexedSeq)) { + this + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newChildrenIndexedSeq) + res.copyTagsFrom(this) + res + } + } + } + + protected def withNewChildrenInternal(newChildren: IndexedSeq[BaseType]): BaseType + /** * Returns a copy of this node with the children replaced. * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. */ - def withNewChildren(newChildren: Seq[BaseType]): BaseType = { + protected final def legacyWithNewChildren(newChildren: Seq[BaseType]): BaseType = { assert(newChildren.size == children.size, "Incorrect number of children") var changed = false val remainingNewChildren = newChildren.toBuffer @@ -355,7 +394,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def mapChildren(f: BaseType => BaseType): BaseType = { if (containsChild.nonEmpty) { - mapChildren(f, forceCopy = false) + withNewChildren(children.map(f)) } else { this } @@ -844,24 +883,96 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { trait LeafLike[T <: TreeNode[T]] { self: TreeNode[T] => override final def children: Seq[T] = Nil + override final def mapChildren(f: T => T): T = this.asInstanceOf[T] + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = this.asInstanceOf[T] } trait UnaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def child: T - @transient override final lazy val children: Seq[T] = child :: Nil + @transient override final lazy val children: Seq[T] = IndexedSeq(child) + + override final def mapChildren(f: T => T): T = { + val newChild = f(child) + if (newChild fastEquals child) { + this.asInstanceOf[T] + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildInternal(newChild) + res.copyTagsFrom(this.asInstanceOf[T]) + res + } + } + } + + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = { + assert(newChildren.size == 1, "Incorrect number of children") + withNewChildInternal(newChildren.head) + } + + protected def withNewChildInternal(newChild: T): T } trait BinaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def left: T def right: T - @transient override final lazy val children: Seq[T] = left :: right :: Nil + @transient override final lazy val children: Seq[T] = IndexedSeq(left, right) + + override final def mapChildren(f: T => T): T = { + var newLeft = f(left) + newLeft = if (newLeft fastEquals left) left else newLeft + var newRight = f(right) + newRight = if (newRight fastEquals right) right else newRight + + if (newLeft.eq(left) && newRight.eq(right)) { + this.asInstanceOf[T] + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newLeft, newRight) + res.copyTagsFrom(this.asInstanceOf[T]) + res + } + } + } + + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = { + assert(newChildren.size == 2, "Incorrect number of children") + withNewChildrenInternal(newChildren(0), newChildren(1)) + } + + protected def withNewChildrenInternal(newLeft: T, newRight: T): T } trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def first: T def second: T def third: T - @transient override final lazy val children: Seq[T] = first :: second :: third :: Nil + @transient override final lazy val children: Seq[T] = IndexedSeq(first, second, third) + + override final def mapChildren(f: T => T): T = { + var newFirst = f(first) + newFirst = if (newFirst fastEquals first) first else newFirst + var newSecond = f(second) + newSecond = if (newSecond fastEquals second) second else newSecond + var newThird = f(third) + newThird = if (newThird fastEquals third) third else newThird + + if (newFirst.eq(first) && newSecond.eq(second) && newThird.eq(third)) { + this.asInstanceOf[T] + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newFirst, newSecond, newThird) + res.copyTagsFrom(this.asInstanceOf[T]) + res + } + } + } + + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = { + assert(newChildren.size == 3, "Incorrect number of children") + withNewChildrenInternal(newChildren(0), newChildren(1), newChildren(2)) + } + + protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: T): T } trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] => @@ -869,5 +980,33 @@ trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def second: T def third: T def fourth: T - @transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil + @transient override final lazy val children: Seq[T] = IndexedSeq(first, second, third, fourth) + + override final def mapChildren(f: T => T): T = { + var newFirst = f(first) + newFirst = if (newFirst fastEquals first) first else newFirst + var newSecond = f(second) + newSecond = if (newSecond fastEquals second) second else newSecond + var newThird = f(third) + newThird = if (newThird fastEquals third) third else newThird + var newFourth = f(fourth) + newFourth = if (newFourth fastEquals fourth) fourth else newFourth + + if (newFirst.eq(first) && newSecond.eq(second) && newThird.eq(third) && newFourth.eq(fourth)) { + this.asInstanceOf[T] + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newFirst, newSecond, newThird, newFourth) + res.copyTagsFrom(this.asInstanceOf[T]) + res + } + } + } + + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = { + assert(newChildren.size == 4, "Incorrect number of children") + withNewChildrenInternal(newChildren(0), newChildren(1), newChildren(2), newChildren(3)) + } + + protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: T, newFourth: T): T } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ec5a9cc9afad5..aecbf241e3947 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -88,6 +88,8 @@ case class TestFunction( extends Expression with ImplicitCastInputTypes with Unevaluable { override def nullable: Boolean = true override def dataType: DataType = StringType + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } case class UnresolvedTestPlan() extends LeafNode { @@ -733,4 +735,36 @@ class AnalysisErrorSuite extends AnalysisTest { s"data type mismatch: argument 1 requires (int or bigint) type" :: Nil) } } + + test("SPARK-34946: correlated scalar subquery in grouping expressions only") { + val c1 = AttributeReference("c1", IntegerType)() + val c2 = AttributeReference("c2", IntegerType)() + val t = LocalRelation(c1, c2) + val plan = Aggregate( + ScalarSubquery( + Aggregate(Nil, sum($"c2").as("sum") :: Nil, + Filter($"t1.c1" === $"t2.c1", + t.as("t2"))) + ) :: Nil, + sum($"c2").as("sum") :: Nil, t.as("t1")) + assertAnalysisError(plan, "Correlated scalar subqueries in the group by clause must also be " + + "in the aggregate expressions" :: Nil) + } + + test("SPARK-34946: correlated scalar subquery in aggregate expressions only") { + val c1 = AttributeReference("c1", IntegerType)() + val c2 = AttributeReference("c2", IntegerType)() + val t = LocalRelation(c1, c2) + val plan = Aggregate( + $"c1" :: Nil, + ScalarSubquery( + Aggregate(Nil, sum($"c2").as("sum") :: Nil, + Filter($"t1.c1" === $"t2.c1", + t.as("t2"))) + ).as("sub") :: Nil, t.as("t1")) + assertAnalysisError(plan, "Correlated scalar subquery 'scalarsubquery(t1.c1)' is " + + "neither present in the group by, nor in an aggregate function. Add it to group by " + + "using ordinal position or wrap it in first() (or first_value) if you don't care " + + "which value you get." :: Nil) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index a6145c5421d48..9058e3eb3f041 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1623,12 +1623,16 @@ object TypeCoercionSuite { extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) override def dataType: DataType = NullType + override protected def withNewChildInternal(newChild: Expression): AnyTypeUnaryExpression = + copy(child = newChild) } case class NumericTypeUnaryExpression(child: Expression) extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def dataType: DataType = NullType + override protected def withNewChildInternal(newChild: Expression): NumericTypeUnaryExpression = + copy(child = newChild) } case class AnyTypeBinaryOperator(left: Expression, right: Expression) @@ -1636,6 +1640,9 @@ object TypeCoercionSuite { override def dataType: DataType = NullType override def inputType: AbstractDataType = AnyDataType override def symbol: String = "anytype" + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): AnyTypeBinaryOperator = + copy(left = newLeft, right = newRight) } case class NumericTypeBinaryOperator(left: Expression, right: Expression) @@ -1643,5 +1650,8 @@ object TypeCoercionSuite { override def dataType: DataType = NullType override def inputType: AbstractDataType = NumericType override def symbol: String = "numerictype" + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): NumericTypeBinaryOperator = + copy(left = newLeft, right = newRight) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 71993e1a369ec..dc62841e058e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -998,6 +998,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { case class StreamingPlanWrapper(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def isStreaming: Boolean = true + override protected def withNewChildInternal(newChild: LogicalPlan): StreamingPlanWrapper = + copy(child = newChild) } case class TestStreamingRelation(output: Seq[Attribute]) extends LeafNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 65671d253dc53..9bfe69b1709d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -314,4 +314,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel case class CodegenFallbackExpression(child: Expression) extends UnaryExpression with CodegenFallback { override def dataType: DataType = child.dataType + override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpression = + copy(child = newChild) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala index bcf8a22cf0b99..76ce96705c44f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, IntegerType} class TryCastSuite extends AnsiCastSuiteBase { override protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String]) = { @@ -48,4 +48,8 @@ class TryCastSuite extends AnsiCastSuiteBase { override def checkCastToNumericError(l: Literal, to: DataType, tryCastResult: Any): Unit = { checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value)) } + + test("try_cast: to_string") { + assert(TryCast(Literal("1"), IntegerType).toString == "try_cast(1 as int)") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 43579d4c903a1..02b6eed9ed050 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -104,4 +104,7 @@ case class ExprReuseOutput(child: Expression) extends UnaryExpression { row.update(0, child.eval(input)) row } + + override protected def withNewChildInternal(newChild: Expression): ExprReuseOutput = + copy(child = newChild) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 95d33da879a79..23a63d976d1ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, DeleteFromTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, UpdateAction, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, DeleteFromTable, InsertAction, InsertStarAction, LocalRelation, LogicalPlan, MergeIntoTable, UpdateAction, UpdateStarAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} @@ -474,6 +474,22 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val optimizedPlan = Optimize.execute(originalPlan) val expectedPlan = func(testRelation, anotherTestRelation, expectedCond).analyze comparePlans(optimizedPlan, expectedPlan) + + // Test with star actions + def mergePlanWithStar(expr: Expression): MergeIntoTable = { + val matchedActions = UpdateStarAction(Some(expr)) :: Nil + val notMatchedActions = InsertStarAction(Some(expr)) :: Nil + // Between source and target only one should have i and b as those are used for + // test expressions and both, source and target, having those columns is ambiguous . + // However, the source must have all the columns present in target for star resolution. + val source = LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType)) + val target = LocalRelation('a.array(IntegerType)) + MergeIntoTable(target, source, mergeCondition = expr, matchedActions, notMatchedActions) + } + val originalPlanWithStar = mergePlanWithStar(originalCond).analyze + val optimizedPlanWithStar = Optimize.execute(originalPlanWithStar) + val expectedPlanWithStar = mergePlanWithStar(expectedCond).analyze + comparePlans(optimizedPlanWithStar, expectedPlanWithStar) } private def testHigherOrderFunc( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 057dda3776837..c0648919ecc57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1519,9 +1519,8 @@ class DDLParserSuite extends AnalysisTest { SubqueryAlias("source", UnresolvedRelation(Seq("testcat2", "ns1", "ns2", "tbl"))), EqualTo(UnresolvedAttribute("target.col1"), UnresolvedAttribute("source.col1")), Seq(DeleteAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("delete")))), - UpdateAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("update"))), Seq())), - Seq(InsertAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("insert"))), - Seq())))) + UpdateStarAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("update"))))), + Seq(InsertStarAction(Some(EqualTo(UnresolvedAttribute("target.col2"), Literal("insert"))))))) } test("merge into table: columns aliases are not allowed") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 84452399de824..3784f40101702 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -66,6 +66,9 @@ class LogicalPlanSuite extends SparkFunSuite { case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = + copy(left = newLeft, right = newRight) } require(relation.isStreaming === false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala index 6f342b8d94379..009e2a731fe41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala @@ -28,6 +28,8 @@ class LogicalPlanIntegritySuite extends PlanTest { case class OutputTestPlan(child: LogicalPlan, output: Seq[Attribute]) extends UnaryNode { override val analyzed = true + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(child = newChild) } test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 4ad8475a0113c..0d316779d8bcb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -47,6 +47,8 @@ case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFall override def dataType: NullType = NullType override lazy val resolved = true override def eval(input: InternalRow): Any = null.asInstanceOf[Any] + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(optKey = if (optKey.isDefined) Some(newChildren(0)) else None) } case class ComplexPlan(exprs: Seq[Seq[Expression]]) @@ -59,6 +61,8 @@ case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable { override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) } case class SeqTupleExpression(sons: Seq[(Expression, Expression)], @@ -67,6 +71,9 @@ case class SeqTupleExpression(sons: Seq[(Expression, Expression)], override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) } case class JsonTestTreeNode(arg: Any) extends LeafNode { @@ -738,7 +745,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } object MalformedClassObject extends Serializable { - case class MalformedNameExpression(child: Expression) extends TaggingExpression + case class MalformedNameExpression(child: Expression) extends TaggingExpression { + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + } } test("SPARK-32999: TreeNode.nodeName should not throw malformed class name error") { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index e8ffa62e41794..6264d6341c65a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -39,11 +39,6 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.HadoopReadOptions; import org.apache.parquet.ParquetReadOptions; -import org.apache.parquet.bytes.BytesInput; -import org.apache.parquet.bytes.BytesUtils; -import org.apache.parquet.column.ColumnDescriptor; -import org.apache.parquet.column.values.ValuesReader; -import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; import org.apache.parquet.hadoop.BadConfigurationException; import org.apache.parquet.hadoop.ParquetFileReader; import org.apache.parquet.hadoop.ParquetInputFormat; @@ -198,62 +193,6 @@ public void close() throws IOException { } } - /** - * Utility classes to abstract over different way to read ints with different encodings. - * TODO: remove this layer of abstraction? - */ - abstract static class IntIterator { - abstract int nextInt() throws IOException; - } - - protected static final class ValuesReaderIntIterator extends IntIterator { - ValuesReader delegate; - - public ValuesReaderIntIterator(ValuesReader delegate) { - this.delegate = delegate; - } - - @Override - int nextInt() { - return delegate.readInteger(); - } - } - - protected static final class RLEIntIterator extends IntIterator { - RunLengthBitPackingHybridDecoder delegate; - - public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { - this.delegate = delegate; - } - - @Override - int nextInt() throws IOException { - return delegate.readInt(); - } - } - - protected static final class NullIntIterator extends IntIterator { - @Override - int nextInt() { return 0; } - } - - /** - * Creates a reader for definition and repetition levels, returning an optimized one if - * the levels are not needed. - */ - protected static IntIterator createRLEIterator( - int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { - try { - if (maxLevel == 0) return new NullIntIterator(); - return new RLEIntIterator( - new RunLengthBitPackingHybridDecoder( - BytesUtils.getWidthFromMaxInt(maxLevel), - bytes.toInputStream())); - } catch (IOException e) { - throw new IOException("could not read levels in page for col " + descriptor, e); - } - } - private static Map> toSetMultiMap(Map map) { Map> setMultiMap = new HashMap<>(); for (Map.Entry entry : map.entrySet()) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 672b73e94c42f..52620b0740851 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -48,8 +48,6 @@ import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; -import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ValuesReaderIntIterator; -import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.createRLEIterator; /** * Decoder to return values from a single column. @@ -82,14 +80,13 @@ public class VectorizedColumnReader { private final int maxDefLevel; /** - * Repetition/Definition/Value readers. + * Value readers. */ - private SpecificParquetRecordReaderBase.IntIterator repetitionLevelColumn; - private SpecificParquetRecordReaderBase.IntIterator definitionLevelColumn; private ValuesReader dataColumn; - // Only set if vectorized decoding is true. This is used instead of the row by row decoding - // with `definitionLevelColumn`. + /** + * Vectorized RLE decoder for definition levels + */ private VectorizedRleValuesReader defColumn; /** @@ -171,23 +168,6 @@ public VectorizedColumnReader( this.int96RebaseMode = int96RebaseMode; } - /** - * Advances to the next value. Returns true if the value is non-null. - */ - private boolean next() throws IOException { - if (valuesRead >= endOfPageValueCount) { - if (valuesRead >= totalValueCount) { - // How do we get here? Throw end of stream exception? - return false; - } - readPage(); - } - ++valuesRead; - // TODO: Don't read for flat schemas - //repetitionLevel = repetitionLevelColumn.nextInt(); - return definitionLevelColumn.nextInt() == maxDefLevel; - } - private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName) { boolean isSupported = false; switch (typeName) { @@ -854,23 +834,24 @@ private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) thr private void readPageV1(DataPageV1 page) throws IOException { this.pageValueCount = page.getValueCount(); - ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); - ValuesReader dlReader; // Initialize the decoders. if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) { throw new UnsupportedOperationException("Unsupported encoding: " + page.getDlEncoding()); } + int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); this.defColumn = new VectorizedRleValuesReader(bitWidth); - dlReader = this.defColumn; - this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); - this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { BytesInput bytes = page.getBytes(); ByteBufferInputStream in = bytes.toInputStream(); - rlReader.initFromPage(pageValueCount, in); - dlReader.initFromPage(pageValueCount, in); + + // only used now to consume the repetition level data + page.getRlEncoding() + .getValuesReader(descriptor, REPETITION_LEVEL) + .initFromPage(pageValueCount, in); + + defColumn.initFromPage(pageValueCount, in); initDataReader(page.getValueEncoding(), in); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); @@ -879,15 +860,11 @@ private void readPageV1(DataPageV1 page) throws IOException { private void readPageV2(DataPageV2 page) throws IOException { this.pageValueCount = page.getValueCount(); - this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), - page.getRepetitionLevels(), descriptor); int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); // do not read the length from the stream. v2 pages handle dividing the page bytes. - this.defColumn = new VectorizedRleValuesReader(bitWidth, false); - this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); - this.defColumn.initFromPage( - this.pageValueCount, page.getDefinitionLevels().toInputStream()); + defColumn = new VectorizedRleValuesReader(bitWidth, false); + defColumn.initFromPage(this.pageValueCount, page.getDefinitionLevels().toInputStream()); try { initDataReader(page.getDataEncoding(), page.getData().toInputStream()); } catch (IOException e) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 595da20ad5e9b..6a0038dbdc44c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -226,6 +226,16 @@ public final void readBytes(int total, WritableColumnVector c, int rowId) { } } + @Override + public final void readShorts(int total, WritableColumnVector c, int rowId) { + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + for (int i = 0; i < total; i += 1) { + c.putShort(rowId + i, (short) buffer.getInt()); + } + } + @Override public final boolean readBoolean() { // TODO: vectorize decoding and keep boolean[] instead of currentByte @@ -260,6 +270,11 @@ public final byte readByte() { return (byte) readInteger(); } + @Override + public short readShort() { + return (short) readInteger(); + } + @Override public final float readFloat() { return getBuffer(4).getFloat(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 2eed66278be8c..125506d4d5013 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -359,9 +359,7 @@ public void readShorts( switch (mode) { case RLE: if (currentValue == level) { - for (int i = 0; i < n; i++) { - c.putShort(rowId + i, (short)data.readInteger()); - } + data.readShorts(n, c, rowId); } else { c.putNulls(rowId, n); } @@ -369,7 +367,7 @@ public void readShorts( case PACKED: for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { - c.putShort(rowId + i, (short)data.readInteger()); + c.putShort(rowId + i, data.readShort()); } else { c.putNull(rowId + i); } @@ -694,11 +692,21 @@ public byte readByte() { throw new UnsupportedOperationException("only readInts is valid."); } + @Override + public short readShort() { + throw new UnsupportedOperationException("only readInts is valid."); + } + @Override public void readBytes(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } + @Override + public void readShorts(int total, WritableColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + @Override public void readLongs(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index d09f750beb285..a2d663fd8c8b6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -28,6 +28,7 @@ public interface VectorizedValuesReader { boolean readBoolean(); byte readByte(); + short readShort(); int readInteger(); long readLong(); float readFloat(); @@ -39,6 +40,7 @@ public interface VectorizedValuesReader { */ void readBooleans(int total, WritableColumnVector c, int rowId); void readBytes(int total, WritableColumnVector c, int rowId); + void readShorts(int total, WritableColumnVector c, int rowId); void readIntegers(int total, WritableColumnVector c, int rowId); void readIntegersWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase); void readUnsignedIntegers(int total, WritableColumnVector c, int rowId); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index b0bbb52bc4990..500425e4809e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -78,6 +78,9 @@ case class CollectMetricsExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): CollectMetricsExec = + copy(child = newChild) } object CollectMetricsExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 8d542792a0e28..6bdd93e9230b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -201,6 +201,9 @@ case class ColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition w override def inputRDDs(): Seq[RDD[InternalRow]] = { Seq(child.executeColumnar().asInstanceOf[RDD[InternalRow]]) // Hack because of type erasure } + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarToRowExec = + copy(child = newChild) } /** @@ -486,6 +489,9 @@ case class RowToColumnarExec(child: SparkPlan) extends RowToColumnarTransition { } } } + + override protected def withNewChildInternal(newChild: SparkPlan): RowToColumnarExec = + copy(child = newChild) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 6f5bf15d82638..3fd653130e57c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -203,4 +203,7 @@ case class ExpandExec( |} """.stripMargin } + + override protected def withNewChildInternal(newChild: SparkPlan): ExpandExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 0d5ec2d6c6f1c..6c7929437ffdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -325,4 +325,7 @@ case class GenerateExec( if (condition) Seq(code) else Seq.empty } + + override protected def withNewChildInternal(newChild: SparkPlan): GenerateExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 08950c827f5aa..48da6f0410690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Instant, LocalDate, ZoneOffset} +import java.time.{Duration, Instant, LocalDate, Period, ZoneOffset} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, periodToMonths, toDayTimeIntervalString, toYearMonthIntervalString} import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand, ShowViewsCommand} import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTablesExec} import org.apache.spark.sql.internal.SQLConf @@ -117,6 +118,10 @@ object HiveResult { struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveString((v, t.dataType), true, formatters)}""" }.mkString("{", ",", "}") + case (period: Period, YearMonthIntervalType) => + toYearMonthIntervalString(periodToMonths(period)) + case (duration: Duration, DayTimeIntervalType) => + toDayTimeIntervalString(durationToMicros(duration)) case (other, _: UserDefinedType[_]) => other.toString } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 6b6ca531c6d3b..984a45cd058ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -202,4 +202,7 @@ case class SortExec( } super.cleanupResources() } + + override protected def withNewChildInternal(newChild: SparkPlan): SortExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index 75c91667012a3..7f3628926e351 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -72,6 +72,9 @@ case class SparkScriptTransformationExec( outputIterator } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkScriptTransformationExec = + copy(child = newChild) } case class SparkScriptTransformationWriterThread( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala index cece43090cb76..a735d913c953a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala @@ -39,4 +39,7 @@ case class SubqueryAdaptiveBroadcastExec( throw new UnsupportedOperationException( "SubqueryAdaptiveBroadcastExec does not support the execute() code path.") } + + override protected def withNewChildInternal(newChild: SparkPlan): SubqueryAdaptiveBroadcastExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala index 70ba13550afbf..47cb70dde86a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala @@ -113,6 +113,9 @@ case class SubqueryBroadcastExec( } override def stringArgs: Iterator[Any] = super.stringArgs ++ Iterator(s"[id=#$id]") + + override protected def withNewChildInternal(newChild: SparkPlan): SubqueryBroadcastExec = + copy(child = newChild) } object SubqueryBroadcastExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9c50dc91b6385..85bc98d194fee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -554,6 +554,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod } override def needCopyResult: Boolean = false + + override protected def withNewChildInternal(newChild: SparkPlan): InputAdapter = + copy(child = newChild) } object WholeStageCodegenExec { @@ -829,6 +832,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) override def limitNotReachedChecks: Seq[String] = Nil override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) + + override protected def withNewChildInternal(newChild: SparkPlan): WholeStageCodegenExec = + copy(child = newChild)(codegenStageId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala index 4639ccc11fc6a..f2eefbc028b5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala @@ -195,4 +195,7 @@ case class CustomShuffleReaderExec private( override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { shuffleRDD.asInstanceOf[RDD[ColumnarBatch]] } + + override protected def withNewChildInternal(newChild: SparkPlan): CustomShuffleReaderExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 7d45638146a71..6e23a2844d148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -1108,6 +1108,9 @@ case class HashAggregateExec( s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } + + override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec = + copy(child = newChild) } object HashAggregateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index e5f59e0d4e9bf..559f545dc05ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -138,6 +138,9 @@ case class ObjectHashAggregateExec( s"ObjectHashAggregate(keys=$keyString, functions=$functionString)" } } + + override protected def withNewChildInternal(newChild: SparkPlan): ObjectHashAggregateExec = + copy(child = newChild) } object ObjectHashAggregateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2400ceef544d6..4fb0f44db81c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -101,4 +101,7 @@ case class SortAggregateExec( s"SortAggregate(key=$keyString, functions=$functionString)" } } + + override protected def withNewChildInternal(newChild: SparkPlan): SortAggregateExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index ea44c6013b7d9..d958790dd09b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -203,6 +203,10 @@ case class SimpleTypedAggregateExpression( schema: StructType): TypedAggregateExpression = { copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): SimpleTypedAggregateExpression = + super.legacyWithNewChildren(newChildren).asInstanceOf[SimpleTypedAggregateExpression] } case class ComplexTypedAggregateExpression( @@ -285,4 +289,8 @@ case class ComplexTypedAggregateExpression( schema: StructType): TypedAggregateExpression = { copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ComplexTypedAggregateExpression = + super.legacyWithNewChildren(newChildren).asInstanceOf[ComplexTypedAggregateExpression] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index e6851a9af739f..1aae76e0fb29b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -454,6 +454,9 @@ case class ScalaUDAF( override def nodeName: String = name override def name: String = udafName.getOrElse(udaf.getClass.getSimpleName) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ScalaUDAF = + copy(children = newChildren) } case class ScalaAggregator[IN, BUF, OUT]( @@ -520,6 +523,10 @@ case class ScalaAggregator[IN, BUF, OUT]( override def nodeName: String = name override def name: String = aggregatorName.getOrElse(agg.getClass.getSimpleName) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ScalaAggregator[IN, BUF, OUT] = + copy(children = newChildren) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index abd336006848b..b537040fe71df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -107,6 +107,9 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) |${ExplainUtils.generateFieldString("Input", child.output)} |""".stripMargin } + + override protected def withNewChildInternal(newChild: SparkPlan): ProjectExec = + copy(child = newChild) } trait GeneratePredicateHelper extends PredicateHelper { @@ -286,6 +289,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) |Condition : ${condition} |""".stripMargin } + + override protected def withNewChildInternal(newChild: SparkPlan): FilterExec = + copy(child = newChild) } /** @@ -392,6 +398,9 @@ case class SampleExec( """.stripMargin.trim } } + + override protected def withNewChildInternal(newChild: SparkPlan): SampleExec = + copy(child = newChild) } @@ -687,6 +696,9 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): UnionExec = + copy(children = newChildren) } /** @@ -720,6 +732,9 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN child.execute().coalesce(numPartitions, shuffle = false) } } + + override protected def withNewChildInternal(newChild: SparkPlan): CoalesceExec = + copy(child = newChild) } object CoalesceExec { @@ -849,6 +864,9 @@ case class SubqueryExec(name: String, child: SparkPlan, maxNumRows: Option[Int] } override def stringArgs: Iterator[Any] = Iterator(name, child) ++ Iterator(s"[id=#$id]") + + override protected def withNewChildInternal(newChild: SparkPlan): SubqueryExec = + copy(child = newChild) } object SubqueryExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 641bd26c381ad..e3c2e90a42dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types._ case class AnalyzeColumnCommand( tableIdent: TableIdentifier, columnNames: Option[Seq[String]], - allColumns: Boolean) extends RunnableCommand { + allColumns: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 51d4c5f41b1d3..5b3cb7476608b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.util.PartitioningUtils case class AnalyzePartitionCommand( tableIdent: TableIdentifier, partitionSpec: Map[String, Option[String]], - noscan: Boolean = true) extends RunnableCommand { + noscan: Boolean = true) extends LeafRunnableCommand { private def getPartitionSpec(table: CatalogTable): Option[TablePartitionSpec] = { val normalizedPartitionSpec = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d114ca015d7ca..157554e821811 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier */ case class AnalyzeTableCommand( tableIdent: TableIdentifier, - noScan: Boolean = true) extends RunnableCommand { + noScan: Boolean = true) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { CommandUtils.analyzeTable(sparkSession, tableIdent, noScan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala index ef0701909de2e..c9b22a7d1b258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{Row, SparkSession} */ case class AnalyzeTablesCommand( databaseName: Option[String], - noScan: Boolean) extends RunnableCommand { + noScan: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index d065bc0dab4cd..be680a733eac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -42,7 +42,7 @@ case class InsertIntoDataSourceDirCommand( storage: CatalogStorageFormat, provider: String, query: LogicalPlan, - overwrite: Boolean) extends RunnableCommand { + overwrite: Boolean) extends LeafRunnableCommand { override def innerChildren: Seq[LogicalPlan] = query :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7d92e6e189fb2..0ebc927c552f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * set; * }}} */ -case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { +case class SetCommand(kv: Option[(String, Option[String])]) + extends LeafRunnableCommand with Logging { private def keyValueOutput: Seq[Attribute] = { val schema = StructType( @@ -169,7 +170,7 @@ object SetCommand { * reset spark.sql.session.timeZone; * }}} */ -case class ResetCommand(config: Option[String]) extends RunnableCommand with IgnoreCachedData { +case class ResetCommand(config: Option[String]) extends LeafRunnableCommand with IgnoreCachedData { override def run(sparkSession: SparkSession): Seq[Row] = { val globalInitialConfigs = sparkSession.sharedState.conf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 2f72af7f4b512..de5dbddbfa146 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData /** * Clear all cached data from the in-memory cache. */ -case object ClearCacheCommand extends RunnableCommand with IgnoreCachedData { +case object ClearCacheCommand extends LeafRunnableCommand with IgnoreCachedData { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.catalog.clearCache() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 8bc3cedff2426..7f4f816d328da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.LeafLike import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetric @@ -48,6 +49,8 @@ trait RunnableCommand extends Command { def run(sparkSession: SparkSession): Seq[Row] } +trait LeafRunnableCommand extends RunnableCommand with LeafLike[LogicalPlan] + /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. @@ -132,6 +135,9 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(sideEffectResult, 1) } + + override protected def withNewChildInternal(newChild: SparkPlan): DataWritingCommandExec = + copy(child = newChild) } /** @@ -150,7 +156,7 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) case class ExplainCommand( logicalPlan: LogicalPlan, mode: ExplainMode) - extends RunnableCommand { + extends LeafRunnableCommand { override val output: Seq[Attribute] = Seq(AttributeReference("plan", StringType, nullable = true)()) @@ -167,7 +173,7 @@ case class ExplainCommand( /** An explain command for users to see how a streaming batch is executed. */ case class StreamingExplainCommand( queryExecution: IncrementalExecution, - extended: Boolean) extends RunnableCommand { + extended: Boolean) extends LeafRunnableCommand { override val output: Seq[Attribute] = Seq(AttributeReference("plan", StringType, nullable = true)()) @@ -193,7 +199,7 @@ case class StreamingExplainCommand( case class ExternalCommandExecutor( runner: ExternalCommandRunner, command: String, - options: Map[String, String]) extends RunnableCommand { + options: Map[String, String]) extends LeafRunnableCommand { override def output: Seq[Attribute] = Seq(AttributeReference("command_output", StringType)()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index bb54457afdc78..bb3869ddf811e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType * }}} */ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) @@ -227,4 +227,7 @@ case class CreateDataSourceTableAsSelectCommand( throw ex } } + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 7330f5bee9c21..c7456cd9d2058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -69,7 +69,7 @@ case class CreateDatabaseCommand( path: Option[String], comment: Option[String], props: Map[String, String]) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -105,7 +105,7 @@ case class DropDatabaseCommand( databaseName: String, ifExists: Boolean, cascade: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) @@ -125,7 +125,7 @@ case class DropDatabaseCommand( case class AlterDatabasePropertiesCommand( databaseName: String, props: Map[String, String]) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -146,7 +146,7 @@ case class AlterDatabasePropertiesCommand( * }}} */ case class AlterDatabaseSetLocationCommand(databaseName: String, location: String) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -171,7 +171,7 @@ case class DescribeDatabaseCommand( databaseName: String, extended: Boolean, override val output: Seq[Attribute]) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val dbMetadata: CatalogDatabase = @@ -211,7 +211,7 @@ case class DropTableCommand( tableName: TableIdentifier, ifExists: Boolean, isView: Boolean, - purge: Boolean) extends RunnableCommand { + purge: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -264,7 +264,7 @@ case class AlterTableSetPropertiesCommand( tableName: TableIdentifier, properties: Map[String, String], isView: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -295,7 +295,7 @@ case class AlterTableUnsetPropertiesCommand( propKeys: Seq[String], ifExists: Boolean, isView: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -333,7 +333,7 @@ case class AlterTableUnsetPropertiesCommand( case class AlterTableChangeColumnCommand( tableName: TableIdentifier, columnName: String, - newColumn: StructField) extends RunnableCommand { + newColumn: StructField) extends LeafRunnableCommand { // TODO: support change column name/dataType/metadata/position. override def run(sparkSession: SparkSession): Seq[Row] = { @@ -402,7 +402,7 @@ case class AlterTableSerDePropertiesCommand( serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], partSpec: Option[TablePartitionSpec]) - extends RunnableCommand { + extends LeafRunnableCommand { // should never happen if we parsed things correctly require(serdeClassName.isDefined || serdeProperties.isDefined, @@ -454,7 +454,7 @@ case class AlterTableAddPartitionCommand( tableName: TableIdentifier, partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], ifNotExists: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -509,7 +509,7 @@ case class AlterTableRenamePartitionCommand( tableName: TableIdentifier, oldPartition: TablePartitionSpec, newPartition: TablePartitionSpec) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -556,7 +556,7 @@ case class AlterTableDropPartitionCommand( ifExists: Boolean, purge: Boolean, retainData: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -600,7 +600,7 @@ case class RepairTableCommand( tableName: TableIdentifier, enableAddPartitions: Boolean, enableDropPartitions: Boolean, - cmd: String = "MSCK REPAIR TABLE") extends RunnableCommand { + cmd: String = "MSCK REPAIR TABLE") extends LeafRunnableCommand { // These are list of statistics that can be collected quickly without requiring a scan of the data // see https://github.com/apache/hive/blob/master/ @@ -833,7 +833,7 @@ case class AlterTableSetLocationCommand( tableName: TableIdentifier, partitionSpec: Option[TablePartitionSpec], location: String) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index af5ba4839ea10..0eda90a596999 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -55,7 +55,7 @@ case class CreateFunctionCommand( isTemp: Boolean, ignoreIfExists: Boolean, replace: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { if (ignoreIfExists && replace) { throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" + @@ -112,7 +112,7 @@ case class CreateFunctionCommand( */ case class DescribeFunctionCommand( functionName: FunctionIdentifier, - isExtended: Boolean) extends RunnableCommand { + isExtended: Boolean) extends LeafRunnableCommand { override val output: Seq[Attribute] = { val schema = StructType(StructField("function_desc", StringType, nullable = false) :: Nil) @@ -177,7 +177,7 @@ case class DropFunctionCommand( functionName: String, ifExists: Boolean, isTemp: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -216,7 +216,7 @@ case class ShowFunctionsCommand( pattern: Option[String], showUserFunctions: Boolean, showSystemFunctions: Boolean, - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val dbName = db.getOrElse(sparkSession.sessionState.catalog.getCurrentDatabase) @@ -255,7 +255,7 @@ case class ShowFunctionsCommand( case class RefreshFunctionCommand( databaseName: Option[String], functionName: String) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala index 691837f38d7e3..af053f72cc647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.StringType /** * Adds a jar to the current session so it can be used (for UDFs or serdes). */ -case class AddJarCommand(path: String) extends RunnableCommand { +case class AddJarCommand(path: String) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sessionState.resourceLoader.addJar(path) Seq.empty[Row] @@ -39,7 +39,7 @@ case class AddJarCommand(path: String) extends RunnableCommand { /** * Adds a file to the current session so it can be used. */ -case class AddFileCommand(path: String) extends RunnableCommand { +case class AddFileCommand(path: String) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val recursive = !sparkSession.sessionState.conf.addSingleFileInAddFile sparkSession.sparkContext.addFile(path, recursive) @@ -50,7 +50,7 @@ case class AddFileCommand(path: String) extends RunnableCommand { /** * Adds an archive to the current session so it can be used. */ -case class AddArchiveCommand(path: String) extends RunnableCommand { +case class AddArchiveCommand(path: String) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sparkContext.addArchive(path) Seq.empty[Row] @@ -61,7 +61,7 @@ case class AddArchiveCommand(path: String) extends RunnableCommand { * Returns a list of file paths that are added to resources. * If file paths are provided, return the ones that are added to resources. */ -case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends RunnableCommand { +case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends LeafRunnableCommand { override val output: Seq[Attribute] = { AttributeReference("Results", StringType, nullable = false)() :: Nil } @@ -88,7 +88,7 @@ case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends Runn * Returns a list of jar files that are added to resources. * If jar files are provided, return the ones that are added to resources. */ -case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand { +case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends LeafRunnableCommand { override val output: Seq[Attribute] = { AttributeReference("Results", StringType, nullable = false)() :: Nil } @@ -109,7 +109,8 @@ case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends Runnab * Returns a list of archive paths that are added to resources. * If archive paths are provided, return the ones that are added to resources. */ -case class ListArchivesCommand(archives: Seq[String] = Seq.empty[String]) extends RunnableCommand { +case class ListArchivesCommand(archives: Seq[String] = Seq.empty[String]) + extends LeafRunnableCommand { override val output: Seq[Attribute] = { AttributeReference("Results", StringType, nullable = false)() :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 488c628fb8633..72168f243900f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -82,7 +82,7 @@ case class CreateTableLikeCommand( fileFormat: CatalogStorageFormat, provider: Option[String], properties: Map[String, String] = Map.empty, - ifNotExists: Boolean) extends RunnableCommand { + ifNotExists: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -161,7 +161,7 @@ case class CreateTableLikeCommand( */ case class CreateTableCommand( table: CatalogTable, - ignoreIfExists: Boolean) extends RunnableCommand { + ignoreIfExists: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sessionState.catalog.createTable(table, ignoreIfExists) @@ -183,7 +183,7 @@ case class AlterTableRenameCommand( oldName: TableIdentifier, newName: TableIdentifier, isView: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -224,7 +224,7 @@ case class AlterTableRenameCommand( */ case class AlterTableAddColumnsCommand( table: TableIdentifier, - colsToAdd: Seq[StructField]) extends RunnableCommand { + colsToAdd: Seq[StructField]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val catalogTable = verifyAlterTableAddColumn(sparkSession.sessionState.conf, catalog, table) @@ -300,7 +300,7 @@ case class LoadDataCommand( path: String, isLocal: Boolean, isOverwrite: Boolean, - partition: Option[TablePartitionSpec]) extends RunnableCommand { + partition: Option[TablePartitionSpec]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -441,7 +441,7 @@ object LoadDataCommand { */ case class TruncateTableCommand( tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec]) extends RunnableCommand { + partitionSpec: Option[TablePartitionSpec]) extends LeafRunnableCommand { override def run(spark: SparkSession): Seq[Row] = { val catalog = spark.sessionState.catalog @@ -580,7 +580,7 @@ case class TruncateTableCommand( } } -abstract class DescribeCommandBase extends RunnableCommand { +abstract class DescribeCommandBase extends LeafRunnableCommand { protected def describeSchema( schema: StructType, buffer: ArrayBuffer[Row], @@ -745,7 +745,7 @@ case class DescribeColumnCommand( colNameParts: Seq[String], isExtended: Boolean, override val output: Seq[Attribute]) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { @@ -828,7 +828,7 @@ case class ShowTablesCommand( tableIdentifierPattern: Option[String], override val output: Seq[Attribute], isExtended: Boolean = false, - partitionSpec: Option[TablePartitionSpec] = None) extends RunnableCommand { + partitionSpec: Option[TablePartitionSpec] = None) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { // Since we need to return a Seq of rows, we will call getTables directly @@ -888,7 +888,7 @@ case class ShowTablesCommand( case class ShowTablePropertiesCommand( table: TableIdentifier, propertyKey: Option[String], - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -924,7 +924,7 @@ case class ShowTablePropertiesCommand( case class ShowColumnsCommand( databaseName: Option[String], tableName: TableIdentifier, - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -955,7 +955,7 @@ case class ShowColumnsCommand( case class ShowPartitionsCommand( tableName: TableIdentifier, override val output: Seq[Attribute], - spec: Option[TablePartitionSpec]) extends RunnableCommand { + spec: Option[TablePartitionSpec]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -1080,7 +1080,7 @@ trait ShowCreateTableCommandBase { case class ShowCreateTableCommand( table: TableIdentifier, override val output: Seq[Attribute]) - extends RunnableCommand with ShowCreateTableCommandBase { + extends LeafRunnableCommand with ShowCreateTableCommandBase { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -1234,7 +1234,7 @@ case class ShowCreateTableCommand( case class ShowCreateTableAsSerdeCommand( table: TableIdentifier, override val output: Seq[Attribute]) - extends RunnableCommand with ShowCreateTableCommandBase { + extends LeafRunnableCommand with ShowCreateTableCommandBase { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -1354,7 +1354,7 @@ case class ShowCreateTableAsSerdeCommand( * }}} */ case class RefreshTableCommand(tableIdent: TableIdentifier) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { // Refresh the given table's metadata. If this table is cached as an InMemoryRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 2308ed7555476..745d463e960b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -71,6 +71,10 @@ case class CreateViewCommand( import ViewHelper._ + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): CreateViewCommand = + copy(plan = newChildren.head) + override def innerChildren: Seq[QueryPlan[_]] = Seq(plan) // `plan` needs to be analyzed, but shouldn't be optimized so that caching works correctly. @@ -244,6 +248,10 @@ case class AlterViewAsCommand( import ViewHelper._ + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): AlterViewAsCommand = + copy(query = newChildren.head) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def childrenToAnalyze: Seq[LogicalPlan] = query :: Nil @@ -312,7 +320,7 @@ case class AlterViewAsCommand( case class ShowViewsCommand( databaseName: String, tableIdentifierPattern: Option[String], - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 5f019557d337a..6300e10c0bb3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -68,6 +68,9 @@ object FileFormatWriter extends Logging { |}""".stripMargin }) } + + override protected def withNewChildInternal(newChild: Expression): Empty2Null = + copy(child = newChild) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index bd9cc0e44fca3..789b1d714fcb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -31,7 +31,7 @@ case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, overwrite: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def innerChildren: Seq[QueryPlan[_]] = Seq(query) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index b29ccb85d77a6..267b360b474ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -270,4 +270,7 @@ case class InsertIntoHadoopFsRelationCommand( } }.toMap } + + override protected def withNewChildInternal( + newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 5195bb295f5bf..486f73cab44f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.CreatableRelationProvider /** @@ -36,7 +36,7 @@ case class SaveIntoDataSourceCommand( query: LogicalPlan, dataSource: CreatableRelationProvider, options: Map[String, String], - mode: SaveMode) extends RunnableCommand { + mode: SaveMode) extends LeafRunnableCommand { override def innerChildren: Seq[QueryPlan[_]] = Seq(query) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 137e50236a295..221db208bc629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{DDLUtils, RunnableCommand} +import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} import org.apache.spark.sql.execution.command.ViewHelper.createTemporaryViewRelation import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.types._ @@ -52,6 +52,10 @@ case class CreateTable( override def children: Seq[LogicalPlan] = query.toSeq override def output: Seq[Attribute] = Seq.empty override lazy val resolved: Boolean = false + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = + copy(query = if (query.isDefined) Some(newChildren.head) else None) } /** @@ -63,7 +67,7 @@ case class CreateTempViewUsing( replace: Boolean, global: Boolean, provider: String, - options: Map[String, String]) extends RunnableCommand { + options: Map[String, String]) extends LeafRunnableCommand { if (tableIdent.database.isDefined) { throw new AnalysisException( @@ -123,7 +127,7 @@ case class CreateTempViewUsing( } case class RefreshResource(path: String) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.catalog.refreshByPath(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2ed0e06807bf0..764b63db35a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -47,6 +47,8 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) extends UnaryNode { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil + override protected def withNewChildInternal(newChild: LogicalPlan): WriteToDataSourceV2 = + copy(query = newChild) } /** @@ -82,6 +84,9 @@ case class CreateTableAsSelectExec( partitioning.toArray, properties.asJava) writeToTable(catalog, table, writeOptions, ident) } + + override protected def withNewChildInternal(newChild: SparkPlan): CreateTableAsSelectExec = + copy(query = newChild) } /** @@ -116,6 +121,9 @@ case class AtomicCreateTableAsSelectExec( ident, schema, partitioning.toArray, properties.asJava) writeToTable(catalog, stagedTable, writeOptions, ident) } + + override protected def withNewChildInternal(newChild: SparkPlan): AtomicCreateTableAsSelectExec = + copy(query = newChild) } /** @@ -160,6 +168,9 @@ case class ReplaceTableAsSelectExec( ident, schema, partitioning.toArray, properties.asJava) writeToTable(catalog, table, writeOptions, ident) } + + override protected def withNewChildInternal(newChild: SparkPlan): ReplaceTableAsSelectExec = + copy(query = newChild) } /** @@ -207,6 +218,9 @@ case class AtomicReplaceTableAsSelectExec( } writeToTable(catalog, staged, writeOptions, ident) } + + override protected def withNewChildInternal(newChild: SparkPlan): AtomicReplaceTableAsSelectExec = + copy(query = newChild) } /** @@ -217,7 +231,10 @@ case class AtomicReplaceTableAsSelectExec( case class AppendDataExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec + write: Write) extends V2ExistingTableWriteExec { + override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec = + copy(query = newChild) +} /** * Physical plan node for overwrite into a v2 table. @@ -232,7 +249,10 @@ case class AppendDataExec( case class OverwriteByExpressionExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec + write: Write) extends V2ExistingTableWriteExec { + override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec = + copy(query = newChild) +} /** * Physical plan node for dynamic partition overwrite into a v2 table. @@ -246,7 +266,10 @@ case class OverwriteByExpressionExec( case class OverwritePartitionsDynamicExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec + write: Write) extends V2ExistingTableWriteExec { + override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec = + copy(query = newChild) +} case class WriteToDataSourceV2Exec( batchWrite: BatchWrite, @@ -255,6 +278,9 @@ case class WriteToDataSourceV2Exec( override protected def run(): Seq[InternalRow] = { writeWithV2(batchWrite) } + + override protected def withNewChildInternal(newChild: SparkPlan): WriteToDataSourceV2Exec = + copy(query = newChild) } trait V2ExistingTableWriteExec extends V2TableWriteExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 3cbebca14f7dc..6c744e66d7abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -288,5 +288,8 @@ package object debug { } override def supportsColumnar: Boolean = child.supportsColumnar + + override protected def withNewChildInternal(newChild: SparkPlan): DebugExec = + copy(child = newChild) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index ca640c43a03a0..94a8a8f0d9e5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -205,6 +205,9 @@ case class BroadcastExchangeExec( ex) } } + + override protected def withNewChildInternal(newChild: SparkPlan): BroadcastExchangeExec = + copy(child = newChild) } object BroadcastExchangeExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 2a7b12f7f515a..6ec376764a38f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -166,6 +166,9 @@ case class ShuffleExchangeExec( } cachedShuffleRDD } + + override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec = + copy(child = newChild) } object ShuffleExchangeExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index cec1286c98a7e..ccbcaa2573f64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -254,4 +254,8 @@ case class BroadcastHashJoinExec( super.codegenAnti(ctx, input) } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): BroadcastHashJoinExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index fa1a57a8ae3a5..acdd346c84594 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -548,4 +548,8 @@ case class BroadcastNestedLoopJoinExec( """.stripMargin } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): BroadcastNestedLoopJoinExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index b6386d0d11b4b..1b2d3731f7e8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -101,4 +101,8 @@ case class CartesianProductExec( } } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): CartesianProductExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index cd57408e7972d..8514fc2fc4da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -318,4 +318,8 @@ case class ShuffledHashJoinExec( v => s"$v = $thisPlan.buildHashedRelation(inputs[1]);", forceInline = true) HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false) } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index eabbdc8ed3243..8e0b7173ad453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -633,6 +633,10 @@ case class SortMergeJoinExec( |$eagerCleanup """.stripMargin } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec = + copy(left = newLeft, right = newRight) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index e5a299523c79c..5114c075a72d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -73,6 +73,9 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec { singlePartitionRDD.mapPartitionsInternal(_.take(limit)) } } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } /** @@ -95,6 +98,9 @@ case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec { // job launch, we might just have to mimic the implementation of `CollectLimitExec`. sparkContext.parallelize(executeCollect(), numSlices = 1) } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } object BaseLimitExec { @@ -160,7 +166,10 @@ trait BaseLimitExec extends LimitExec with CodegenSupport { /** * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec +case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} /** * Take the first `limit` elements of the child's single output partition. @@ -168,6 +177,9 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } /** @@ -249,4 +261,7 @@ case class TakeOrderedAndProjectExec( s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c08db132c946f..fa46f75abe8f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -99,6 +99,9 @@ case class DeserializeToObjectExec( iter.map(projection) } } + + override protected def withNewChildInternal(newChild: SparkPlan): DeserializeToObjectExec = + copy(child = newChild) } /** @@ -135,6 +138,9 @@ case class SerializeFromObjectExec( iter.map(projection) } } + + override protected def withNewChildInternal(newChild: SparkPlan): SerializeFromObjectExec = + copy(child = newChild) } /** @@ -195,6 +201,9 @@ case class MapPartitionsExec( func(iter.map(getObject)).map(outputObject) } } + + override protected def withNewChildInternal(newChild: SparkPlan): MapPartitionsExec = + copy(child = newChild) } /** @@ -252,6 +261,9 @@ case class MapPartitionsInRWithArrowExec( }.map(outputProject) } } + + override protected def withNewChildInternal(newChild: SparkPlan): MapPartitionsInRWithArrowExec = + copy(child = newChild) } /** @@ -304,6 +316,9 @@ case class MapElementsExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def withNewChildInternal(newChild: SparkPlan): MapElementsExec = + copy(child = newChild) } /** @@ -333,6 +348,9 @@ case class AppendColumnsExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): AppendColumnsExec = + copy(child = newChild) } /** @@ -366,6 +384,9 @@ case class AppendColumnsWithObjectExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): AppendColumnsWithObjectExec = + copy(child = newChild) } /** @@ -405,6 +426,9 @@ case class MapGroupsExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): MapGroupsExec = + copy(child = newChild) } object MapGroupsExec { @@ -495,6 +519,9 @@ case class FlatMapGroupsInRExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInRExec = + copy(child = newChild) } /** @@ -577,6 +604,9 @@ case class FlatMapGroupsInRWithArrowExec( }.map(outputProject) } } + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInRWithArrowExec = + copy(child = newChild) } /** @@ -623,4 +653,7 @@ case class CoGroupExec( } } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): CoGroupExec = copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index dadf1129c34b5..5019008ec5e32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -154,4 +154,7 @@ case class AggregateInPandasExec( } }} } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 67f075f0785fb..096712cf93529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -94,4 +94,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] batch.rowIterator.asScala } } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 2ab7262763835..10f7966b93d1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -103,4 +103,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] } } } + + override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index b079405bdc2f8..e830ea6b54662 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -103,4 +103,8 @@ case class FlatMapCoGroupsInPandasExec( } } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): FlatMapCoGroupsInPandasExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 5032bc81327b9..3a3a6022f9985 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,4 +94,7 @@ case class FlatMapGroupsInPandasExec( executePython(data, output, runner) }} } + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInPandasExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala index 71f51f1abc6f5..0434710da43ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala @@ -93,4 +93,7 @@ case class MapInPandasExec( }.map(unsafeProj) } } + + override protected def withNewChildInternal(newChild: SparkPlan): MapInPandasExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 983fe9db73824..909a026bac7d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -401,4 +401,7 @@ case class WindowInPandasExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): WindowInPandasExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 20fb06a851dd7..7e094fee32547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -125,4 +125,7 @@ case class EventTimeWatermarkExec( a } } + + override protected def withNewChildInternal(newChild: SparkPlan): EventTimeWatermarkExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 747094b7791c1..fe788dd8b9408 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -246,4 +246,7 @@ case class FlatMapGroupsWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } } + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsWithStateExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 73d2f826f1126..b2c8141e5db0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -620,4 +620,8 @@ case class StreamingSymmetricHashJoinExec( def numUpdatedStateRows: Long = updatedStateRowsCount } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): StreamingSymmetricHashJoinExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 1923fc969801e..ceb52f520df66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -28,4 +28,6 @@ case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan extends UnaryNode { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil + override protected def withNewChildInternal( + newChild: LogicalPlan): WriteToContinuousDataSource = copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index f1898ad3f27ca..1e0caf4785d5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -70,4 +70,7 @@ case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPl sparkContext.emptyRDD } + + override protected def withNewChildInternal( + newChild: SparkPlan): WriteToContinuousDataSourceExec = copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index 4bacd71a55ec1..7989b941563a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -36,4 +36,7 @@ case class WriteToMicroBatchDataSource(write: StreamingWrite, query: LogicalPlan def createPlan(batchId: Long): WriteToDataSourceV2 = { WriteToDataSourceV2(new MicroBatchWrite(batchId, write), query) } + + override protected def withNewChildInternal(newChild: LogicalPlan): WriteToMicroBatchDataSource = + copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index e52f2a17b659d..b52603ebc0443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -281,6 +281,9 @@ case class StateStoreRestoreExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override protected def withNewChildInternal(newChild: SparkPlan): StateStoreRestoreExec = + copy(child = newChild) } /** @@ -436,6 +439,9 @@ case class StateStoreSaveExec( eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get } + + override protected def withNewChildInternal(newChild: SparkPlan): StateStoreSaveExec = + copy(child = newChild) } /** Physical operator for executing streaming Deduplicate. */ @@ -509,6 +515,9 @@ case class StreamingDeduplicateExec( override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get } + + override protected def withNewChildInternal(newChild: SparkPlan): StreamingDeduplicateExec = + copy(child = newChild) } object StreamingDeduplicateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala index e53e0644eb268..51723a25e04e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala @@ -95,6 +95,9 @@ case class StreamingGlobalLimitExec( private def getValueRow(value: Long): UnsafeRow = { UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value))) } + + override protected def withNewChildInternal(newChild: SparkPlan): StreamingGlobalLimitExec = + copy(child = newChild) } @@ -133,4 +136,7 @@ case class StreamingLocalLimitExec(limit: Int, child: SparkPlan) override def outputPartitioning: Partitioning = child.outputPartitioning override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): StreamingLocalLimitExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 9c950fd8033a7..15b85013c4621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -166,6 +166,9 @@ case class InSubqueryExec( exprId = ExprId(0), resultBroadcast = null) } + + override protected def withNewChildInternal(newChild: Expression): InSubqueryExec = + copy(child = newChild) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 6e0e36cbe5901..8011c803394d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -211,4 +211,7 @@ case class WindowExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): WindowExec = + copy(child = newChild) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 3144833b608be..b773396c050d2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -54,6 +54,41 @@ select count(a), a from (select 1 as a) tmp group by 2 having a > 0; -- mixed cases: group-by ordinals and aliases select a, a AS k, count(b) from data group by k, 1; +-- can use ordinal in CUBE +select a, b, count(1) from data group by cube(1, 2); + +-- mixed cases: can use ordinal in CUBE +select a, b, count(1) from data group by cube(1, b); + +-- can use ordinal with cube +select a, b, count(1) from data group by 1, 2 with cube; + +-- can use ordinal in ROLLUP +select a, b, count(1) from data group by rollup(1, 2); + +-- mixed cases: can use ordinal in ROLLUP +select a, b, count(1) from data group by rollup(1, b); + +-- can use ordinal with rollup +select a, b, count(1) from data group by 1, 2 with rollup; + +-- can use ordinal in GROUPING SETS +select a, b, count(1) from data group by grouping sets((1), (2), (1, 2)); + +-- mixed cases: can use ordinal in GROUPING SETS +select a, b, count(1) from data group by grouping sets((1), (b), (a, 2)); + +select a, b, count(1) from data group by a, 2 grouping sets((1), (b), (a, 2)); + +-- range error +select a, b, count(1) from data group by a, -1; + +select a, b, count(1) from data group by a, 3; + +select a, b, count(1) from data group by cube(-1, 2); + +select a, b, count(1) from data group by cube(1, 3); + -- turn off group by ordinal set spark.sql.groupByOrdinal=false; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index 48c4f8ac6503c..1f05dc08fd0a3 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 33 -- !query @@ -184,6 +184,204 @@ struct 3 3 2 +-- !query +select a, b, count(1) from data group by cube(1, 2) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL 1 3 +NULL 2 3 +NULL NULL 6 + + +-- !query +select a, b, count(1) from data group by cube(1, b) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL 1 3 +NULL 2 3 +NULL NULL 6 + + +-- !query +select a, b, count(1) from data group by 1, 2 with cube +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL 1 3 +NULL 2 3 +NULL NULL 6 + + +-- !query +select a, b, count(1) from data group by rollup(1, 2) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL NULL 6 + + +-- !query +select a, b, count(1) from data group by rollup(1, b) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL NULL 6 + + +-- !query +select a, b, count(1) from data group by 1, 2 with rollup +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL NULL 6 + + +-- !query +select a, b, count(1) from data group by grouping sets((1), (2), (1, 2)) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL 1 3 +NULL 2 3 + + +-- !query +select a, b, count(1) from data group by grouping sets((1), (b), (a, 2)) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL 1 3 +NULL 2 3 + + +-- !query +select a, b, count(1) from data group by a, 2 grouping sets((1), (b), (a, 2)) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +NULL 1 3 +NULL 2 3 + + +-- !query +select a, b, count(1) from data group by a, -1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +GROUP BY position -1 is not in select list (valid range is [1, 3]); line 1 pos 44 + + +-- !query +select a, b, count(1) from data group by a, 3 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +aggregate functions are not allowed in GROUP BY, but found count(1) + + +-- !query +select a, b, count(1) from data group by cube(-1, 2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +GROUP BY position -1 is not in select list (valid range is [1, 3]); line 1 pos 46 + + +-- !query +select a, b, count(1) from data group by cube(1, 3) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +grouping expressions sequence is empty, and 'data.a' is not an aggregate function. Wrap '(count(1) AS `count(1)`)' in windowing function(s) or wrap 'data.a' in first() (or first_value) if you don't care which value you get. + + -- !query set spark.sql.groupByOrdinal=false -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/try_cast.sql.out b/sql/core/src/test/resources/sql-tests/results/try_cast.sql.out index 810b82f7943df..8be8d6be3437e 100644 --- a/sql/core/src/test/resources/sql-tests/results/try_cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/try_cast.sql.out @@ -5,7 +5,7 @@ -- !query SELECT TRY_CAST('1.23' AS int) -- !query schema -struct +struct -- !query output NULL @@ -13,7 +13,7 @@ NULL -- !query SELECT TRY_CAST('1.23' AS long) -- !query schema -struct +struct -- !query output NULL @@ -21,7 +21,7 @@ NULL -- !query SELECT TRY_CAST('-4.56' AS int) -- !query schema -struct +struct -- !query output NULL @@ -29,7 +29,7 @@ NULL -- !query SELECT TRY_CAST('-4.56' AS long) -- !query schema -struct +struct -- !query output NULL @@ -37,7 +37,7 @@ NULL -- !query SELECT TRY_CAST('abc' AS int) -- !query schema -struct +struct -- !query output NULL @@ -45,7 +45,7 @@ NULL -- !query SELECT TRY_CAST('abc' AS long) -- !query schema -struct +struct -- !query output NULL @@ -53,7 +53,7 @@ NULL -- !query SELECT TRY_CAST('' AS int) -- !query schema -struct +struct -- !query output NULL @@ -61,7 +61,7 @@ NULL -- !query SELECT TRY_CAST('' AS long) -- !query schema -struct +struct -- !query output NULL @@ -69,7 +69,7 @@ NULL -- !query SELECT TRY_CAST(NULL AS int) -- !query schema -struct +struct -- !query output NULL @@ -77,7 +77,7 @@ NULL -- !query SELECT TRY_CAST(NULL AS long) -- !query schema -struct +struct -- !query output NULL @@ -85,7 +85,7 @@ NULL -- !query SELECT TRY_CAST('123.a' AS int) -- !query schema -struct +struct -- !query output NULL @@ -93,7 +93,7 @@ NULL -- !query SELECT TRY_CAST('123.a' AS long) -- !query schema -struct +struct -- !query output NULL @@ -101,7 +101,7 @@ NULL -- !query SELECT TRY_CAST('-2147483648' AS int) -- !query schema -struct +struct -- !query output -2147483648 @@ -109,7 +109,7 @@ struct -- !query SELECT TRY_CAST('-2147483649' AS int) -- !query schema -struct +struct -- !query output NULL @@ -117,7 +117,7 @@ NULL -- !query SELECT TRY_CAST('2147483647' AS int) -- !query schema -struct +struct -- !query output 2147483647 @@ -125,7 +125,7 @@ struct -- !query SELECT TRY_CAST('2147483648' AS int) -- !query schema -struct +struct -- !query output NULL @@ -133,7 +133,7 @@ NULL -- !query SELECT TRY_CAST('-9223372036854775808' AS long) -- !query schema -struct +struct -- !query output -9223372036854775808 @@ -141,7 +141,7 @@ struct -- !query SELECT TRY_CAST('-9223372036854775809' AS long) -- !query schema -struct +struct -- !query output NULL @@ -149,7 +149,7 @@ NULL -- !query SELECT TRY_CAST('9223372036854775807' AS long) -- !query schema -struct +struct -- !query output 9223372036854775807 @@ -157,7 +157,7 @@ struct -- !query SELECT TRY_CAST('9223372036854775808' AS long) -- !query schema -struct +struct -- !query output NULL @@ -165,7 +165,7 @@ NULL -- !query SELECT TRY_CAST('interval 3 month 1 hour' AS interval) -- !query schema -struct +struct -- !query output 3 months 1 hours @@ -173,7 +173,7 @@ struct -- !query SELECT TRY_CAST('abc' AS interval) -- !query schema -struct +struct -- !query output NULL @@ -181,7 +181,7 @@ NULL -- !query select TRY_CAST('true' as boolean) -- !query schema -struct +struct -- !query output true @@ -189,7 +189,7 @@ true -- !query select TRY_CAST('false' as boolean) -- !query schema -struct +struct -- !query output false @@ -197,7 +197,7 @@ false -- !query select TRY_CAST('abc' as boolean) -- !query schema -struct +struct -- !query output NULL @@ -205,7 +205,7 @@ NULL -- !query SELECT TRY_CAST("2021-01-01" AS date) -- !query schema -struct +struct -- !query output 2021-01-01 @@ -213,7 +213,7 @@ struct -- !query SELECT TRY_CAST("2021-101-01" AS date) -- !query schema -struct +struct -- !query output NULL @@ -221,7 +221,7 @@ NULL -- !query SELECT TRY_CAST("2021-01-01 00:00:00" AS timestamp) -- !query schema -struct +struct -- !query output 2021-01-01 00:00:00 @@ -229,6 +229,6 @@ struct -- !query SELECT TRY_CAST("2021-101-01 00:00:00" AS timestamp) -- !query schema -struct +struct -- !query output NULL diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt index e4ec487623d2c..0c191216db316 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt @@ -720,7 +720,7 @@ Input [6]: [i_brand_id#104, i_class_id#105, i_category_id#106, sales#116, number (130) Expand [codegen id : 130] Input [6]: [sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56] -Arguments: [List(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56, 0), List(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, null, 1), List(sales#68, number_sales#69, channel#73, i_brand_id#54, null, null, 3), List(sales#68, number_sales#69, channel#73, null, null, null, 7), List(sales#68, number_sales#69, null, null, null, null, 15)], [sales#68, number_sales#69, channel#120, i_brand_id#121, i_class_id#122, i_category_id#123, spark_grouping_id#124] +Arguments: [ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56, 0), ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, null, 1), ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, null, null, 3), ArrayBuffer(sales#68, number_sales#69, channel#73, null, null, null, 7), ArrayBuffer(sales#68, number_sales#69, null, null, null, null, 15)], [sales#68, number_sales#69, channel#120, i_brand_id#121, i_class_id#122, i_category_id#123, spark_grouping_id#124] (131) HashAggregate [codegen id : 130] Input [7]: [sales#68, number_sales#69, channel#120, i_brand_id#121, i_class_id#122, i_category_id#123, spark_grouping_id#124] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt index 6f61fc8e96ae1..ffcbef4ce1602 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt @@ -625,7 +625,7 @@ Input [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sales#109, number_sa (111) Expand [codegen id : 79] Input [6]: [sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48] -Arguments: [List(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48, 0), List(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, null, 1), List(sales#63, number_sales#64, channel#68, i_brand_id#46, null, null, 3), List(sales#63, number_sales#64, channel#68, null, null, null, 7), List(sales#63, number_sales#64, null, null, null, null, 15)], [sales#63, number_sales#64, channel#113, i_brand_id#114, i_class_id#115, i_category_id#116, spark_grouping_id#117] +Arguments: [ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48, 0), ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, null, 1), ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, null, null, 3), ArrayBuffer(sales#63, number_sales#64, channel#68, null, null, null, 7), ArrayBuffer(sales#63, number_sales#64, null, null, null, null, 15)], [sales#63, number_sales#64, channel#113, i_brand_id#114, i_class_id#115, i_category_id#116, spark_grouping_id#117] (112) HashAggregate [codegen id : 79] Input [7]: [sales#63, number_sales#64, channel#113, i_brand_id#114, i_class_id#115, i_category_id#116, spark_grouping_id#117] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt index 28a457258eff7..c9a772d3163ca 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt @@ -429,7 +429,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#95))#129,17,2) AS sales# (77) Expand [codegen id : 23] Input [5]: [sales#41, RETURNS#42, profit#43, channel#44, id#45] -Arguments: [List(sales#41, returns#42, profit#43, channel#44, id#45, 0), List(sales#41, returns#42, profit#43, channel#44, null, 1), List(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140] +Arguments: [ArrayBuffer(sales#41, returns#42, profit#43, channel#44, id#45, 0), ArrayBuffer(sales#41, returns#42, profit#43, channel#44, null, 1), ArrayBuffer(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140] (78) HashAggregate [codegen id : 23] Input [6]: [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt index cb130ce17795a..c01302bf69a40 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt @@ -414,7 +414,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#95))#128,17,2) AS sales# (74) Expand [codegen id : 20] Input [5]: [sales#41, RETURNS#42, profit#43, channel#44, id#45] -Arguments: [List(sales#41, returns#42, profit#43, channel#44, id#45, 0), List(sales#41, returns#42, profit#43, channel#44, null, 1), List(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139] +Arguments: [ArrayBuffer(sales#41, returns#42, profit#43, channel#44, id#45, 0), ArrayBuffer(sales#41, returns#42, profit#43, channel#44, null, 1), ArrayBuffer(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139] (75) HashAggregate [codegen id : 20] Input [6]: [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt index 4b2299ca2e749..dc5a7fc792af9 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt @@ -488,7 +488,7 @@ Input [6]: [wp_web_page_sk#77, sales#86, profit#87, wp_web_page_sk#92, returns#1 (85) Expand [codegen id : 23] Input [5]: [sales#18, returns#37, profit#38, channel#39, id#40] -Arguments: [List(sales#18, returns#37, profit#38, channel#39, id#40, 0), List(sales#18, returns#37, profit#38, channel#39, null, 1), List(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] +Arguments: [ArrayBuffer(sales#18, returns#37, profit#38, channel#39, id#40, 0), ArrayBuffer(sales#18, returns#37, profit#38, channel#39, null, 1), ArrayBuffer(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] (86) HashAggregate [codegen id : 23] Input [6]: [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt index 618da39637e23..62bd5aba36e53 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt @@ -488,7 +488,7 @@ Input [6]: [wp_web_page_sk#77, sales#86, profit#87, wp_web_page_sk#93, returns#1 (85) Expand [codegen id : 23] Input [5]: [sales#18, returns#37, profit#38, channel#39, id#40] -Arguments: [List(sales#18, returns#37, profit#38, channel#39, id#40, 0), List(sales#18, returns#37, profit#38, channel#39, null, 1), List(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] +Arguments: [ArrayBuffer(sales#18, returns#37, profit#38, channel#39, id#40, 0), ArrayBuffer(sales#18, returns#37, profit#38, channel#39, null, 1), ArrayBuffer(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] (86) HashAggregate [codegen id : 23] Input [6]: [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt index bdb1a52a18f2d..040407d99e48d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt @@ -590,7 +590,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#90))#117,17,2) AS (107) Expand [codegen id : 31] Input [5]: [sales#42, returns#43, profit#44, channel#45, id#46] -Arguments: [List(sales#42, returns#43, profit#44, channel#45, id#46, 0), List(sales#42, returns#43, profit#44, channel#45, null, 1), List(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] +Arguments: [ArrayBuffer(sales#42, returns#43, profit#44, channel#45, id#46, 0), ArrayBuffer(sales#42, returns#43, profit#44, channel#45, null, 1), ArrayBuffer(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] (108) HashAggregate [codegen id : 31] Input [6]: [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt index aa15d27d4e562..467127aa2e493 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt @@ -590,7 +590,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#90))#117,17,2) AS (107) Expand [codegen id : 31] Input [5]: [sales#42, returns#43, profit#44, channel#45, id#46] -Arguments: [List(sales#42, returns#43, profit#44, channel#45, id#46, 0), List(sales#42, returns#43, profit#44, channel#45, null, 1), List(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] +Arguments: [ArrayBuffer(sales#42, returns#43, profit#44, channel#45, id#46, 0), ArrayBuffer(sales#42, returns#43, profit#44, channel#45, null, 1), ArrayBuffer(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] (108) HashAggregate [codegen id : 31] Input [6]: [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6914330bb289d..70dc0d09bcad5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3637,5 +3637,7 @@ object DataFrameFunctionsSuite { override def dataType: DataType = child.dataType override lazy val resolved = true override def eval(input: InternalRow): Any = child.eval(input) + override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr = + copy(child = newChild) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 9192370cfa620..bec68fae08719 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} import org.apache.spark.sql.test.SharedSparkSession -case class FastOperator(output: Seq[Attribute]) extends SparkPlan { +case class FastOperator(output: Seq[Attribute]) extends LeafExecNode { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value @@ -35,7 +35,6 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { } override def producedAttributes: AttributeSet = outputSet - override def children: Seq[SparkPlan] = Nil } object TestStrategy extends Strategy { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 35d2513835611..d4a6d84ce2b30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -582,6 +582,10 @@ class ColumnarAlias(child: ColumnarExpression, name: String)( with ColumnarExpression { override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch) + + override protected def withNewChildInternal(newChild: Expression): ColumnarAlias = + new ColumnarAlias(newChild.asInstanceOf[ColumnarExpression], name)(exprId, qualifier, + explicitMetadata, nonInheritableMetadataKeys) } class ColumnarAttributeReference( @@ -641,6 +645,9 @@ class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def hashCode(): Int = super.hashCode() + + override def withNewChildInternal(newChild: SparkPlan): ColumnarProjectExec = + new ColumnarProjectExec(projectList, newChild) } /** @@ -705,6 +712,12 @@ class BrokenColumnarAdd( } ret } + + override def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BrokenColumnarAdd = + new BrokenColumnarAdd( + left = newLeft.asInstanceOf[ColumnarExpression], + right = newRight.asInstanceOf[ColumnarExpression], failOnError) } class CannotReplaceException(str: String) extends RuntimeException(str) { @@ -781,6 +794,8 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE override def child: SparkPlan = delegate.child override protected def doExecute(): RDD[InternalRow] = delegate.execute() override def outputPartitioning: Partitioning = delegate.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + super.legacyWithNewChildren(Seq(newChild)) } /** @@ -798,6 +813,9 @@ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends Broa override protected def doExecute(): RDD[InternalRow] = delegate.execute() override def doExecuteBroadcast[T](): Broadcast[T] = delegate.executeBroadcast() override def outputPartitioning: Partitioning = delegate.outputPartitioning + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + super.legacyWithNewChildren(Seq(newChild)) } class ReplacedRowToColumnarExec(override val child: SparkPlan) @@ -815,6 +833,9 @@ class ReplacedRowToColumnarExec(override val child: SparkPlan) } override def hashCode(): Int = super.hashCode() + + override def withNewChildInternal(newChild: SparkPlan): ReplacedRowToColumnarExec = + new ReplacedRowToColumnarExec(newChild) } case class MyPostRule() extends Rule[SparkPlan] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index abe94c2a0b410..986e625137a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -233,8 +233,7 @@ object TypedImperativeAggregateSuite { nullable: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[MaxValue] - with ImplicitCastInputTypes + extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes with UnaryLike[Expression] { override def createAggregationBuffer(): MaxValue = { @@ -297,6 +296,9 @@ object TypedImperativeAggregateSuite { val value = stream.readInt() new MaxValue(value, isValueSet) } + + override protected def withNewChildInternal(newChild: Expression): TypedMax = + copy(child = newChild) } private class MaxValue(var value: Int, var isValueSet: Boolean = false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index cef870b249985..2011d057338c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -600,6 +600,9 @@ case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def withNewChildInternal(newChild: SparkPlan): ExceptionInjectingOperator = + copy(child = newChild) } @SQLUserDefinedType(udt = classOf[SimpleTupleUDT]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala index dd2790040b9e8..df08acd35ef17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala @@ -60,4 +60,5 @@ case class LeafOp(override val supportsColumnar: Boolean) extends LeafExecNode { case class UnaryOp(child: SparkPlan, override val supportsColumnar: Boolean) extends UnaryExecNode { override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp = copy(child = newChild) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index fb97e15e4df63..9776e76b541ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -40,6 +40,9 @@ case class ColumnarExchange(child: SparkPlan) extends Exchange { override protected def doExecute(): RDD[InternalRow] = throw new RanRowBased override protected def doExecuteColumnar(): RDD[ColumnarBatch] = throw new RanColumnar + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExchange = + copy(child = newChild) } class ExchangeSuite extends SparkPlanTest with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala index a49beda2186b4..187fda749a983 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.time.{Duration, Period} + import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.connector.InMemoryTableCatalog import org.apache.spark.sql.execution.HiveResult._ @@ -107,4 +109,20 @@ class HiveResultSuite extends SharedSparkSession { } } } + + test("SPARK-34984: year-month interval formatting in hive result") { + val df = Seq(Period.ofYears(-10).minusMonths(1)).toDF("i") + val plan1 = df.queryExecution.executedPlan + assert(hiveResultString(plan1) === Seq("INTERVAL '-10-1' YEAR TO MONTH")) + val plan2 = df.selectExpr("array(i)").queryExecution.executedPlan + assert(hiveResultString(plan2) === Seq("[INTERVAL '-10-1' YEAR TO MONTH]")) + } + + test("SPARK-34984: day-time interval formatting in hive result") { + val df = Seq(Duration.ofDays(5).plusMillis(10)).toDF("i") + val plan1 = df.queryExecution.executedPlan + assert(hiveResultString(plan1) === Seq("INTERVAL '5 00:00:00.01' DAY TO SECOND")) + val plan2 = df.selectExpr("array(i)").queryExecution.executedPlan + assert(hiveResultString(plan2) === Seq("[INTERVAL '5 00:00:00.01' DAY TO SECOND]")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 1724f785c2ff9..0b30b8cdf2644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -1264,4 +1264,6 @@ private case class DummySparkPlan( ) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException override def output: Seq[Attribute] = Seq.empty + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = + copy(children = newChildren) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala index a31e2382940e6..1592949fe9a9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -58,4 +58,7 @@ case class ReferenceSort( override def outputOrdering: Seq[SortOrder] = sortOrder override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def withNewChildInternal(newChild: SparkPlan): ReferenceSort = + copy(child = newChild) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 4fd7bad4e376e..e1e37f7652bbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.sources.SimpleScanSource -import org.apache.spark.sql.types.{CharType, DoubleType, IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, CharType, DoubleType, IntegerType, LongType, StringType, StructField, StructType} class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ @@ -1229,6 +1229,7 @@ class PlanResolutionSuite extends AnalysisTest { mergeCondition match { case EqualTo(l: AttributeReference, r: AttributeReference) => assert(l.sameRef(ti) && r.sameRef(si)) + case Literal(_, BooleanType) => // this is acceptable as a merge condition case other => fail("unexpected merge condition " + other) } @@ -1309,6 +1310,28 @@ class PlanResolutionSuite extends AnalysisTest { case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) } + // merge with star should get resolved into specific actions even if there + // is no other unresolved expression in the merge + parseAndResolve(s""" + |MERGE INTO $target AS target + |USING $source AS source + |ON true + |WHEN MATCHED THEN UPDATE SET * + |WHEN NOT MATCHED THEN INSERT * + """.stripMargin) match { + case MergeIntoTable( + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), + mergeCondition, + Seq(UpdateAction(None, updateAssigns)), + Seq(InsertAction(None, insertAssigns))) => + + checkResolution(target, source, mergeCondition, None, None, None, + updateAssigns, insertAssigns, starInUpdate = true) + + case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) + } + // no additional conditions val sql3 = s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index b17c93503804c..b3d29df1b29bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand} import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.test.SharedSparkSession @@ -302,7 +302,7 @@ class DataFrameCallbackSuite extends QueryTest } /** A test command that throws `java.lang.Error` during execution. */ -case class ErrorTestCommand(foo: String) extends RunnableCommand { +case class ErrorTestCommand(foo: String) extends LeafRunnableCommand { override val output: Seq[Attribute] = Seq(AttributeReference("foo", StringType)()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 283c254b39602..fe5d74f889dbb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -130,6 +130,9 @@ case class CreateHiveTableAsSelectCommand( override def writingCommandClassName: String = Utils.getSimpleName(classOf[InsertIntoHiveTable]) + + override protected def withNewChildInternal( + newChild: LogicalPlan): CreateHiveTableAsSelectCommand = copy(query = newChild) } /** @@ -177,4 +180,7 @@ case class OptimizedCreateHiveTableAsSelectCommand( override def writingCommandClassName: String = Utils.getSimpleName(classOf[InsertIntoHadoopFsRelationCommand]) + + override protected def withNewChildInternal( + newChild: LogicalPlan): OptimizedCreateHiveTableAsSelectCommand = copy(query = newChild) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 2059f5bff9cbb..27fdb22391226 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -184,6 +184,9 @@ private[hive] case class HiveScriptTransformationExec( outputIterator } + + override protected def withNewChildInternal(newChild: SparkPlan): HiveScriptTransformationExec = + copy(child = newChild) } private[hive] case class HiveScriptTransformationWriterThread( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 7ef637ed553ad..09aa1e8eea1f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -137,5 +137,8 @@ case class InsertIntoHiveDirCommand( Seq.empty[Row] } + + override protected def withNewChildInternal( + newChild: LogicalPlan): InsertIntoHiveDirCommand = copy(query = newChild) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index bfb24cfedb55a..fcd11e67587cf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -343,4 +343,7 @@ case class InsertIntoHiveTable( isSrcLocal = false) } } + + override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoHiveTable = + copy(query = newChild) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 7717e6ee207d9..7c3d1617bfaeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -110,6 +110,9 @@ private[hive] case class HiveSimpleUDF( override def prettyName: String = name override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -186,6 +189,9 @@ private[hive] case class HiveGenericUDF( override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } /** @@ -279,6 +285,9 @@ private[hive] case class HiveGenericUDTF( } override def prettyName: String = name + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } /** @@ -528,6 +537,9 @@ private[hive] case class HiveUDAFFunction( buffer } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala index 0ef7b3383e086..ee233fbd7238f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala @@ -78,6 +78,9 @@ case class TestingTypedCount( copy(inputAggBufferOffset = newInputAggBufferOffset) override val prettyName: String = "typed_count" + + override protected def withNewChildInternal(newChild: Expression): TestingTypedCount = + copy(child = newChild) } object TestingTypedCount {