Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(tsql): improve support for transaction statements #1907

Merged
merged 20 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c5ab2ff
Modify tsql BEGIN TRANSACTION to be BEGIN TRANSACTION
dmoore247 Jul 6, 2023
c6ea5ce
enhance tsql transaction mapping and stored proc begin... interim commit
dmoore247 Jul 6, 2023
474ea39
Make tsql BEGIN TRANSACTION parse in all it's flavors
dmoore247 Jul 7, 2023
dae6818
Enrich TSQL begin transaction, commit transaction, rollback transacti…
dmoore247 Jul 10, 2023
95d14b1
undo fat fingering of _parse_convert
dmoore247 Jul 11, 2023
c9e6ad2
deal with merge conflict on return_sql
dmoore247 Jul 11, 2023
6ba0901
re-write BEGIN TRANSACTION
dmoore247 Jul 11, 2023
a39ef65
Merge branch 'main' into transactions
dmoore247 Jul 11, 2023
a097ea7
remove extraineous CONVERT unit test
dmoore247 Jul 11, 2023
99b977b
Merge branch 'transactions' of https://github.com/dmoore247/sqlglot i…
dmoore247 Jul 11, 2023
ca3d6d2
run make check
dmoore247 Jul 11, 2023
9856cee
PR review comments update, still work in progress
dmoore247 Jul 11, 2023
cb07d9c
Remove BEGIN TRANSACTION keyword
dmoore247 Jul 11, 2023
4b929a0
Remove transaction parameter
dmoore247 Jul 11, 2023
064125a
Remove modes from Commit interface
dmoore247 Jul 11, 2023
bcc5263
move CREATE PROCEDURE keyword tests, remove TokenType.BEGIN from tsql…
dmoore247 Jul 11, 2023
3fcd22a
Remove redundant tokens in tsql KEYWORDS, remove description from met…
dmoore247 Jul 11, 2023
9579470
tighten up code for _parse_transaction per feedback
dmoore247 Jul 11, 2023
e92e7eb
Revert to matching on TokenType.BEGIN
dmoore247 Jul 11, 2023
19f9af9
Handle BEGIN as a command
dmoore247 Jul 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 111 additions & 3 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
if isinstance(expression, exp.NumberToStr)
else exp.Literal.string(
format_time(
expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING)
expression.text("format"),
t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
)
)
)
Expand Down Expand Up @@ -281,8 +282,16 @@ class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]

COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON}
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"COMMIT": TokenType.COMMIT,
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
"ROLLBACK": TokenType.ROLLBACK,
# "END": TokenType.END,
"SET": TokenType.SET,
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"DECLARE": TokenType.COMMAND,
Expand All @@ -304,6 +313,7 @@ class Tokenizer(tokens.Tokenizer):
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
"WITH": TokenType.WITH,
}

# TSQL allows @, # to appear as a variable/identifier prefix
Expand All @@ -314,7 +324,9 @@ class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
Expand Down Expand Up @@ -358,13 +370,75 @@ class Parser(parser.Parser):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.END: lambda self: self._parse_command(),
TokenType.ALIAS: lambda self: self._parse_alias(),
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
}

LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True

CONCAT_NULL_OUTPUTS_STRING = True

def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback:
"""Applies to SQL Server and Azure SQL Database
COMMIT [ { TRAN | TRANSACTION }
[ transaction_name | @tran_name_variable ] ]
[ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ]
[ ; ]

ROLLBACK { TRAN | TRANSACTION }
[ transaction_name | @tran_name_variable
| savepoint_name | @savepoint_variable ]
[ ; ]

Returns:
exp.Commit | exp.Rollback: _description_
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
"""
rollback = False
if self._prev.token_type == TokenType.ROLLBACK:
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
rollback = True

transaction = None
if self._match_texts({"TRAN", "TRANSACTION"}):
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
transaction = self._prev.text
txn_name = self._parse_id_var()

durability = None
if self._match_text_seq("WITH", "(", "DELAYED_DURABILITY", "="):
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
if self._match_text_seq("OFF"):
durability = False
else:
self._match(TokenType.ON)
durability = True

self._match_r_paren()

if rollback:
return self.expression(exp.Rollback, this=txn_name, transaction=transaction)

return self.expression(
exp.Commit, this=txn_name, durability=durability, transaction=transaction
)

def _parse_transaction(self) -> exp.Transaction:
"""Syntax:
BEGIN { TRAN | TRANSACTION }
[ { transaction_name | @tran_name_variable }
[ WITH MARK [ 'description' ] ]
]
[ ; ]

Returns:
exp.Transaction: _description_
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
"""
mark = None

txn_name = self._parse_id_var()

if self._match_text_seq("WITH", "MARK"):
mark = self._parse_string()

return self.expression(exp.Transaction, this=txn_name, mark=mark)
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved

def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
Expand Down Expand Up @@ -496,7 +570,9 @@ class Generator(generator.Generator):
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
"HASHBYTES",
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
e.this,
),
exp.TimeToStr: _format_sql,
}
Expand Down Expand Up @@ -539,3 +615,35 @@ def returning_sql(self, expression: exp.Returning) -> str:
into = self.sql(expression, "into")
into = self.seg(f"INTO {into}") if into else ""
return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}"

