Skip to content

Commit

Permalink
Force user-specified JDBC driver to take precedence.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Dec 30, 2015
1 parent be86268 commit 3554d68
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 50 deletions.
4 changes: 1 addition & 3 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported:
<tr>
<td><code>driver</code></td>
<td>
The class name of the JDBC driver needed to connect to this URL. This class will be loaded
on the master and workers before running an JDBC commands to allow the driver to
register itself with the JDBC subsystem.
The class name of the JDBC driver to use to connect to this URL.
</td>
</tr>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnection(url, props)
val conn = JdbcUtils.createConnectionFactory(url, props)()

try {
var tableExists = JdbcUtils.tableExists(conn, url, table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
val driver = parameters.getOrElse("driver", null)
val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
val partitionColumn = parameters.getOrElse("partitionColumn", null)
val lowerBound = parameters.getOrElse("lowerBound", null)
val upperBound = parameters.getOrElse("upperBound", null)
val numPartitions = parameters.getOrElse("numPartitions", null)

if (driver != null) DriverRegistry.register(driver)

if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,5 @@ object DriverRegistry extends Logging {
}
}
}

def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
case driver => driver.getClass.getCanonicalName
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.util.Properties

import scala.util.control.NonFatal
Expand All @@ -41,7 +41,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par
override def index: Int = idx
}


private[sql] object JDBCRDD extends Logging {

/**
Expand Down Expand Up @@ -120,7 +119,7 @@ private[sql] object JDBCRDD extends Logging {
*/
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val dialect = JdbcDialects.get(url)
val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)()
val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)()
try {
val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0")
try {
Expand Down Expand Up @@ -201,36 +200,13 @@ private[sql] object JDBCRDD extends Logging {
case _ => null
}

/**
* Given a driver string and an url, return a function that loads the
* specified driver string then returns a connection to the JDBC url.
* getConnector is run on the driver code, while the function it returns
* is run on the executor.
*
* @param driver - The class name of the JDBC driver for the given url, or null if the class name
* is not necessary.
* @param url - The JDBC url to connect to.
*
* @return A function that loads the driver and connects to the url.
*/
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) DriverRegistry.register(driver)
} catch {
case e: ClassNotFoundException =>
logWarning(s"Couldn't find class $driver", e)
}
DriverManager.getConnection(url, properties)
}
}


/**
* Build and return JDBCRDD from the given information.
*
* @param sc - Your SparkContext.
* @param schema - The Catalyst schema of the underlying database table.
* @param driver - The class name of the JDBC driver for the given url.
* @param url - The JDBC url to connect to.
* @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
* @param requiredColumns - The names of the columns to SELECT.
Expand All @@ -243,7 +219,6 @@ private[sql] object JDBCRDD extends Logging {
def scanTable(
sc: SparkContext,
schema: StructType,
driver: String,
url: String,
properties: Properties,
fqTable: String,
Expand All @@ -254,7 +229,7 @@ private[sql] object JDBCRDD extends Logging {
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
new JDBCRDD(
sc,
getConnector(driver, url, properties),
JdbcUtils.createConnectionFactory(url, properties),
pruneSchema(schema, requiredColumns),
fqTable,
quotedColumns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ private[sql] case class JDBCRelation(
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverRegistry.getDriverClassName(url)
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sqlContext.sparkContext,
schema,
driver,
url,
properties,
table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, PreparedStatement}
import java.sql.{Connection, Driver, DriverManager, PreparedStatement}
import java.util.Properties

import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal

Expand All @@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row}
object JdbcUtils extends Logging {

/**
* Establishes a JDBC connection.
* Returns a factory for creating connections to the given JDBC URL.
*
* @param url the JDBC url to connect to.
* @param properties JDBC connection properties.
*/
def createConnection(url: String, connectionProperties: Properties): Connection = {
JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
def createConnectionFactory(url: String, properties: Properties): () => Connection = {
val userSpecifiedDriverClass = Option(properties.getProperty("driver"))
userSpecifiedDriverClass.foreach(DriverRegistry.register)
// Performing this part of the logic on the driver guards against the corner-case where the
// driver returned for a URL is different on the driver and executors due to classpath
// differences.
val driverClass: String = userSpecifiedDriverClass.getOrElse {
DriverManager.getDriver(url).getClass.getCanonicalName
}
() => {
userSpecifiedDriverClass.foreach(DriverRegistry.register)
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
case d if d.getClass.getCanonicalName == driverClass => d
}.getOrElse {
throw new IllegalStateException(
s"Did not find registered driver with class $driverClass")
}
driver.connect(url, properties)
}
}

/**
Expand Down Expand Up @@ -242,15 +264,14 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
properties: Properties = new Properties()) {
properties: Properties) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}

val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val getConnection: () => Connection = createConnectionFactory(url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
Expand Down

0 comments on commit 3554d68

Please sign in to comment.