Skip to content

Commit

Permalink
[SPARK-18572][SQL] Add a method listPartitionNames to `ExternalCata…
Browse files Browse the repository at this point in the history
…log`

(Link to Jira issue: https://issues.apache.org/jira/browse/SPARK-18572)

## What changes were proposed in this pull request?

Currently Spark answers the `SHOW PARTITIONS` command by fetching all of the table's partition metadata from the external catalog and constructing partition names therefrom. The Hive client has a `getPartitionNames` method which is many times faster for this purpose, with the performance improvement scaling with the number of partitions in a table.

To test the performance impact of this PR, I ran the `SHOW PARTITIONS` command on two Hive tables with large numbers of partitions. One table has ~17,800 partitions, and the other has ~95,000 partitions. For the purposes of this PR, I'll call the former table `table1` and the latter table `table2`. I ran 5 trials for each table with before-and-after versions of this PR. The results are as follows:

Spark at bdc8153, `SHOW PARTITIONS table1`, times in seconds:
7.901
3.983
4.018
4.331
4.261

Spark at bdc8153, `SHOW PARTITIONS table2`
(Timed out after 10 minutes with a `SocketTimeoutException`.)

Spark at this PR, `SHOW PARTITIONS table1`, times in seconds:
3.801
0.449
0.395
0.348
0.336

Spark at this PR, `SHOW PARTITIONS table2`, times in seconds:
5.184
1.63
1.474
1.519
1.41

Taking the best times from each trial, we get a 12x performance improvement for a table with ~17,800 partitions and at least a 426x improvement for a table with ~95,000 partitions. More significantly, the latter command doesn't even complete with the current code in master.

This is actually a patch we've been using in-house at VideoAmp since Spark 1.1. It's made all the difference in the practical usability of our largest tables. Even with tables with about 1,000 partitions there's a performance improvement of about 2-3x.

## How was this patch tested?

I added a unit test to `VersionsSuite` which tests that the Hive client's `getPartitionNames` method returns the correct number of partitions.

Author: Michael Allman <[email protected]>

Closes #15998 from mallman/spark-18572-list_partition_names.
  • Loading branch information
Michael Allman authored and cloud-fan committed Dec 6, 2016
1 parent 4af142f commit 772ddbe
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,37 @@ abstract class ExternalCatalog {
table: String,
spec: TablePartitionSpec): Option[CatalogTablePartition]

/**
* List the names of all partitions that belong to the specified table, assuming it exists.
*
* For a table with partition columns p1, p2, p3, each partition name is formatted as
* `p1=v1/p2=v2/p3=v3`. Each partition column name and value is an escaped path name, and can be
* decoded with the `ExternalCatalogUtils.unescapePathName` method.
*
* The returned sequence is sorted as strings.
*
* A partial partition spec may optionally be provided to filter the partitions returned, as
* described in the `listPartitions` method.
*
* @param db database name
* @param table table name
* @param partialSpec partition spec
*/
def listPartitionNames(
db: String,
table: String,
partialSpec: Option[TablePartitionSpec] = None): Seq[String]

/**
* List the metadata of all partitions that belong to the specified table, assuming it exists.
*
* A partial partition spec may optionally be provided to filter the partitions returned.
* For instance, if there exist partitions (a='1', b='2'), (a='1', b='3') and (a='2', b='4'),
* then a partial spec of (a='1') will return the first two only.
*
* @param db database name
* @param table table name
* @param partialSpec partition spec
* @param partialSpec partition spec
*/
def listPartitions(
db: String,
Expand All @@ -210,7 +232,7 @@ abstract class ExternalCatalog {
*
* @param db database name
* @param table table name
* @param predicates partition-pruning predicates
* @param predicates partition-pruning predicates
*/
def listPartitionsByFilter(
db: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.StringUtils

Expand Down Expand Up @@ -488,6 +489,19 @@ class InMemoryCatalog(
}
}

override def listPartitionNames(
db: String,
table: String,
partialSpec: Option[TablePartitionSpec] = None): Seq[String] = synchronized {
val partitionColumnNames = getTable(db, table).partitionColumnNames

listPartitions(db, table, partialSpec).map { partition =>
partitionColumnNames.map { name =>
escapePathName(name) + "=" + escapePathName(partition.spec(name))
}.mkString("/")
}.sorted
}

override def listPartitions(
db: String,
table: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,26 @@ class SessionCatalog(
externalCatalog.getPartition(db, table, spec)
}

/**
* List the names of all partitions that belong to the specified table, assuming it exists.
*
* A partial partition spec may optionally be provided to filter the partitions returned.
* For instance, if there exist partitions (a='1', b='2'), (a='1', b='3') and (a='2', b='4'),
* then a partial spec of (a='1') will return the first two only.
*/
def listPartitionNames(
tableName: TableIdentifier,
partialSpec: Option[TablePartitionSpec] = None): Seq[String] = {
val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
val table = formatTableName(tableName.table)
requireDbExists(db)
requireTableExists(TableIdentifier(table, Option(db)))
partialSpec.foreach { spec =>
requirePartialMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName))
}
externalCatalog.listPartitionNames(db, table, partialSpec)
}

/**
* List the metadata of all partitions that belong to the specified table, assuming it exists.
*
Expand All @@ -762,6 +782,9 @@ class SessionCatalog(
val table = formatTableName(tableName.table)
requireDbExists(db)
requireTableExists(TableIdentifier(table, Option(db)))
partialSpec.foreach { spec =>
requirePartialMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName))
}
externalCatalog.listPartitions(db, table, partialSpec)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,31 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
assert(new Path(partitionLocation) == defaultPartitionLocation)
}

test("list partition names") {
val catalog = newBasicCatalog()
val newPart = CatalogTablePartition(Map("a" -> "1", "b" -> "%="), storageFormat)
catalog.createPartitions("db2", "tbl2", Seq(newPart), ignoreIfExists = false)

val partitionNames = catalog.listPartitionNames("db2", "tbl2")
assert(partitionNames == Seq("a=1/b=%25%3D", "a=1/b=2", "a=3/b=4"))
}

test("list partition names with partial partition spec") {
val catalog = newBasicCatalog()
val newPart = CatalogTablePartition(Map("a" -> "1", "b" -> "%="), storageFormat)
catalog.createPartitions("db2", "tbl2", Seq(newPart), ignoreIfExists = false)

val partitionNames1 = catalog.listPartitionNames("db2", "tbl2", Some(Map("a" -> "1")))
assert(partitionNames1 == Seq("a=1/b=%25%3D", "a=1/b=2"))

// Partial partition specs including "weird" partition values should use the unescaped values
val partitionNames2 = catalog.listPartitionNames("db2", "tbl2", Some(Map("b" -> "%=")))
assert(partitionNames2 == Seq("a=1/b=%25%3D"))

val partitionNames3 = catalog.listPartitionNames("db2", "tbl2", Some(Map("b" -> "%25%3D")))
assert(partitionNames3.isEmpty)
}

test("list partitions with partial partition spec") {
val catalog = newBasicCatalog()
val parts = catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "1")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,31 @@ class SessionCatalogSuite extends SparkFunSuite {
"the partition spec (a, b) defined in table '`db2`.`tbl1`'"))
}

test("list partition names") {
val catalog = new SessionCatalog(newBasicCatalog())
val expectedPartitionNames = Seq("a=1/b=2", "a=3/b=4")
assert(catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2"))) ==
expectedPartitionNames)
// List partition names without explicitly specifying database
catalog.setCurrentDatabase("db2")
assert(catalog.listPartitionNames(TableIdentifier("tbl2")) == expectedPartitionNames)
}

test("list partition names with partial partition spec") {
val catalog = new SessionCatalog(newBasicCatalog())
assert(
catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))) ==
Seq("a=1/b=2"))
}

test("list partition names with invalid partial partition spec") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")),
Some(Map("unknown" -> "unknown")))
}
}

test("list partitions") {
val catalog = new SessionCatalog(newBasicCatalog())
assert(catalogPartitionsEqual(
Expand All @@ -887,6 +912,20 @@ class SessionCatalogSuite extends SparkFunSuite {
assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2))
}

test("list partitions with partial partition spec") {
val catalog = new SessionCatalog(newBasicCatalog())
assert(catalogPartitionsEqual(
catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))), part1))
}

test("list partitions with invalid partial partition spec") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
catalog.listPartitions(
TableIdentifier("tbl2", Some("db2")), Some(Map("unknown" -> "unknown")))
}
}

test("list partitions when database/table does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[NoSuchDatabaseException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,13 +729,6 @@ case class ShowPartitionsCommand(
AttributeReference("partition", StringType, nullable = false)() :: Nil
}

private def getPartName(spec: TablePartitionSpec, partColNames: Seq[String]): String = {
partColNames.map { name =>
ExternalCatalogUtils.escapePathName(name) + "=" +
ExternalCatalogUtils.escapePathName(spec(name))
}.mkString(File.separator)
}

override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
val table = catalog.getTableMetadata(tableName)
Expand Down Expand Up @@ -772,10 +765,7 @@ case class ShowPartitionsCommand(
}
}

val partNames = catalog.listPartitions(tableName, spec).map { p =>
getPartName(p.spec, table.partitionColumnNames)
}

val partNames = catalog.listPartitionNames(tableName, spec)
partNames.map(Row(_))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
insert.copy(partition = parts.map(p => (p._1, None)), child = Project(projectList, query))


case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: HadoopFsRelation, _, table), part, query, overwrite, false)
case logical.InsertIntoTable(
l @ LogicalRelation(t: HadoopFsRelation, _, table), _, query, overwrite, false)
if query.resolved && t.schema.sameType(query.schema) =>

// Sanity checks
Expand Down Expand Up @@ -192,11 +192,19 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty

val staticPartitionKeys: TablePartitionSpec = if (overwrite.enabled) {
overwrite.staticPartitionKeys.map { case (k, v) =>
(partitionSchema.map(_.name).find(_.equalsIgnoreCase(k)).get, v)
}
} else {
Map.empty
}

// When partitions are tracked by the catalog, compute all custom partition locations that
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions(
l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys))
l.catalogTable.get.identifier, Some(staticPartitionKeys))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations = getCustomPartitionLocations(
t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions)
Expand Down Expand Up @@ -225,14 +233,6 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
t.location.refresh()
}

val staticPartitionKeys: TablePartitionSpec = if (overwrite.enabled) {
overwrite.staticPartitionKeys.map { case (k, v) =>
(partitionSchema.map(_.name).find(_.equalsIgnoreCase(k)).get, v)
}
} else {
Map.empty
}

val insertCmd = InsertIntoHadoopFsRelationCommand(
outputPath,
staticPartitionKeys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,22 @@ object PartitioningUtils {

/**
* Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec
* for that fragment, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`.
* for that fragment as a `TablePartitionSpec`, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`.
*/
def parsePathFragment(pathFragment: String): TablePartitionSpec = {
parsePathFragmentAsSeq(pathFragment).toMap
}

/**
* Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec
* for that fragment as a `Seq[(String, String)]`, e.g.
* `Seq(("fieldOne", "1"), ("fieldTwo", "2"))`.
*/
def parsePathFragmentAsSeq(pathFragment: String): Seq[(String, String)] = {
pathFragment.split("/").map { kv =>
val pair = kv.split("=", 2)
(unescapePathName(pair(0)), unescapePathName(pair(1)))
}.toMap
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.internal.StaticSQLConf._
Expand Down Expand Up @@ -812,9 +814,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
spec.map { case (k, v) => k.toLowerCase -> v }
}

// Build a map from lower-cased partition column names to exact column names for a given table
private def buildLowerCasePartColNameMap(table: CatalogTable): Map[String, String] = {
val actualPartColNames = table.partitionColumnNames
actualPartColNames.map(colName => (colName.toLowerCase, colName)).toMap
}

// Hive metastore is not case preserving and the column names of the partition specification we
// get from the metastore are always lower cased. We should restore them w.r.t. the actual table
// partition columns.
private def restorePartitionSpec(
spec: TablePartitionSpec,
partColMap: Map[String, String]): TablePartitionSpec = {
spec.map { case (k, v) => partColMap(k.toLowerCase) -> v }
}

private def restorePartitionSpec(
spec: TablePartitionSpec,
partCols: Seq[String]): TablePartitionSpec = {
Expand Down Expand Up @@ -927,13 +941,32 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
/**
* Returns the partition names from hive metastore for a given table in a database.
*/
override def listPartitionNames(
db: String,
table: String,
partialSpec: Option[TablePartitionSpec] = None): Seq[String] = withClient {
val catalogTable = getTable(db, table)
val partColNameMap = buildLowerCasePartColNameMap(catalogTable).mapValues(escapePathName)
val clientPartitionNames =
client.getPartitionNames(catalogTable, partialSpec.map(lowerCasePartitionSpec))
clientPartitionNames.map { partName =>
val partSpec = PartitioningUtils.parsePathFragmentAsSeq(partName)
partSpec.map { case (partName, partValue) =>
partColNameMap(partName.toLowerCase) + "=" + escapePathName(partValue)
}.mkString("/")
}
}

/**
* Returns the partitions from hive metastore for a given table in a database.
*/
override def listPartitions(
db: String,
table: String,
partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient {
val actualPartColNames = getTable(db, table).partitionColumnNames
val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table))
client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, actualPartColNames))
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
}

Expand All @@ -954,10 +987,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
}

val partitionSchema = catalogTable.partitionSchema
val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table))

if (predicates.nonEmpty) {
val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, catalogTable.partitionColumnNames))
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
val boundPredicate =
InterpretedPredicate.create(predicates.reduce(And).transform {
Expand All @@ -968,7 +1002,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
clientPrunedPartitions.filter { p => boundPredicate(p.toRow(partitionSchema)) }
} else {
client.getPartitions(catalogTable).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, catalogTable.partitionColumnNames))
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
}
}
Expand Down
Loading

0 comments on commit 772ddbe

Please sign in to comment.