Skip to content

Commit

Permalink
Refactor common Delta Lake test code
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe committed Oct 18, 2023
1 parent 6334ece commit 965d69a
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 199 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
from asserts import assert_gpu_and_cpu_writes_are_equal_collect, with_cpu_session, with_gpu_session
from data_gen import copy_and_update
from delta_lake_write_test import delta_meta_allow
from delta_lake_utils import delta_meta_allow
from marks import allow_non_gpu, delta_lake
from pyspark.sql.functions import *
from spark_session import is_databricks104_or_later
Expand Down
19 changes: 9 additions & 10 deletions integration_tests/src/main/python/delta_lake_delete_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

from asserts import assert_equal, assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_write
from data_gen import *
from delta_lake_write_test import assert_gpu_and_cpu_delta_logs_equivalent, delta_meta_allow, delta_writes_enabled_conf
from delta_lake_merge_test import read_delta_path, read_delta_path_with_cdf, setup_dest_tables
from delta_lake_utils import *
from marks import *
from spark_session import is_before_spark_320, is_databricks_runtime, supports_delta_lake_deletion_vectors, \
with_cpu_session, with_gpu_session
Expand All @@ -30,7 +29,7 @@ def delta_sql_delete_test(spark_tmp_path, use_cdf, dest_table_func, delete_sql,
check_func, partition_columns=None):
data_path = spark_tmp_path + "/DELTA_DATA"
def setup_tables(spark):
setup_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns)
setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns)
def do_delete(spark, path):
return spark.sql(delete_sql.format(path=path))
with_cpu_session(setup_tables)
Expand Down Expand Up @@ -74,9 +73,9 @@ def checker(data_path, do_delete):
def test_delta_delete_disabled_fallback(spark_tmp_path, disable_conf):
data_path = spark_tmp_path + "/DELTA_DATA"
def setup_tables(spark):
setup_dest_tables(spark, data_path,
dest_table_func=lambda spark: unary_op_df(spark, int_gen),
use_cdf=False)
setup_delta_dest_tables(spark, data_path,
dest_table_func=lambda spark: unary_op_df(spark, int_gen),
use_cdf=False)
def write_func(spark, path):
delete_sql="DELETE FROM delta.`{}`".format(path)
spark.sql(delete_sql)
Expand All @@ -93,9 +92,9 @@ def write_func(spark, path):
def test_delta_deletion_vector_fallback(spark_tmp_path, use_cdf):
data_path = spark_tmp_path + "/DELTA_DATA"
def setup_tables(spark):
setup_dest_tables(spark, data_path,
dest_table_func=lambda spark: unary_op_df(spark, int_gen),
use_cdf=use_cdf, enable_deletion_vectors=True)
setup_delta_dest_tables(spark, data_path,
dest_table_func=lambda spark: unary_op_df(spark, int_gen),
use_cdf=use_cdf, enable_deletion_vectors=True)
def write_func(spark, path):
delete_sql="DELETE FROM delta.`{}`".format(path)
spark.sql(delete_sql)
Expand Down Expand Up @@ -182,7 +181,7 @@ def generate_dest_data(spark):
SetValuesGen(IntegerType(), range(5)),
SetValuesGen(StringType(), "abcdefg"),
string_gen, num_slices=num_slices_to_test)
with_cpu_session(lambda spark: setup_dest_tables(spark, data_path, generate_dest_data, use_cdf, partition_columns))
with_cpu_session(lambda spark: setup_delta_dest_tables(spark, data_path, generate_dest_data, use_cdf, partition_columns))
def do_delete(spark, path):
dest_table = DeltaTable.forPath(spark, path)
dest_table.delete("b > 'c'")
Expand Down
50 changes: 4 additions & 46 deletions integration_tests/src/main/python/delta_lake_merge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

from asserts import *
from data_gen import *
from delta_lake_utils import *
from marks import *
from delta_lake_write_test import assert_gpu_and_cpu_delta_logs_equivalent, delta_meta_allow, delta_writes_enabled_conf
from pyspark.sql.types import *
from spark_session import is_before_spark_320, is_databricks_runtime, is_databricks122_or_later, spark_version
from spark_session import is_before_spark_320, is_databricks_runtime, spark_version

