Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Support SQL INTERVAL #3146

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ daft-plan = {path = "../daft-plan"}
once_cell = {workspace = true}
pyo3 = {workspace = true, optional = true}
sqlparser = {workspace = true}
regex.workspace = true
snafu.workspace = true

[dev-dependencies]
Expand Down
173 changes: 169 additions & 4 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use daft_functions::numeric::{ceil::ceil, floor::floor};
use daft_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, Distinct, ExactNumberInfo, ExcludeSelectItem,
GroupByExpr, Ident, Query, SelectItem, Statement, StructField, Subscript, TableAlias,
TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions,
ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo,
ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField,
Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value,
WildcardAdditionalOptions,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand Down Expand Up @@ -913,7 +914,171 @@ impl SQLPlanner {
SQLExpr::Map(_) => unsupported_sql_err!("MAP"),
SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()),
SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"),
SQLExpr::Interval(_) => unsupported_sql_err!("INTERVAL"),
SQLExpr::Interval(interval) => {
use regex::Regex;

/// A private struct represents a single parsed interval unit and its value
#[derive(Debug)]
struct IntervalPart {
count: i64,
unit: DateTimeField,
}

// Local function to parse interval string to interval parts
fn parse_interval_string(expr: &str) -> Result<Vec<IntervalPart>, PlannerError> {
let expr = expr.trim().trim_matches('\'');

let re = Regex::new(r"(-?\d+)\s*(year|years|month|months|day|days|hour|hours|minute|minutes|second|seconds|millisecond|milliseconds|microsecond|microseconds|nanosecond|nanoseconds|week|weeks)")
.map_err(|e|PlannerError::invalid_operation(format!("Invalid regex pattern: {}", e)))?;

let mut parts = Vec::new();

for cap in re.captures_iter(expr) {
let count: i64 = cap[1].parse().map_err(|e| {
PlannerError::invalid_operation(format!("Invalid interval count: {e}"))
})?;

let unit = match &cap[2].to_lowercase()[..] {
"year" | "years" => DateTimeField::Year,
"month" | "months" => DateTimeField::Month,
"week" | "weeks" => DateTimeField::Week(None),
"day" | "days" => DateTimeField::Day,
"hour" | "hours" => DateTimeField::Hour,
"minute" | "minutes" => DateTimeField::Minute,
"second" | "seconds" => DateTimeField::Second,
"millisecond" | "milliseconds" => DateTimeField::Millisecond,
"microsecond" | "microseconds" => DateTimeField::Microsecond,
"nanosecond" | "nanoseconds" => DateTimeField::Nanosecond,
_ => {
return Err(PlannerError::invalid_operation(format!(
"Invalid interval unit: {}",
&cap[2]
)))
}
};

parts.push(IntervalPart { count, unit });
}
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved

if parts.is_empty() {
return Err(PlannerError::invalid_operation("Invalid interval format."));
}

Ok(parts)
}

// Local function to convert parts to interval values
fn interval_parts_to_values(parts: Vec<IntervalPart>) -> (i64, i64, i64) {
let mut total_months = 0i64;
let mut total_days = 0i64;
let mut total_nanos = 0i64;

for part in parts {
match part.unit {
DateTimeField::Year => total_months += 12 * part.count,
DateTimeField::Month => total_months += part.count,
DateTimeField::Week(_) => total_days += 7 * part.count,
DateTimeField::Day => total_days += part.count,
DateTimeField::Hour => total_nanos += part.count * 3_600_000_000_000,
DateTimeField::Minute => total_nanos += part.count * 60_000_000_000,
DateTimeField::Second => total_nanos += part.count * 1_000_000_000,
DateTimeField::Millisecond | DateTimeField::Milliseconds => {
total_nanos += part.count * 1_000_000;
}
DateTimeField::Microsecond | DateTimeField::Microseconds => {
total_nanos += part.count * 1_000;
}
DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
total_nanos += part.count;
}
_ => {}
}
}

(total_months, total_days, total_nanos)
}

