Skip to content

Commit

Permalink
Feature: datediff python executor
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jun 30, 2023
1 parent 2911bbb commit 3800158
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
19 changes: 15 additions & 4 deletions sqlglot/executor/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,19 @@ def substring(this, start=None, length=None):
@null_if_any
def cast(this, to):
if to == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(this)
if to == exp.DataType.Type.DATETIME:
return datetime.datetime.fromisoformat(this)
if isinstance(this, datetime.datetime):
return this.date()
if isinstance(this, datetime.date):
return this
if isinstance(this, str):
return datetime.date.fromisoformat(this)
if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
if isinstance(this, datetime.datetime):
return this
if isinstance(this, datetime.date):
return datetime.datetime(this.year, this.month, this.day)
if isinstance(this, str):
return datetime.datetime.fromisoformat(this)
if to == exp.DataType.Type.BOOLEAN:
return bool(this)
if to in exp.DataType.TEXT_TYPES:
Expand All @@ -111,7 +121,7 @@ def cast(this, to):
return float(this)
if to in exp.DataType.NUMERIC_TYPES:
return int(this)
raise NotImplementedError(f"Casting to '{to}' not implemented.")
raise NotImplementedError(f"Casting {this} to '{to}' not implemented.")


def ordered(this, desc, nulls_first):
Expand Down Expand Up @@ -153,6 +163,7 @@ def interval(this, unit):
"CONCAT": null_if_any(lambda *args: "".join(args)),
"SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
"DATEDIFF": null_if_any(lambda this, expression, *_: (this - expression).days),
"DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
"DIV": null_if_any(lambda e, this: e / this),
"DOT": null_if_any(lambda e, this: e[this]),
Expand Down
1 change: 1 addition & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ def test_scalar_functions(self):
("TIMESTRTOTIME('2022-01-01')", datetime.datetime(2022, 1, 1)),
("LEFT('12345', 3)", "123"),
("RIGHT('12345', 3)", "345"),
("DATEDIFF('2022-01-03'::date, '2022-01-01'::TIMESTAMP::DATE)", 2),
]:
with self.subTest(sql):
result = execute(f"SELECT {sql}")
Expand Down

0 comments on commit 3800158

Please sign in to comment.