# Databricks changes the number of files being written, so we cannot compare logs
num_slices_to_test = [10] if is_databricks_runtime() else [1, 10]
Expand All @@ -30,59 +30,17 @@
{"spark.rapids.sql.command.MergeIntoCommand": "true",
"spark.rapids.sql.command.MergeIntoCommandEdge": "true"})

delta_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec" if is_databricks122_or_later() else "ExecutedCommandExec"
delta_write_fallback_check = "DataWritingCommandExec" if is_databricks122_or_later() else "ExecutedCommandExec"

def read_delta_path(spark, path):
return spark.read.format("delta").load(path)

def read_delta_path_with_cdf(spark, path):
return spark.read.format("delta") \
.option("readChangeDataFeed", "true").option("startingVersion", 0) \
.load(path).drop("_commit_timestamp")

def schema_to_ddl(spark, schema):
return spark.sparkContext._jvm.org.apache.spark.sql.types.DataType.fromJson(schema.json()).toDDL()

def make_df(spark, gen, num_slices):
return three_col_df(spark, gen, SetValuesGen(StringType(), string.ascii_lowercase),
SetValuesGen(StringType(), string.ascii_uppercase), num_slices=num_slices)

def setup_dest_table(spark, path, dest_table_func, use_cdf, partition_columns=None, enable_deletion_vectors=False):
dest_df = dest_table_func(spark)
writer = dest_df.write.format("delta")
ddl = schema_to_ddl(spark, dest_df.schema)
table_properties = {}
if use_cdf:
table_properties['delta.enableChangeDataFeed'] = 'true'
if enable_deletion_vectors:
table_properties['delta.enableDeletionVectors'] = 'true'
if len(table_properties) > 0:
# if any table properties are specified then we need to use SQL to define the table
sql_text = "CREATE TABLE delta.`{path}` ({ddl}) USING DELTA".format(path=path, ddl=ddl)
if partition_columns:
sql_text += " PARTITIONED BY ({})".format(",".join(partition_columns))
properties = ', '.join(key + ' = ' + value for key, value in table_properties.items())
sql_text += " TBLPROPERTIES ({})".format(properties)
spark.sql(sql_text)
elif partition_columns:
writer = writer.partitionBy(*partition_columns)
if use_cdf or enable_deletion_vectors:
writer = writer.mode("append")
writer.save(path)

def setup_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns=None, enable_deletion_vectors=False):
for name in ["CPU", "GPU"]:
path = "{}/{}".format(data_path, name)
setup_dest_table(spark, path, dest_table_func, use_cdf, partition_columns, enable_deletion_vectors)

def delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, check_func,
partition_columns=None):
data_path = spark_tmp_path + "/DELTA_DATA"
src_table = spark_tmp_table_factory.get()
def setup_tables(spark):
setup_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns)
setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns)
src_table_func(spark).createOrReplaceTempView(src_table)
def do_merge(spark, path):
dest_table = spark_tmp_table_factory.get()
Expand Down Expand Up @@ -327,7 +285,7 @@ def test_delta_merge_dataframe_api(spark_tmp_path, use_cdf, num_slices):
from delta.tables import DeltaTable
data_path = spark_tmp_path + "/DELTA_DATA"
dest_table_func = lambda spark: two_col_df(spark, SetValuesGen(IntegerType(), [None] + list(range(100))), string_gen, seed=1, num_slices=num_slices)
with_cpu_session(lambda spark: setup_dest_tables(spark, data_path, dest_table_func, use_cdf))
with_cpu_session(lambda spark: setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf))
def do_merge(spark, path):
# Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous
src_df = two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b"))
Expand Down
9 changes: 4 additions & 5 deletions integration_tests/src/main/python/delta_lake_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from pyspark.sql import Row
from asserts import assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from delta_lake_merge_test import setup_dest_table
from delta_lake_write_test import delta_meta_allow
from delta_lake_utils import delta_meta_allow, setup_delta_dest_table
from marks import allow_non_gpu, delta_lake, ignore_order
from parquet_test import reader_opt_confs_no_native
from spark_session import with_cpu_session, with_gpu_session, is_databricks_runtime, \
Expand Down Expand Up @@ -79,9 +78,9 @@ def test_delta_deletion_vector_read_fallback(spark_tmp_path, use_cdf):
data_path = spark_tmp_path + "/DELTA_DATA"
conf = {"spark.databricks.delta.delete.deletionVectors.persistent": "true"}
def setup_tables(spark):
setup_dest_table(spark, data_path,
dest_table_func=lambda spark: unary_op_df(spark, int_gen),
use_cdf=use_cdf, enable_deletion_vectors=True)
setup_delta_dest_table(spark, data_path,
dest_table_func=lambda spark: unary_op_df(spark, int_gen),
use_cdf=use_cdf, enable_deletion_vectors=True)
spark.sql("INSERT INTO delta.`{}` VALUES(1)".format(data_path))
spark.sql("DELETE FROM delta.`{}` WHERE a = 1".format(data_path))
with_cpu_session(setup_tables, conf=conf)
Expand Down
18 changes: 7 additions & 11 deletions integration_tests/src/main/python/delta_lake_update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,20 @@

