Skip to content

Commit

Permalink
[SPARK-12579][SQL] Force user-specified JDBC driver to take precedence
Browse files Browse the repository at this point in the history
Spark SQL's JDBC data source allows users to specify an explicit JDBC driver to load (using the `driver` argument), but in the current code it's possible that the user-specified driver will not be used when it comes time to actually create a JDBC connection.

In a nutshell, the problem is that you might have multiple JDBC drivers on the classpath that claim to be able to handle the same subprotocol, so simply registering the user-provided driver class with the our `DriverRegistry` and JDBC's `DriverManager` is not sufficient to ensure that it's actually used when creating the JDBC connection.

This patch addresses this issue by first registering the user-specified driver with the DriverManager, then iterating over the driver manager's loaded drivers in order to obtain the correct driver and use it to create a connection (previously, we just called `DriverManager.getConnection()` directly).

If a user did not specify a JDBC driver to use, then we call `DriverManager.getDriver` to figure out the class of the driver to use, then pass that class's name to executors; this guards against corner-case bugs in situations where the driver and executor JVMs might have different sets of JDBC drivers on their classpaths (previously, there was the (rare) potential for `DriverManager.getConnection()` to use different drivers on the driver and executors if the user had not explicitly specified a JDBC driver class and the classpaths were different).

This patch is inspired by a similar patch that I made to the `spark-redshift` library (databricks/spark-redshift#143), which contains its own modified fork of some of Spark's JDBC data source code (for cross-Spark-version compatibility reasons).

Author: Josh Rosen <[email protected]>

Closes #10519 from JoshRosen/jdbc-driver-precedence.

(cherry picked from commit 6c83d93)
Signed-off-by: Yin Huai <[email protected]>
  • Loading branch information
JoshRosen authored and yhuai committed Jan 4, 2016
1 parent b5a1f56 commit 7f37c1e
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 50 deletions.
4 changes: 1 addition & 3 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported:
<tr>
<td><code>driver</code></td>
<td>
The class name of the JDBC driver needed to connect to this URL. This class will be loaded
on the master and workers before running an JDBC commands to allow the driver to
register itself with the JDBC subsystem.
The class name of the JDBC driver to use to connect to this URL.
</td>
</tr>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnection(url, props)
val conn = JdbcUtils.createConnectionFactory(url, props)()

try {
var tableExists = JdbcUtils.tableExists(conn, url, table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
val driver = parameters.getOrElse("driver", null)
val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
val partitionColumn = parameters.getOrElse("partitionColumn", null)
val lowerBound = parameters.getOrElse("lowerBound", null)
val upperBound = parameters.getOrElse("upperBound", null)
val numPartitions = parameters.getOrElse("numPartitions", null)

if (driver != null) DriverRegistry.register(driver)

if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,5 @@ object DriverRegistry extends Logging {
}
}
}

def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
case driver => driver.getClass.getCanonicalName
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.util.Properties

import org.apache.commons.lang3.StringUtils
Expand All @@ -39,7 +39,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par
override def index: Int = idx
}


private[sql] object JDBCRDD extends Logging {

/**
Expand Down Expand Up @@ -118,7 +117,7 @@ private[sql] object JDBCRDD extends Logging {
*/
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val dialect = JdbcDialects.get(url)
val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)()
val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)()
try {
val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0")
try {
Expand Down Expand Up @@ -170,36 +169,13 @@ private[sql] object JDBCRDD extends Logging {
new StructType(columns map { name => fieldMap(name) })
}

/**
* Given a driver string and an url, return a function that loads the
* specified driver string then returns a connection to the JDBC url.
* getConnector is run on the driver code, while the function it returns
* is run on the executor.
*
* @param driver - The class name of the JDBC driver for the given url, or null if the class name
* is not necessary.
* @param url - The JDBC url to connect to.
*
* @return A function that loads the driver and connects to the url.
*/
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) DriverRegistry.register(driver)
} catch {
case e: ClassNotFoundException =>
logWarning(s"Couldn't find class $driver", e)
}
DriverManager.getConnection(url, properties)
}
}


/**
* Build and return JDBCRDD from the given information.
*
* @param sc - Your SparkContext.
* @param schema - The Catalyst schema of the underlying database table.
* @param driver - The class name of the JDBC driver for the given url.
* @param url - The JDBC url to connect to.
* @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
* @param requiredColumns - The names of the columns to SELECT.
Expand All @@ -212,7 +188,6 @@ private[sql] object JDBCRDD extends Logging {
def scanTable(
sc: SparkContext,
schema: StructType,
driver: String,
url: String,
properties: Properties,
fqTable: String,
Expand All @@ -223,7 +198,7 @@ private[sql] object JDBCRDD extends Logging {
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
new JDBCRDD(
sc,
getConnector(driver, url, properties),
JdbcUtils.createConnectionFactory(url, properties),
pruneSchema(schema, requiredColumns),
fqTable,
quotedColumns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ private[sql] case class JDBCRelation(
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverRegistry.getDriverClassName(url)
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sqlContext.sparkContext,
schema,
driver,
url,
properties,
table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, PreparedStatement}
import java.sql.{Connection, Driver, DriverManager, PreparedStatement}
import java.util.Properties

import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal

Expand All @@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row}
object JdbcUtils extends Logging {

/**
* Establishes a JDBC connection.
* Returns a factory for creating connections to the given JDBC URL.
*
* @param url the JDBC url to connect to.
* @param properties JDBC connection properties.
*/
def createConnection(url: String, connectionProperties: Properties): Connection = {
JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
def createConnectionFactory(url: String, properties: Properties): () => Connection = {
val userSpecifiedDriverClass = Option(properties.getProperty("driver"))
userSpecifiedDriverClass.foreach(DriverRegistry.register)
// Performing this part of the logic on the driver guards against the corner-case where the
// driver returned for a URL is different on the driver and executors due to classpath
// differences.
val driverClass: String = userSpecifiedDriverClass.getOrElse {
DriverManager.getDriver(url).getClass.getCanonicalName
}
() => {
userSpecifiedDriverClass.foreach(DriverRegistry.register)
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
case d if d.getClass.getCanonicalName == driverClass => d
}.getOrElse {
throw new IllegalStateException(
s"Did not find registered driver with class $driverClass")
}
driver.connect(url, properties)
}
}

/**
Expand Down Expand Up @@ -242,15 +264,14 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
properties: Properties = new Properties()) {
properties: Properties) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}

val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val getConnection: () => Connection = createConnectionFactory(url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
Expand Down

0 comments on commit 7f37c1e

Please sign in to comment.