Skip to content

Commit

Permalink
Add connector cast suite and fix found issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
urosstan-db committed May 20, 2024
1 parent 4fc2910 commit 9761f90
Show file tree
Hide file tree
Showing 6 changed files with 488 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.jdbc.cast

import java.sql.{Connection, DriverManager}

import scala.collection.JavaConverters._

import com.databricks.sql.connector.JDBCConnectorCastSuiteBase

import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
import org.apache.spark.sql.execution.datasources.v2.jdbc.MysqlTableCatalog
import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite, JdbcDialect, MySQLDatabaseOnDocker, MySQLDialect}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class MySQLCastSuite extends DockerJDBCIntegrationSuite with JDBCConnectorCastSuiteBase {
override val db: DatabaseOnDocker = new MySQLDatabaseOnDocker

override def dataPreparation(connection: Connection): Unit = { }

override protected def createConnection: Connection = {
val jdbcUrl = db.getJdbcUrl("127.0.0.1", externalPort)
DriverManager.getConnection(jdbcUrl, db.getJdbcProperties())
}

override def dialect: JdbcDialect = MySQLDialect

override def tableCatalog: TableCatalog = {
val catalog = new MysqlTableCatalog()
val options = Map(
"host" -> "127.0.0.1",
"port" -> externalPort.toString,
"database" -> "mysql",
"user" -> "root",
"password" -> "rootpass"
)

catalog.initialize("sfCat", new CaseInsensitiveStringMap(options.asJava))
catalog
}

override protected def dropTable(table: Identifier): Unit =
execUpdate(s"DROP TABLE ${table.toString}")

override protected def dropSchema(schemaName: String): Unit = {
// MySQL does not support CASCADE drop
execUpdate(s"DROP SCHEMA $schemaName")
}

override def createNumericTypesTable: Identifier = {
val identifier = Identifier.of(Array(schemaName), "CAST_NUMERIC_TABLE")
execUpdate(
s"""CREATE TABLE IF NOT EXISTS $schemaName.CAST_NUMERIC_TABLE
|(COL_TINY TINYINT, COL_SMALL SMALLINT, COL_MEDIUM MEDIUMINT, COL_INT INT, COL_BIG BIGINT,
| COL_DECIMAL DECIMAL(9,2), COL_FLOAT FLOAT, COL_DOUBLE DOUBLE)
|""".stripMargin)
execUpdate(
s"""INSERT INTO $schemaName.CAST_NUMERIC_TABLE VALUES
|(-100, -100, -100, -100, -100,
| -10.25, -10.256, -10.256)""".stripMargin)
execUpdate(
s"""INSERT INTO $schemaName.CAST_NUMERIC_TABLE VALUES
|(100, 100, 100, 100, 100,
| 10.25, 10.256, 10.256)""".stripMargin)
identifier
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.jdbc.cast

import java.sql.{Connection, DriverManager}

import scala.collection.JavaConverters._

import com.databricks.sql.connector.JDBCConnectorCastSuiteBase

import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
import org.apache.spark.sql.execution.datasources.v2.jdbc.PostgresqlTableCatalog
import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite, JdbcDialect, PostgresDialect}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class PostgreSQLCastSuite extends DockerJDBCIntegrationSuite with JDBCConnectorCastSuiteBase {
override val schemaName: String = "cast_schema"

override val db: DatabaseOnDocker = new DatabaseOnDocker {
override val imageName: String =
sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine")
override val env: Map[String, String] = Map(
"POSTGRES_PASSWORD" -> "rootpass"
)
override val usesIpc = false
override val jdbcPort = 5432

override def getJdbcUrl(ip: String, port: Int): String =
s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass"
}

override def dataPreparation(connection: Connection): Unit = { }

override protected def createConnection: Connection = {
val jdbcUrl = db.getJdbcUrl("127.0.0.1", externalPort)
DriverManager.getConnection(jdbcUrl, db.getJdbcProperties())
}

override def dialect: JdbcDialect = PostgresDialect

override def tableCatalog: TableCatalog = {
val catalog = new PostgresqlTableCatalog()
val options = Map(
"host" -> "127.0.0.1",
"port" -> externalPort.toString,
"database" -> "postgres",
"user" -> "postgres",
"password" -> "rootpass"
)

catalog.initialize("postgresCat", new CaseInsensitiveStringMap(options.asJava))
catalog
}

override protected def dropTable(table: Identifier): Unit =
execUpdate(s"DROP TABLE ${table.toString}")

override def createNumericTypesTable: Identifier = {
val identifier = Identifier.of(Array(schemaName), "cast_numeric_table")
execUpdate(
s"""CREATE TABLE IF NOT EXISTS $schemaName.cast_numeric_table
|(COL_SMALLINT SMALLINT, COL_INT INT, COL_BIGINT BIGINT,
| COL_DECIMAL DECIMAL(9,2), COL_REAL REAL, COL_DOUBLE FLOAT8)
|""".stripMargin)
execUpdate(
s"""INSERT INTO $schemaName.cast_numeric_table VALUES
|(-1000, -1000, -1000,
| -10.25, -10.256, -10.256)""".stripMargin)
execUpdate(
s"""INSERT INTO $schemaName.cast_numeric_table VALUES
|(1000, 1000, 1000,
| 10.25, 10.256, 10.256)""".stripMargin)
identifier
}

override def createStringTypeTable: Identifier = {
super.createStringTypeTable
// Return lower case table name
Identifier.of(Array(schemaName), "cast_string_table")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper {
supportedFunctions.contains(funcName)

class MySQLSQLBuilder extends JDBCSQLBuilder {
override def visitCast(expr: String, exprDataType: DataType, dataType: DataType): String = {
dataType match {
case _: IntegralType =>
// MySQL does not support cast to SHORT INT, BIGINT
s"CAST($expr AS SIGNED)"
case _: StringType =>
// GetJDBCType will return LONGTEXT for StringType, but LONGTEXT cannot be used for CAST
// function, so we need to override this type
s"CAST($expr AS CHAR)"
case _: DatetimeType =>
// TODO: Check whether is ok to cast to DATETIME instead of TIMESTAMP
// Get JDBC Type will return TIMESTAMP when conversion is being done, so to be consistent,
// it is better for now to throw exception instead of pushing down
throw new UnsupportedOperationException("Cannot cast to timestamp type")
case _: BooleanType =>
throw new UnsupportedOperationException("Cannot cast to boolean type")
case _ =>
super.visitCast(expr, exprDataType, dataType)
}
}

override def visitSortOrder(
sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = {
(sortDirection, nullOrdering) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ import java.util
import java.util.Locale

import scala.util.Using
import scala.util.control.NonFatal

import org.apache.spark.internal.LogKeys.COLUMN_NAME
import org.apache.spark.internal.MDC
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.expressions.{Expression, NamedReference}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand Down Expand Up @@ -149,6 +150,8 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper {
case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
case ShortType | ByteType => Some(JdbcType("SMALLINT", Types.SMALLINT))
case LongType => Some(JdbcType("BIGINT", Types.BIGINT))
case TimestampNTZType => Some(JdbcType("TIMESTAMP", Types.TIMESTAMP))
case t: DecimalType => Some(
JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
case ArrayType(et, _) if et.isInstanceOf[AtomicType] || et.isInstanceOf[ArrayType] =>
Expand Down Expand Up @@ -376,4 +379,27 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper {
case _ =>
}
}

override def compileExpression(expr: Expression): Option[String] = {
val builder = new PostgresSQLBuilder()
try {
Some(builder.build(expr))
} catch {
case NonFatal(e) =>
logWarning("Error occurs while compiling V2 expression", e)
None
}
}

private class PostgresSQLBuilder extends JDBCSQLBuilder {
override def visitCast(expr: String, exprDataType: DataType, dataType: DataType): String = {
if (exprDataType.isInstanceOf[NumericType] && dataType.isInstanceOf[TimestampType]) {
throw new UnsupportedOperationException("Cannot cast from numeric types to timestamp type")
}
if (exprDataType.isInstanceOf[NumericType] && dataType.isInstanceOf[BooleanType]) {
throw new UnsupportedOperationException("Cannot cast from numeric types to boolean type")
}
super.visitCast(expr, exprDataType, dataType)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.spark.sql.jdbc

import java.util.Locale

import scala.util.control.NonFatal

import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.sql.types.{BooleanType, DataType, FractionalType, TimestampType}

private case class SnowflakeDialect() extends JdbcDialect {
override def canHandle(url: String): Boolean =
Expand All @@ -31,6 +34,35 @@ private case class SnowflakeDialect() extends JdbcDialect {
// By default, BOOLEAN is mapped to BIT(1).
// but Snowflake does not have a BIT type. It uses BOOLEAN instead.
Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
case TimestampType =>
Option(JdbcType("TIMESTAMP_TZ", java.sql.Types.TIMESTAMP))
case _ => JdbcUtils.getCommonJDBCType(dt)
}

override def compileExpression(expr: Expression): Option[String] = {
val snowflakeSQLBuilder = new SnowflakeSQLBuilder()
try {
Some(snowflakeSQLBuilder.build(expr))
} catch {
case NonFatal(e) =>
logWarning("Error occurs while compiling V2 expression to snowflake", e)
None
}
}

private class SnowflakeSQLBuilder extends JDBCSQLBuilder {
override def visitCast(expr: String, exprDataType: DataType, dataType: DataType): String = {
if (exprDataType.isInstanceOf[FractionalType]) {
if (dataType.isInstanceOf[TimestampType]) {
throw new UnsupportedOperationException("Cannot cast fractional types to timestamp type")
}
if (dataType.isInstanceOf[BooleanType]) {
// Cast from floating point types fails only when column is referenced,
// but succeed for literal values
throw new UnsupportedOperationException("Cannot cast fractional types to boolean type")
}
}
super.visitCast(expr, exprDataType, dataType)
}
}
}
Loading

0 comments on commit 9761f90

Please sign in to comment.