from asserts import assert_equal, assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_write
from data_gen import *
from delta_lake_write_test import assert_gpu_and_cpu_delta_logs_equivalent, delta_meta_allow, delta_writes_enabled_conf
from delta_lake_merge_test import read_delta_path, read_delta_path_with_cdf, setup_dest_tables
from delta_lake_utils import *
from marks import *
from spark_session import is_before_spark_320, is_databricks_runtime, is_databricks122_or_later, \
from spark_session import is_before_spark_320, is_databricks_runtime, \
supports_delta_lake_deletion_vectors, with_cpu_session, with_gpu_session

delta_update_enabled_conf = copy_and_update(delta_writes_enabled_conf,
{"spark.rapids.sql.command.UpdateCommand": "true",
"spark.rapids.sql.command.UpdateCommandEdge": "true"})

delta_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec" if is_databricks122_or_later() else "ExecutedCommandExec"
delta_write_fallback_check = "DataWritingCommandExec" if is_databricks122_or_later() else "ExecutedCommandExec"

def delta_sql_update_test(spark_tmp_path, use_cdf, dest_table_func, update_sql,
check_func, partition_columns=None, enable_deletion_vectors=False):
data_path = spark_tmp_path + "/DELTA_DATA"
def setup_tables(spark):
setup_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns, enable_deletion_vectors)
setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns, enable_deletion_vectors)
def do_update(spark, path):
return spark.sql(update_sql.format(path=path))
with_cpu_session(setup_tables)
Expand Down Expand Up @@ -78,9 +74,9 @@ def checker(data_path, do_update):
def test_delta_update_disabled_fallback(spark_tmp_path, disable_conf):
data_path = spark_tmp_path + "/DELTA_DATA"
def setup_tables(spark):
setup_dest_tables(spark, data_path,
dest_table_func=lambda spark: unary_op_df(spark, int_gen),
use_cdf=False)
setup_delta_dest_tables(spark, data_path,
dest_table_func=lambda spark: unary_op_df(spark, int_gen),
use_cdf=False)
def write_func(spark, path):
update_sql="UPDATE delta.`{}` SET a = 0".format(path)
spark.sql(update_sql)
Expand Down Expand Up @@ -173,7 +169,7 @@ def generate_dest_data(spark):
SetValuesGen(IntegerType(), range(5)),
SetValuesGen(StringType(), "abcdefg"),
string_gen, num_slices=num_slices_to_test)
with_cpu_session(lambda spark: setup_dest_tables(spark, data_path, generate_dest_data, use_cdf, partition_columns))
with_cpu_session(lambda spark: setup_delta_dest_tables(spark, data_path, generate_dest_data, use_cdf, partition_columns))
def do_update(spark, path):
dest_table = DeltaTable.forPath(spark, path)
dest_table.update(condition="b > 'c'", set={"c": f.col("b"), "a": f.lit(1)})
Expand Down
Loading

0 comments on commit 965d69a

Please sign in to comment.