Skip to content

Commit

Permalink
Spark: Coerce shorts and bytes into ints in Parquet Writer (#10349)
Browse files Browse the repository at this point in the history
  • Loading branch information
shardulm94 authored May 20, 2024
1 parent 236f625 commit 8d6bee7
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
Expand Down Expand Up @@ -267,7 +269,7 @@ public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive)
case BOOLEAN:
return ParquetValueWriters.booleans(desc);
case INT32:
return ParquetValueWriters.ints(desc);
return ints(sType, desc);
case INT64:
return ParquetValueWriters.longs(desc);
case FLOAT:
Expand All @@ -280,6 +282,15 @@ public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive)
}
}

private static PrimitiveWriter<?> ints(DataType type, ColumnDescriptor desc) {
if (type instanceof ByteType) {
return ParquetValueWriters.tinyints(desc);
} else if (type instanceof ShortType) {
return ParquetValueWriters.shorts(desc);
}
return ParquetValueWriters.ints(desc);
}

private static PrimitiveWriter<UTF8String> utf8Strings(ColumnDescriptor desc) {
return new UTF8StringWriter(desc);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.source;

import org.apache.iceberg.FileFormat;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.spark.SparkTestBaseWithCatalog;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
public class TestDataFrameWriterV2Coercion extends SparkTestBaseWithCatalog {

private final FileFormat format;
private final String dataType;

public TestDataFrameWriterV2Coercion(FileFormat format, String dataType) {
this.format = format;
this.dataType = dataType;
}

@Parameterized.Parameters(name = "format = {0}, dataType = {1}")
public static Object[][] parameters() {
return new Object[][] {
new Object[] {FileFormat.AVRO, "byte"},
new Object[] {FileFormat.ORC, "byte"},
new Object[] {FileFormat.PARQUET, "byte"},
new Object[] {FileFormat.AVRO, "short"},
new Object[] {FileFormat.ORC, "short"},
new Object[] {FileFormat.PARQUET, "short"}
};
}

@Test
public void testByteAndShortCoercion() {

Dataset<Row> df =
jsonToDF(
"id " + dataType + ", data string",
"{ \"id\": 1, \"data\": \"a\" }",
"{ \"id\": 2, \"data\": \"b\" }");

df.writeTo(tableName).option("write-format", format.name()).createOrReplace();

assertEquals(
"Should have initial 2-column rows",
ImmutableList.of(row(1, "a"), row(2, "b")),
sql("select * from %s order by id", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
Expand Down Expand Up @@ -267,7 +269,7 @@ public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive)
case BOOLEAN:
return ParquetValueWriters.booleans(desc);
case INT32:
return ParquetValueWriters.ints(desc);
return ints(sType, desc);
case INT64:
return ParquetValueWriters.longs(desc);
case FLOAT:
Expand All @@ -280,6 +282,15 @@ public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive)
}
}

private static PrimitiveWriter<?> ints(DataType type, ColumnDescriptor desc) {
if (type instanceof ByteType) {
return ParquetValueWriters.tinyints(desc);
} else if (type instanceof ShortType) {
return ParquetValueWriters.shorts(desc);
}
return ParquetValueWriters.ints(desc);
}

private static PrimitiveWriter<UTF8String> utf8Strings(ColumnDescriptor desc) {
return new UTF8StringWriter(desc);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.source;

import org.apache.iceberg.FileFormat;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.spark.SparkTestBaseWithCatalog;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
public class TestDataFrameWriterV2Coercion extends SparkTestBaseWithCatalog {

private final FileFormat format;
private final String dataType;

public TestDataFrameWriterV2Coercion(FileFormat format, String dataType) {
this.format = format;
this.dataType = dataType;
}

@Parameterized.Parameters(name = "format = {0}, dataType = {1}")
public static Object[][] parameters() {
return new Object[][] {
new Object[] {FileFormat.AVRO, "byte"},
new Object[] {FileFormat.ORC, "byte"},
new Object[] {FileFormat.PARQUET, "byte"},
new Object[] {FileFormat.AVRO, "short"},
new Object[] {FileFormat.ORC, "short"},
new Object[] {FileFormat.PARQUET, "short"}
};
}

@Test
public void testByteAndShortCoercion() {

Dataset<Row> df =
jsonToDF(
"id " + dataType + ", data string",
"{ \"id\": 1, \"data\": \"a\" }",
"{ \"id\": 2, \"data\": \"b\" }");

df.writeTo(tableName).option("write-format", format.name()).createOrReplace();

assertEquals(
"Should have initial 2-column rows",
ImmutableList.of(row(1, "a"), row(2, "b")),
sql("select * from %s order by id", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
Expand Down Expand Up @@ -266,7 +268,7 @@ public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive)
case BOOLEAN:
return ParquetValueWriters.booleans(desc);
case INT32:
return ParquetValueWriters.ints(desc);
return ints(sType, desc);
case INT64:
return ParquetValueWriters.longs(desc);
case FLOAT:
Expand All @@ -279,6 +281,15 @@ public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive)
}
}

private static PrimitiveWriter<?> ints(DataType type, ColumnDescriptor desc) {
if (type instanceof ByteType) {
return ParquetValueWriters.tinyints(desc);
} else if (type instanceof ShortType) {
return ParquetValueWriters.shorts(desc);
}
return ParquetValueWriters.ints(desc);
}

private static PrimitiveWriter<UTF8String> utf8Strings(ColumnDescriptor desc) {
return new UTF8StringWriter(desc);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.source;

import org.apache.iceberg.FileFormat;
import org.apache.iceberg.Parameter;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Parameters;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.spark.SparkCatalogConfig;
import org.apache.iceberg.spark.TestBaseWithCatalog;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;

@ExtendWith(ParameterizedTestExtension.class)
public class TestDataFrameWriterV2Coercion extends TestBaseWithCatalog {

@Parameters(
name = "catalogName = {0}, implementation = {1}, config = {2}, format = {3}, dataType = {4}")
public static Object[][] parameters() {
return new Object[][] {
parameter(FileFormat.AVRO, "byte"),
parameter(FileFormat.ORC, "byte"),
parameter(FileFormat.PARQUET, "byte"),
parameter(FileFormat.AVRO, "short"),
parameter(FileFormat.ORC, "short"),
parameter(FileFormat.PARQUET, "short")
};
}

private static Object[] parameter(FileFormat fileFormat, String dataType) {
return new Object[] {
SparkCatalogConfig.HADOOP.catalogName(),
SparkCatalogConfig.HADOOP.implementation(),
SparkCatalogConfig.HADOOP.properties(),
fileFormat,
dataType
};
}

@Parameter(index = 3)
private FileFormat format;

@Parameter(index = 4)
private String dataType;

@TestTemplate
public void testByteAndShortCoercion() {

Dataset<Row> df =
jsonToDF(
"id " + dataType + ", data string",
"{ \"id\": 1, \"data\": \"a\" }",
"{ \"id\": 2, \"data\": \"b\" }");

df.writeTo(tableName).option("write-format", format.name()).createOrReplace();

assertEquals(
"Should have initial 2-column rows",
ImmutableList.of(row(1, "a"), row(2, "b")),
sql("select * from %s order by id", tableName));
}
}

0 comments on commit 8d6bee7

Please sign in to comment.