match interval {
// If leading_field is specified, treat it as the old style single-unit interval
// e.g., INTERVAL '12' YEAR
sqlparser::ast::Interval {
value,
leading_field: Some(time_unit),
..
} => {
let expr = self.plan_expr(value)?;

let expr =
expr.as_literal()
.and_then(|lit| lit.as_str())
.ok_or_else(|| {
PlannerError::invalid_operation(
"Interval value must be a string",
)
})?;

let count = expr.parse::<i64>().map_err(|e| {
PlannerError::unsupported_sql(format!("Invalid interval count: {e}"))
})?;

let (months, days, nanoseconds) = match time_unit {
DateTimeField::Year => (12 * count, 0, 0),
DateTimeField::Month => (count, 0, 0),
DateTimeField::Week(_) => (0, 7 * count, 0),
DateTimeField::Day => (0, count, 0),
DateTimeField::Hour => (0, 0, count * 3_600_000_000_000),
DateTimeField::Minute => (0, 0, count * 60_000_000_000),
DateTimeField::Second => (0, 0, count * 1_000_000_000),
DateTimeField::Microsecond | DateTimeField::Microseconds => (0, 0, count * 1_000),
DateTimeField::Millisecond | DateTimeField::Milliseconds => (0, 0, count * 1_000_000),
DateTimeField::Nanosecond | DateTimeField::Nanoseconds => (0, 0, count),
_ => return Err(PlannerError::invalid_operation(format!(
"Invalid interval unit: {time_unit}. Expected one of: year, month, week, day, hour, minute, second, millisecond, microsecond, nanosecond"
))),
};

Ok(Arc::new(Expr::Literal(LiteralValue::Interval(
daft_core::datatypes::IntervalValue::new(
months as i32,
days as i32,
nanoseconds,
),
))))
}

// If no leading_field is specified, treat it as the new style multi-unit interval
// e.g., INTERVAL '12 years 3 months 7 days'
sqlparser::ast::Interval {
value,
leading_field: None,
..
} => {
let expr = self.plan_expr(value)?;

let expr =
expr.as_literal()
.and_then(|lit| lit.as_str())
.ok_or_else(|| {
PlannerError::invalid_operation(
"Interval value must be a string",
)
})?;

let parts = parse_interval_string(expr)
.map_err(|e| PlannerError::invalid_operation(e.to_string()))?;

let (months, days, nanoseconds) = interval_parts_to_values(parts);

Ok(Arc::new(Expr::Literal(LiteralValue::Interval(
daft_core::datatypes::IntervalValue::new(
months as i32,
days as i32,
nanoseconds,
),
))))
}
}
}
SQLExpr::MatchAgainst { .. } => unsupported_sql_err!("MATCH AGAINST"),
SQLExpr::Wildcard => unsupported_sql_err!("WILDCARD"),
SQLExpr::QualifiedWildcard(_) => unsupported_sql_err!("QUALIFIED WILDCARD"),
Expand Down
80 changes: 79 additions & 1 deletion tests/sql/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import datetime

import pytest

import daft
from daft import col
from daft import col, interval
from daft.sql.sql import SQLCatalog


def test_nested():
Expand Down Expand Up @@ -135,3 +138,78 @@ def test_is_in_edge_cases():
# Test with mixed types in the IN list
with pytest.raises(Exception, match="All literals must have the same data type"):
daft.sql("SELECT * FROM df WHERE nums IN (1, '2', 3.0)").collect().to_pydict()


@pytest.mark.parametrize(
"date_values, ts_values, expected_intervals",
[
(
["2022-01-01", "2020-02-29", "2029-05-15"],
["2022-01-01 10:00:00", "2020-02-29 23:59:59", "2029-05-15 12:34:56"],
{
"date_add_day": [
datetime.date(2022, 1, 2),
datetime.date(2020, 3, 1),
datetime.date(2029, 5, 16),
],
"date_sub_month": [
datetime.date(2021, 12, 1),
datetime.date(2020, 1, 31),
datetime.date(2029, 4, 14),
],
"ts_sub_year": [
datetime.datetime(2021, 1, 1, 10),
datetime.datetime(2019, 2, 28, 23, 59, 59),
datetime.datetime(2028, 5, 15, 12, 34, 56),
],
"ts_add_hour": [
datetime.datetime(2022, 1, 1, 11, 0, 0),
datetime.datetime(2020, 3, 1, 0, 59, 59),
datetime.datetime(2029, 5, 15, 13, 34, 56),
],
"ts_sub_minute": [
datetime.datetime(2022, 1, 1, 9, 57, 21),
datetime.datetime(2020, 2, 29, 23, 57, 20),
datetime.datetime(2029, 5, 15, 12, 32, 17),
],
},
),
],
)
def test_interval_comparison(date_values, ts_values, expected_intervals):
# Create DataFrame with date and timestamp columns
df = daft.from_pydict({"date": date_values, "ts": ts_values}).select(
col("date").cast(daft.DataType.date()), col("ts").str.to_datetime("%Y-%m-%d %H:%M:%S")
)
catalog = SQLCatalog({"test": df})

expected_df = (
df.select(
(col("date") + interval(days=1)).alias("date_add_day"),
(col("date") - interval(months=1)).alias("date_sub_month"),
(col("ts") - interval(years=1, days=0)).alias("ts_sub_year"),
(col("ts") + interval(hours=1)).alias("ts_add_hour"),
(col("ts") - interval(minutes=1, seconds=99)).alias("ts_sub_minute"),
)
.collect()
.to_pydict()
)

actual_sql = (
daft.sql(
"""
SELECT
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
date + INTERVAL '1' day AS date_add_day,
date - INTERVAL '1 months' AS date_sub_month,
ts - INTERVAL '1 year 0 days' AS ts_sub_year,
ts + INTERVAL '1' hour AS ts_add_hour,
ts - INTERVAL '1 minutes 99 second' AS ts_sub_minute
FROM test
""",
catalog=catalog,
)
.collect()
.to_pydict()
)

assert expected_df == actual_sql == expected_intervals
Loading