Skip to content

Commit

Permalink
[SPARK-29819][SQL] Introduce an enum for interval units
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
In the PR, I propose an enumeration for interval units with the value `YEAR`, `MONTH`, `WEEK`, `DAY`, `HOUR`, `MINUTE`, `SECOND`, `MILLISECOND`, `MICROSECOND` and `NANOSECOND`.

### Why are the changes needed?
- This should prevent typos in interval unit names
- Stronger type checking of unit parameters.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
By existing test suites `ExpressionParserSuite` and `IntervalUtilsSuite`

Closes #26455 from MaxGekk/interval-unit-enum.

Authored-by: Maxim Gekk <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
MaxGekk authored and dongjoon-hyun committed Nov 10, 2019
1 parent 57b954e commit 7ddcb5b
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -2061,7 +2062,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val u = unit.getText.toLowerCase(Locale.ROOT)
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
if (u.endsWith("s")) u.substring(0, u.length - 1) else u
}.toArray
}.map(IntervalUtils.IntervalUnit.withName).toArray

val values = ctx.intervalValue().asScala.map { value =>
if (value.STRING() != null) {
Expand Down Expand Up @@ -2097,17 +2098,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case ("year", "month") =>
IntervalUtils.fromYearMonthString(value)
case ("day", "hour") =>
IntervalUtils.fromDayTimeString(value, "day", "hour")
IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.HOUR)
case ("day", "minute") =>
IntervalUtils.fromDayTimeString(value, "day", "minute")
IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.MINUTE)
case ("day", "second") =>
IntervalUtils.fromDayTimeString(value, "day", "second")
IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.SECOND)
case ("hour", "minute") =>
IntervalUtils.fromDayTimeString(value, "hour", "minute")
IntervalUtils.fromDayTimeString(value, IntervalUnit.HOUR, IntervalUnit.MINUTE)
case ("hour", "second") =>
IntervalUtils.fromDayTimeString(value, "hour", "second")
IntervalUtils.fromDayTimeString(value, IntervalUnit.HOUR, IntervalUnit.SECOND)
case ("minute", "second") =>
IntervalUtils.fromDayTimeString(value, "minute", "second")
IntervalUtils.fromDayTimeString(value, IntervalUnit.MINUTE, IntervalUnit.SECOND)
case _ =>
throw new ParseException(s"Intervals FROM $from TO $to are not supported.", ctx)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