def transaction_sql(self, expression: exp.Transaction) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
mark = expression.args.get("mark")
mark = f" WITH MARK {mark}" if mark else ""
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
return f"BEGIN TRANSACTION{this}{mark}"

def _durability_sql(self, expression) -> str:
durability = expression.args.get("durability")
durability_sql = ""
if durability is not None:
if durability:
durability = "ON"
else:
durability = "OFF"
durability_sql = f" WITH (DELAYED_DURABILITY = {durability})"
return durability_sql
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved

def commit_sql(self, expression: exp.Commit) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
transaction = expression.args.get("transaction")
transaction = f" {transaction}" if transaction else ""
return f"COMMIT{transaction}{this}{self._durability_sql(expression)}"

def rollback_sql(self, expression: exp.Rollback) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
transaction = expression.args.get("transaction")
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
transaction = f" {transaction}" if transaction else ""
return f"ROLLBACK{transaction}{this}"
14 changes: 10 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3472,15 +3472,21 @@ class Command(Expression):


class Transaction(Expression):
arg_types = {"this": False, "modes": False}
arg_types = {"this": False, "modes": False, "mark": False}


class Commit(Expression):
arg_types = {"chain": False}
arg_types = {
"chain": False,
"this": False,
"transaction": False,
"durability": False,
"modes": False,
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
}


class Rollback(Expression):
arg_types = {"savepoint": False}
arg_types = {"savepoint": False, "this": False, "transaction": False}


class AlterTable(Expression):
Expand Down Expand Up @@ -3960,7 +3966,7 @@ def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:


class Cast(Func):
arg_types = {"this": True, "to": True, "format": False}
arg_types = {"this": True, "to": False, "format": False}
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved

@property
def name(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,8 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
extend_props(self._parse_properties())

self._match(TokenType.ALIAS)
begin = self._match(TokenType.BEGIN)
begin = self._match_text_seq("BEGIN")
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
# begin = self._match(TokenType.BEGIN)
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
return_ = self._match_text_seq("RETURN")
expression = self._parse_statement()

Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,7 @@ def test_transactions(self):
"redshift": "BEGIN",
"snowflake": "BEGIN",
"sqlite": "BEGIN TRANSACTION",
"tsql": "BEGIN TRANSACTION",
},
)
self.validate_all(
Expand Down
62 changes: 60 additions & 2 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,62 @@ def test_types_bin(self):
},
)

def test_ddl(self):
sql = """CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24));"""
self.validate_all(
" ".join(sql.split()),
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
write={
"tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIMESTAMP(4), d FLOAT(24))"
},
)

def test_transaction(self):
# BEGIN { TRAN | TRANSACTION }
# [ { transaction_name | @tran_name_variable }
# [ WITH MARK [ 'description' ] ]
# ]
# [ ; ]
self.validate_identity("BEGIN TRANSACTION")
# self.validate_identity("BEGIN TRAN")
self.validate_identity("BEGIN TRANSACTION transaction_name")
self.validate_identity("BEGIN TRANSACTION @tran_name_variable")
self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'")

def test_commit(self):
# COMMIT [ { TRAN | TRANSACTION } [ transaction_name | @tran_name_variable ] ] [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] [ ; ]

self.validate_identity("COMMIT")
self.validate_all("COMMIT TRAN")
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
self.validate_identity("COMMIT TRANSACTION")
self.validate_identity("COMMIT TRANSACTION transaction_name")
self.validate_identity("COMMIT TRANSACTION @tran_name_variable")

self.validate_identity(
"COMMIT TRANSACTION @tran_name_variable WITH (DELAYED_DURABILITY = ON)"
)
self.validate_identity(
"COMMIT TRANSACTION transaction_name WITH (DELAYED_DURABILITY = OFF)"
)

def test_rollback(self):
# Applies to SQL Server and Azure SQL Database
# ROLLBACK { TRAN | TRANSACTION }
# [ transaction_name | @tran_name_variable
# | savepoint_name | @savepoint_variable ]
# [ ; ]
self.validate_identity("ROLLBACK")
self.validate_all("ROLLBACK TRAN")
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
self.validate_identity("ROLLBACK TRANSACTION")
self.validate_identity("ROLLBACK TRANSACTION transaction_name")
self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable")

def test_udf(self):
self.validate_identity("BEGIN")
self.validate_identity("END")
self.validate_identity("SET XACT_ABORT ON")
self.validate_identity(
"DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)"
)
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
self.validate_identity(
"CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar"
)
Expand Down Expand Up @@ -446,6 +501,7 @@ def test_udf(self):
pretty=True,
)

def test_fullproc(self):
dmoore247 marked this conversation as resolved.
Show resolved Hide resolved
sql = """
CREATE procedure [TRANSF].[SP_Merge_Sales_Real]
@Loadid INTEGER
Expand Down Expand Up @@ -838,7 +894,8 @@ def test_format(self):
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
)
self.validate_all(
"SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}
"SELECT FORMAT(1234567, 'f')",
write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"},
)
self.validate_all(
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
Expand All @@ -853,7 +910,8 @@ def test_format(self):
write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"},
)
self.validate_all(
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
"SELECT FORMAT(num_col, 'c')",
write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"},
)

def test_string(self):
Expand Down