From 8d6bee736884575da7368e0963268d1cbe362d90 Mon Sep 17 00:00:00 2001 From: Shardul Mahadik Date: Sun, 19 May 2024 18:17:36 -0700 Subject: [PATCH] Spark: Coerce shorts and bytes into ints in Parquet Writer (#10349) --- .../spark/data/SparkParquetWriters.java | 13 ++- .../source/TestDataFrameWriterV2Coercion.java | 69 ++++++++++++++++ .../spark/data/SparkParquetWriters.java | 13 ++- .../source/TestDataFrameWriterV2Coercion.java | 69 ++++++++++++++++ .../spark/data/SparkParquetWriters.java | 13 ++- .../source/TestDataFrameWriterV2Coercion.java | 81 +++++++++++++++++++ 6 files changed, 255 insertions(+), 3 deletions(-) create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java create mode 100644 spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java create mode 100644 spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java index 8baea6c5ab59..1a4f7052de39 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -49,9 +49,11 @@ import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; @@ -267,7 +269,7 @@ public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) case BOOLEAN: return ParquetValueWriters.booleans(desc); case INT32: - return ParquetValueWriters.ints(desc); + return ints(sType, desc); case INT64: return ParquetValueWriters.longs(desc); case FLOAT: @@ -280,6 +282,15 @@ public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) } } + private static PrimitiveWriter ints(DataType type, ColumnDescriptor desc) { + if (type instanceof ByteType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof ShortType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + private static PrimitiveWriter utf8Strings(ColumnDescriptor desc) { return new UTF8StringWriter(desc); } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java new file mode 100644 index 000000000000..efb6352ce8ba --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestDataFrameWriterV2Coercion extends SparkTestBaseWithCatalog { + + private final FileFormat format; + private final String dataType; + + public TestDataFrameWriterV2Coercion(FileFormat format, String dataType) { + this.format = format; + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "format = {0}, dataType = {1}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] {FileFormat.AVRO, "byte"}, + new Object[] {FileFormat.ORC, "byte"}, + new Object[] {FileFormat.PARQUET, "byte"}, + new Object[] {FileFormat.AVRO, "short"}, + new Object[] {FileFormat.ORC, "short"}, + new Object[] {FileFormat.PARQUET, "short"} + }; + } + + @Test + public void testByteAndShortCoercion() { + + Dataset df = + jsonToDF( + "id " + dataType + ", data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + df.writeTo(tableName).option("write-format", format.name()).createOrReplace(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1, "a"), row(2, "b")), + sql("select * from %s order by id", tableName)); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java index 8baea6c5ab59..1a4f7052de39 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -49,9 +49,11 @@ import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; @@ -267,7 +269,7 @@ public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) case BOOLEAN: return ParquetValueWriters.booleans(desc); case INT32: - return ParquetValueWriters.ints(desc); + return ints(sType, desc); case INT64: return ParquetValueWriters.longs(desc); case FLOAT: @@ -280,6 +282,15 @@ public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) } } + private static PrimitiveWriter ints(DataType type, ColumnDescriptor desc) { + if (type instanceof ByteType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof ShortType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + private static PrimitiveWriter utf8Strings(ColumnDescriptor desc) { return new UTF8StringWriter(desc); } diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java new file mode 100644 index 000000000000..efb6352ce8ba --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestDataFrameWriterV2Coercion extends SparkTestBaseWithCatalog { + + private final FileFormat format; + private final String dataType; + + public TestDataFrameWriterV2Coercion(FileFormat format, String dataType) { + this.format = format; + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "format = {0}, dataType = {1}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] {FileFormat.AVRO, "byte"}, + new Object[] {FileFormat.ORC, "byte"}, + new Object[] {FileFormat.PARQUET, "byte"}, + new Object[] {FileFormat.AVRO, "short"}, + new Object[] {FileFormat.ORC, "short"}, + new Object[] {FileFormat.PARQUET, "short"} + }; + } + + @Test + public void testByteAndShortCoercion() { + + Dataset df = + jsonToDF( + "id " + dataType + ", data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + df.writeTo(tableName).option("write-format", format.name()).createOrReplace(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1, "a"), row(2, "b")), + sql("select * from %s order by id", tableName)); + } +} diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java index b2e95df59eb8..209c06bacb3e 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -49,9 +49,11 @@ import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; @@ -266,7 +268,7 @@ public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) case BOOLEAN: return ParquetValueWriters.booleans(desc); case INT32: - return ParquetValueWriters.ints(desc); + return ints(sType, desc); case INT64: return ParquetValueWriters.longs(desc); case FLOAT: @@ -279,6 +281,15 @@ public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) } } + private static PrimitiveWriter ints(DataType type, ColumnDescriptor desc) { + if (type instanceof ByteType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof ShortType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + private static PrimitiveWriter utf8Strings(ColumnDescriptor desc) { return new UTF8StringWriter(desc); } diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java new file mode 100644 index 000000000000..f51a06853a69 --- /dev/null +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataFrameWriterV2Coercion.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(ParameterizedTestExtension.class) +public class TestDataFrameWriterV2Coercion extends TestBaseWithCatalog { + + @Parameters( + name = "catalogName = {0}, implementation = {1}, config = {2}, format = {3}, dataType = {4}") + public static Object[][] parameters() { + return new Object[][] { + parameter(FileFormat.AVRO, "byte"), + parameter(FileFormat.ORC, "byte"), + parameter(FileFormat.PARQUET, "byte"), + parameter(FileFormat.AVRO, "short"), + parameter(FileFormat.ORC, "short"), + parameter(FileFormat.PARQUET, "short") + }; + } + + private static Object[] parameter(FileFormat fileFormat, String dataType) { + return new Object[] { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + fileFormat, + dataType + }; + } + + @Parameter(index = 3) + private FileFormat format; + + @Parameter(index = 4) + private String dataType; + + @TestTemplate + public void testByteAndShortCoercion() { + + Dataset df = + jsonToDF( + "id " + dataType + ", data string", + "{ \"id\": 1, \"data\": \"a\" }", + "{ \"id\": 2, \"data\": \"b\" }"); + + df.writeTo(tableName).option("write-format", format.name()).createOrReplace(); + + assertEquals( + "Should have initial 2-column rows", + ImmutableList.of(row(1, "a"), row(2, "b")), + sql("select * from %s order by id", tableName)); + } +}