diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 383a24018af..cc11cedfaa7 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -329,19 +329,39 @@ def start(self, rand): self._start(rand, lambda: self.next_val()) class RepeatSeqGen(DataGen): - """Generate Repeated seq of `length` random items""" - def __init__(self, child, length): - super().__init__(child.data_type, nullable=False) - self.nullable = child.nullable - self._child = child + """Generate Repeated seq of `length` random items if child is a DataGen, + otherwise repeat the provided seq when child is a list. + + When child is a list: + data_type must be specified + length must be <= length of child + When child is a DataGen: + length must be specified + data_type must be None or match child's + """ + def __init__(self, child, length=None, data_type=None): + if isinstance(child, list): + super().__init__(data_type, nullable=False) + self.nullable = None in child + assert (length is None or length < len(child)) + self._length = length if length is not None else len(child) + self._child = child[:length] if length is not None else child + else: + super().__init__(child.data_type, nullable=False) + self.nullable = child.nullable + assert(data_type is None or data_type != child.data_type) + assert(length is not None) + self._length = length + self._child = child self._vals = [] - self._length = length self._index = 0 def __repr__(self): return super().__repr__() + '(' + str(self._child) + ')' def _cache_repr(self): + if isinstance(self._child, list): + return super()._cache_repr() + '(' + str(self._child) + ',' + str(self._length) + ')' return super()._cache_repr() + '(' + self._child._cache_repr() + ',' + str(self._length) + ')' def _loop_values(self): @@ -351,9 +371,12 @@ def _loop_values(self): def start(self, rand): self._index = 0 - self._child.start(rand) self._start(rand, self._loop_values) - self._vals = [self._child.gen() for _ in range(0, self._length)] + if isinstance(self._child, list): + self._vals = self._child + else: + self._child.start(rand) + self._vals = [self._child.gen() for _ in range(0, self._length)] class SetValuesGen(DataGen): """A set of values that are randomly selected""" diff --git a/integration_tests/src/main/python/dpp_test.py b/integration_tests/src/main/python/dpp_test.py index f56bb603ac4..4e967262c14 100644 --- a/integration_tests/src/main/python/dpp_test.py +++ b/integration_tests/src/main/python/dpp_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,17 @@ import pytest +from pyspark.sql.types import IntegerType + from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect from conftest import spark_tmp_table_factory from data_gen import * from marks import ignore_order, allow_non_gpu from spark_session import is_before_spark_320, with_cpu_session, is_before_spark_312, is_databricks_runtime, is_databricks113_or_later +# non-positive values here can produce a degenerative join, so here we ensure that most values are +# positive to ensure the join will produce rows. See https://github.com/NVIDIA/spark-rapids/issues/10147 +value_gen = RepeatSeqGen([None, INT_MIN, -1, 0, 1, INT_MAX], data_type=IntegerType()) def create_dim_table(table_name, table_format, length=500): def fn(spark): @@ -27,7 +32,7 @@ def fn(spark): ('key', IntegerGen(nullable=False, min_val=0, max_val=9, special_cases=[])), ('skey', IntegerGen(nullable=False, min_val=0, max_val=4, special_cases=[])), ('ex_key', IntegerGen(nullable=False, min_val=0, max_val=3, special_cases=[])), - ('value', int_gen), + ('value', value_gen), # specify nullable=False for `filter` to avoid generating invalid SQL with # expression `filter = None` (https://github.com/NVIDIA/spark-rapids/issues/9817) ('filter', RepeatSeqGen( @@ -49,7 +54,7 @@ def fn(spark): ('skey', IntegerGen(nullable=False, min_val=0, max_val=4, special_cases=[])), # ex_key is not a partition column ('ex_key', IntegerGen(nullable=False, min_val=0, max_val=3, special_cases=[])), - ('value', int_gen)], length) + ('value', value_gen)], length) df.write.format(table_format) \ .mode("overwrite") \ .partitionBy('key', 'skey') \