Skip to content

Commit

Permalink
[HUDI-4813] Fix infer keygen not work in sparksql side issue (#6634)
Browse files Browse the repository at this point in the history
* [HUDI-4813] Fix infer keygen not work in sparksql side issue

Co-authored-by: xiaoxingstack <[email protected]>
  • Loading branch information
2 people authored and yuzhaojing committed Sep 22, 2022
1 parent be67657 commit ba77748
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.hudi.common.config.{ConfigProperty, DFSPropertiesConfiguration
import org.apache.hudi.common.fs.ConsistencyGuardConfig
import org.apache.hudi.common.model.{HoodieTableType, WriteOperationType}
import org.apache.hudi.common.table.HoodieTableConfig
import org.apache.hudi.common.util.Option
import org.apache.hudi.common.util.{Option, StringUtils}
import org.apache.hudi.common.util.ValidationUtils.checkState
import org.apache.hudi.config.{HoodieClusteringConfig, HoodieWriteConfig}
import org.apache.hudi.hive.{HiveSyncConfig, HiveSyncConfigHolder, HiveSyncTool}
Expand Down Expand Up @@ -787,9 +787,13 @@ object DataSourceOptionsHelper {

def inferKeyGenClazz(props: TypedProperties): String = {
val partitionFields = props.getString(DataSourceWriteOptions.PARTITIONPATH_FIELD.key(), null)
if (partitionFields != null) {
val recordsKeyFields = props.getString(DataSourceWriteOptions.RECORDKEY_FIELD.key(), DataSourceWriteOptions.RECORDKEY_FIELD.defaultValue())
inferKeyGenClazz(recordsKeyFields, partitionFields)
}

def inferKeyGenClazz(recordsKeyFields: String, partitionFields: String): String = {
if (!StringUtils.isNullOrEmpty(partitionFields)) {
val numPartFields = partitionFields.split(",").length
val recordsKeyFields = props.getString(DataSourceWriteOptions.RECORDKEY_FIELD.key(), DataSourceWriteOptions.RECORDKEY_FIELD.defaultValue())
val numRecordKeyFields = recordsKeyFields.split(",").length
if (numPartFields == 1 && numRecordKeyFields == 1) {
classOf[SimpleKeyGenerator].getName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@

package org.apache.spark.sql.catalyst.catalog

import org.apache.hudi.AvroConversionUtils
import org.apache.hudi.DataSourceWriteOptions.OPERATION
import org.apache.hudi.HoodieWriterUtils._
import org.apache.hudi.common.config.DFSPropertiesConfiguration
import org.apache.hudi.common.model.HoodieTableType
import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient}
import org.apache.hudi.common.util.{StringUtils, ValidationUtils}
import org.apache.hudi.keygen.ComplexKeyGenerator
import org.apache.hudi.keygen.factory.HoodieSparkKeyGeneratorFactory
import org.apache.hudi.{AvroConversionUtils, DataSourceOptionsHelper}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.SchemaConverters
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.hudi.HoodieOptionConfig
import org.apache.spark.sql.hudi.HoodieOptionConfig.SQL_KEY_TABLE_PRIMARY_KEY
import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{AnalysisException, SparkSession}
Expand Down Expand Up @@ -288,7 +288,10 @@ class HoodieCatalogTable(val spark: SparkSession, var table: CatalogTable) exten
HoodieSparkKeyGeneratorFactory.convertToSparkKeyGenerator(
originTableConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key))
} else {
extraConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key) = classOf[ComplexKeyGenerator].getCanonicalName
val primaryKeys = table.properties.get(SQL_KEY_TABLE_PRIMARY_KEY.sqlKeyName).getOrElse(SQL_KEY_TABLE_PRIMARY_KEY.defaultValue.get)
val partitions = table.partitionColumnNames.mkString(",")
extraConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key) =
DataSourceOptionsHelper.inferKeyGenClazz(primaryKeys, partitions)
}
extraConfig.toMap
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ class TestHoodieSparkSqlWriter {
.setBasePath(tablePath1).build().getTableConfig
assert(tableConfig1.getHiveStylePartitioningEnable == "true")
assert(tableConfig1.getUrlEncodePartitioning == "false")
assert(tableConfig1.getKeyGeneratorClassName == classOf[ComplexKeyGenerator].getName)
assert(tableConfig1.getKeyGeneratorClassName == classOf[SimpleKeyGenerator].getName)
df.write.format("hudi")
.options(options)
.option(HoodieWriteConfig.TBL_NAME.key, tableName1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.types._

import org.junit.jupiter.api.Assertions.assertFalse

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -137,7 +136,7 @@ class TestCreateTable extends HoodieSparkSqlTestBase {
assertResult("dt")(tableConfig(HoodieTableConfig.PARTITION_FIELDS.key))
assertResult("id")(tableConfig(HoodieTableConfig.RECORDKEY_FIELDS.key))
assertResult("ts")(tableConfig(HoodieTableConfig.PRECOMBINE_FIELD.key))
assertResult(classOf[ComplexKeyGenerator].getCanonicalName)(tableConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key))
assertResult(classOf[SimpleKeyGenerator].getCanonicalName)(tableConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key))
assertResult("default")(tableConfig(HoodieTableConfig.DATABASE_NAME.key()))
assertResult(tableName)(tableConfig(HoodieTableConfig.NAME.key()))
assertFalse(tableConfig.contains(OPERATION.key()))
Expand Down Expand Up @@ -944,4 +943,75 @@ class TestCreateTable extends HoodieSparkSqlTestBase {

spark.sql("use default")
}

test("Test Infer KegGenClazz") {
def checkKeyGenerator(targetGenerator: String, tableName: String) = {
val tablePath = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location.getPath
val metaClient = HoodieTableMetaClient.builder()
.setBasePath(tablePath)
.setConf(spark.sessionState.newHadoopConf())
.build()
val realKeyGenerator =
metaClient.getTableConfig.getProps.asScala.toMap.get(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key).get
assertResult(targetGenerator)(realKeyGenerator)
}

val tableName = generateTableName

// Test Nonpartitioned table
spark.sql(
s"""
| create table $tableName (
| id int,
| name string,
| price double,
| ts long
| ) using hudi
| comment "This is a simple hudi table"
| tblproperties (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
checkKeyGenerator("org.apache.hudi.keygen.NonpartitionedKeyGenerator", tableName)
spark.sql(s"drop table $tableName")

// Test single partitioned table
spark.sql(
s"""
| create table $tableName (
| id int,
| name string,
| price double,
| ts long
| ) using hudi
| comment "This is a simple hudi table"
| partitioned by (ts)
| tblproperties (
| primaryKey = 'id',
| preCombineField = 'ts'
| )
""".stripMargin)
checkKeyGenerator("org.apache.hudi.keygen.SimpleKeyGenerator", tableName)
spark.sql(s"drop table $tableName")

// Test single partitioned dual record keys table
spark.sql(
s"""
| create table $tableName (
| id int,
| name string,
| price double,
| ts long
| ) using hudi
| comment "This is a simple hudi table"
| partitioned by (ts)
| tblproperties (
| primaryKey = 'id,name',
| preCombineField = 'ts'
| )
""".stripMargin)
checkKeyGenerator("org.apache.hudi.keygen.ComplexKeyGenerator", tableName)
spark.sql(s"drop table $tableName")
}
}

0 comments on commit ba77748

Please sign in to comment.