object IntervalUtils {

object IntervalUnit extends Enumeration {
type IntervalUnit = Value

val NANOSECOND = Value(0, "nanosecond")
val MICROSECOND = Value(1, "microsecond")
val MILLISECOND = Value(2, "millisecond")
val SECOND = Value(3, "second")
val MINUTE = Value(4, "minute")
val HOUR = Value(5, "hour")
val DAY = Value(6, "day")
val WEEK = Value(7, "week")
val MONTH = Value(8, "month")
val YEAR = Value(9, "year")
}
import IntervalUnit._

def getYears(interval: CalendarInterval): Int = {
interval.months / MONTHS_PER_YEAR
}
Expand Down Expand Up @@ -114,7 +130,7 @@ object IntervalUtils {
}

private def toLongWithRange(
fieldName: String,
fieldName: IntervalUnit,
s: String,
minValue: Long,
maxValue: Long): Long = {
Expand All @@ -136,8 +152,8 @@ object IntervalUtils {
require(input != null, "Interval year-month string must be not null")
def toInterval(yearStr: String, monthStr: String): CalendarInterval = {
try {
val years = toLongWithRange("year", yearStr, 0, Integer.MAX_VALUE).toInt
val months = toLongWithRange("month", monthStr, 0, 11).toInt
val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE).toInt
val months = toLongWithRange(MONTH, monthStr, 0, 11).toInt
val totalMonths = Math.addExact(Math.multiplyExact(years, 12), months)
new CalendarInterval(totalMonths, 0, 0)
} catch {
Expand All @@ -164,7 +180,7 @@ object IntervalUtils {
* adapted from HiveIntervalDayTime.valueOf
*/
def fromDayTimeString(s: String): CalendarInterval = {
fromDayTimeString(s, "day", "second")
fromDayTimeString(s, DAY, SECOND)
}

private val dayTimePattern =
Expand All @@ -179,7 +195,7 @@ object IntervalUtils {
* - HOUR TO (MINUTE|SECOND)
* - MINUTE TO SECOND
*/
def fromDayTimeString(input: String, from: String, to: String): CalendarInterval = {
def fromDayTimeString(input: String, from: IntervalUnit, to: IntervalUnit): CalendarInterval = {
require(input != null, "Interval day-time string must be not null")
assert(input.length == input.trim.length)
val m = dayTimePattern.pattern.matcher(input)
Expand All @@ -190,33 +206,33 @@ object IntervalUtils {
val days = if (m.group(2) == null) {
0
} else {
toLongWithRange("day", m.group(3), 0, Integer.MAX_VALUE).toInt
toLongWithRange(DAY, m.group(3), 0, Integer.MAX_VALUE).toInt
}
var hours: Long = 0L
var minutes: Long = 0L
var seconds: Long = 0L
if (m.group(5) != null || from == "minute") { // 'HH:mm:ss' or 'mm:ss minute'
hours = toLongWithRange("hour", m.group(5), 0, 23)
minutes = toLongWithRange("minute", m.group(6), 0, 59)
seconds = toLongWithRange("second", m.group(7), 0, 59)
if (m.group(5) != null || from == MINUTE) { // 'HH:mm:ss' or 'mm:ss minute'
hours = toLongWithRange(HOUR, m.group(5), 0, 23)
minutes = toLongWithRange(MINUTE, m.group(6), 0, 59)
seconds = toLongWithRange(SECOND, m.group(7), 0, 59)
} else if (m.group(8) != null) { // 'mm:ss.nn'
minutes = toLongWithRange("minute", m.group(6), 0, 59)
seconds = toLongWithRange("second", m.group(7), 0, 59)
minutes = toLongWithRange(MINUTE, m.group(6), 0, 59)
seconds = toLongWithRange(SECOND, m.group(7), 0, 59)
} else { // 'HH:mm'
hours = toLongWithRange("hour", m.group(6), 0, 23)
minutes = toLongWithRange("second", m.group(7), 0, 59)
hours = toLongWithRange(HOUR, m.group(6), 0, 23)
minutes = toLongWithRange(SECOND, m.group(7), 0, 59)
}
// Hive allow nanosecond precision interval
var secondsFraction = parseNanos(m.group(9), seconds < 0)
to match {
case "hour" =>
case HOUR =>
minutes = 0
seconds = 0
secondsFraction = 0
case "minute" =>
case MINUTE =>
seconds = 0
secondsFraction = 0
case "second" =>
case SECOND =>
// No-op
case _ =>
throw new IllegalArgumentException(
Expand All @@ -234,7 +250,7 @@ object IntervalUtils {
}
}

def fromUnitStrings(units: Array[String], values: Array[String]): CalendarInterval = {
def fromUnitStrings(units: Array[IntervalUnit], values: Array[String]): CalendarInterval = {
assert(units.length == values.length)
var months: Int = 0
var days: Int = 0
Expand All @@ -243,26 +259,26 @@ object IntervalUtils {
while (i < units.length) {
try {
units(i) match {
case "year" =>
case YEAR =>
months = Math.addExact(months, Math.multiplyExact(values(i).toInt, 12))
case "month" =>
case MONTH =>
months = Math.addExact(months, values(i).toInt)
case "week" =>
case WEEK =>
days = Math.addExact(days, Math.multiplyExact(values(i).toInt, 7))
case "day" =>
case DAY =>
days = Math.addExact(days, values(i).toInt)
case "hour" =>
case HOUR =>
val hoursUs = Math.multiplyExact(values(i).toLong, MICROS_PER_HOUR)
microseconds = Math.addExact(microseconds, hoursUs)
case "minute" =>
case MINUTE =>
val minutesUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MINUTE)
microseconds = Math.addExact(microseconds, minutesUs)
case "second" =>
case SECOND =>
microseconds = Math.addExact(microseconds, parseSecondNano(values(i)))
case "millisecond" =>
case MILLISECOND =>
val millisUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MILLIS)
microseconds = Math.addExact(microseconds, millisUs)
case "microsecond" =>
case MICROSECOND =>
microseconds = Math.addExact(microseconds, values(i).toLong)
}
} catch {
Expand All @@ -281,7 +297,7 @@ object IntervalUtils {
val alignedStr = if (nanosStr.length < maxNanosLen) {
(nanosStr + "000000000").substring(0, maxNanosLen)
} else nanosStr
val nanos = toLongWithRange("nanosecond", alignedStr, 0L, 999999999L)
val nanos = toLongWithRange(NANOSECOND, alignedStr, 0L, 999999999L)
val micros = nanos / NANOS_PER_MICROS
if (isNegative) -micros else micros
} else {
Expand All @@ -295,7 +311,7 @@ object IntervalUtils {
private def parseSecondNano(secondNano: String): Long = {
def parseSeconds(secondsStr: String): Long = {
toLongWithRange(
"second",
SECOND,
secondsStr,
Long.MinValue / MICROS_PER_SECOND,
Long.MaxValue / MICROS_PER_SECOND) * MICROS_PER_SECOND
Expand Down Expand Up @@ -419,15 +435,18 @@ object IntervalUtils {
END_UNIT_NAME = Value
}
private final val intervalStr = UTF8String.fromString("interval ")
private final val yearStr = UTF8String.fromString("year")
private final val monthStr = UTF8String.fromString("month")
private final val weekStr = UTF8String.fromString("week")
private final val dayStr = UTF8String.fromString("day")
private final val hourStr = UTF8String.fromString("hour")
private final val minuteStr = UTF8String.fromString("minute")
private final val secondStr = UTF8String.fromString("second")
private final val millisStr = UTF8String.fromString("millisecond")
private final val microsStr = UTF8String.fromString("microsecond")
private def unitToUtf8(unit: IntervalUnit): UTF8String = {
UTF8String.fromString(unit.toString)
}
private final val yearStr = unitToUtf8(YEAR)
private final val monthStr = unitToUtf8(MONTH)
private final val weekStr = unitToUtf8(WEEK)
private final val dayStr = unitToUtf8(DAY)
private final val hourStr = unitToUtf8(HOUR)
private final val minuteStr = unitToUtf8(MINUTE)
private final val secondStr = unitToUtf8(SECOND)
private final val millisStr = unitToUtf8(MILLISECOND)
private final val microsStr = unitToUtf8(MICROSECOND)

def stringToInterval(input: UTF8String): CalendarInterval = {
import ParseState._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -587,17 +588,17 @@ class ExpressionParserSuite extends AnalysisTest {
}

val intervalUnits = Seq(
"year",
"month",
"week",
"day",
"hour",
"minute",
"second",
"millisecond",
"microsecond")

def intervalLiteral(u: String, s: String): Literal = {
YEAR,
MONTH,
WEEK,
DAY,
HOUR,
MINUTE,
SECOND,
MILLISECOND,
MICROSECOND)

def intervalLiteral(u: IntervalUnit, s: String): Literal = {
Literal(IntervalUtils.fromUnitStrings(Array(u), Array(s)))
}

Expand Down Expand Up @@ -628,7 +629,7 @@ class ExpressionParserSuite extends AnalysisTest {
}

// Hive nanosecond notation.
checkIntervals("13.123456789 seconds", intervalLiteral("second", "13.123456789"))
checkIntervals("13.123456789 seconds", intervalLiteral(SECOND, "13.123456789"))
checkIntervals(
"-13.123456789 second",
Literal(new CalendarInterval(
Expand Down Expand Up @@ -699,7 +700,7 @@ class ExpressionParserSuite extends AnalysisTest {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
val aliases = defaultParser.parseExpression(intervalValue).collect {
case a @ Alias(_: Literal, name)
if intervalUnits.exists { unit => name.startsWith(unit) } => a
if intervalUnits.exists { unit => name.startsWith(unit.toString) } => a
}
assert(aliases.size === 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class IntervalUtilsSuite extends SparkFunSuite {
Expand Down Expand Up @@ -161,7 +162,7 @@ class IntervalUtilsSuite extends SparkFunSuite {
}

try {
fromDayTimeString("5 1:12:20", "hour", "microsecond")
fromDayTimeString("5 1:12:20", HOUR, MICROSECOND)
fail("Expected to throw an exception for the invalid convention type")
} catch {
case e: IllegalArgumentException =>
Expand Down

0 comments on commit 7ddcb5b

Please sign in to comment.