diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b77c2c0bb7..efb49bd9a9 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim if isinstance(expression, exp.NumberToStr) else exp.Literal.string( format_time( - expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING) + expression.text("format"), + t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING), ) ) ) @@ -314,7 +315,9 @@ class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( - this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), @@ -365,6 +368,57 @@ class Parser(parser.Parser): CONCAT_NULL_OUTPUTS_STRING = True + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: + """Applies to SQL Server and Azure SQL Database + COMMIT [ { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable ] ] + [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] + [ ; ] + + ROLLBACK { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable + | savepoint_name | @savepoint_variable ] + [ ; ] + """ + rollback = self._prev.token_type == TokenType.ROLLBACK + + self._match_texts({"TRAN", "TRANSACTION"}) + txn_name = self._parse_id_var() + + durability = None + if self._match_pair(TokenType.WITH, TokenType.L_PAREN): + self._match_text_seq("DELAYED_DURABILITY") + self._match(TokenType.EQ) + if self._match_text_seq("OFF"): + durability = False + else: + self._match(TokenType.ON) + durability = True + + self._match_r_paren() + + if rollback: + return self.expression(exp.Rollback, this=txn_name) + + return self.expression(exp.Commit, this=txn_name, durability=durability) + + def _parse_transaction(self) -> exp.Transaction | exp.Command: + """Applies to SQL Server and Azure SQL Database + BEGIN { TRAN | TRANSACTION } + [ { transaction_name | @tran_name_variable } + [ WITH MARK [ 'description' ] ] + ] + [ ; ] + """ + if self._match_texts(("TRAN", "TRANSACTION")): + # we have a transaction and not a BEGIN + transaction = self.expression(exp.Transaction, this=self._parse_id_var()) + if self._match_text_seq("WITH", "MARK"): + transaction.set("mark", self._parse_string()) + return transaction + + return self._parse_as_command(self._prev) + def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): return None @@ -496,7 +550,9 @@ class Generator(generator.Generator): exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( - "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this + "HASHBYTES", + exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), + e.this, ), exp.TimeToStr: _format_sql, } @@ -539,3 +595,31 @@ def returning_sql(self, expression: exp.Returning) -> str: into = self.sql(expression, "into") into = self.seg(f"INTO {into}") if into else "" return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}" + + def transaction_sql(self, expression: exp.Transaction) -> str: + this = self.sql(expression, "this") + if this in ["TRAN", "TRANSACTION"]: + mark = self.sql(expression, "mark") + mark = f" WITH MARK {mark}" if mark else "" + return f"BEGIN TRANSACTION{mark}" + else: + this = f" {this}" if this else "" + mark = self.sql(expression, "mark") + mark = f" WITH MARK {mark}" if mark else "" + return f"BEGIN TRANSACTION{this}{mark}" + + def commit_sql(self, expression: exp.Commit) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + durability = expression.args.get("durability") + durability = ( + f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})" + if durability is not None + else "" + ) + return f"COMMIT TRANSACTION{this}{durability}" + + def rollback_sql(self, expression: exp.Rollback) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"ROLLBACK TRANSACTION{this}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1efedc7b78..e4edeed1cf 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3472,15 +3472,15 @@ class Command(Expression): class Transaction(Expression): - arg_types = {"this": False, "modes": False} + arg_types = {"this": False, "modes": False, "mark": False} class Commit(Expression): - arg_types = {"chain": False} + arg_types = {"chain": False, "this": False, "durability": False} class Rollback(Expression): - arg_types = {"savepoint": False} + arg_types = {"savepoint": False, "this": False} class AlterTable(Expression): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 508a273705..166e3b8e66 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -4351,7 +4351,7 @@ def _parse_ddl_select(self) -> t.Optional[exp.Expression]: self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False)) ) - def _parse_transaction(self) -> exp.Transaction: + def _parse_transaction(self) -> exp.Transaction | exp.Command: this = None if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 05738cf34e..3ca47aba2b 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1512,6 +1512,7 @@ def test_transactions(self): "redshift": "BEGIN", "snowflake": "BEGIN", "sqlite": "BEGIN TRANSACTION", + "tsql": "BEGIN TRANSACTION", }, ) self.validate_all( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 065cdd09f9..5fa881b616 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -389,7 +389,58 @@ def test_types_bin(self): }, ) + def test_ddl(self): + self.validate_all( + "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", + write={ + "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIMESTAMP(4), d FLOAT(24))" + }, + ) + + def test_transaction(self): + # BEGIN { TRAN | TRANSACTION } + # [ { transaction_name | @tran_name_variable } + # [ WITH MARK [ 'description' ] ] + # ] + # [ ; ] + self.validate_identity("BEGIN TRANSACTION") + self.validate_all("BEGIN TRAN", write={"tsql": "BEGIN TRANSACTION"}) + self.validate_identity("BEGIN TRANSACTION transaction_name") + self.validate_identity("BEGIN TRANSACTION @tran_name_variable") + self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'") + + def test_commit(self): + # COMMIT [ { TRAN | TRANSACTION } [ transaction_name | @tran_name_variable ] ] [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] [ ; ] + + self.validate_all("COMMIT", write={"tsql": "COMMIT TRANSACTION"}) + self.validate_all("COMMIT TRAN", write={"tsql": "COMMIT TRANSACTION"}) + self.validate_identity("COMMIT TRANSACTION") + self.validate_identity("COMMIT TRANSACTION transaction_name") + self.validate_identity("COMMIT TRANSACTION @tran_name_variable") + + self.validate_identity( + "COMMIT TRANSACTION @tran_name_variable WITH (DELAYED_DURABILITY = ON)" + ) + self.validate_identity( + "COMMIT TRANSACTION transaction_name WITH (DELAYED_DURABILITY = OFF)" + ) + + def test_rollback(self): + # Applies to SQL Server and Azure SQL Database + # ROLLBACK { TRAN | TRANSACTION } + # [ transaction_name | @tran_name_variable + # | savepoint_name | @savepoint_variable ] + # [ ; ] + self.validate_all("ROLLBACK", write={"tsql": "ROLLBACK TRANSACTION"}) + self.validate_all("ROLLBACK TRAN", write={"tsql": "ROLLBACK TRANSACTION"}) + self.validate_identity("ROLLBACK TRANSACTION") + self.validate_identity("ROLLBACK TRANSACTION transaction_name") + self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable") + def test_udf(self): + self.validate_identity( + "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)" + ) self.validate_identity( "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" ) @@ -446,6 +497,12 @@ def test_udf(self): pretty=True, ) + def test_procedure_keywords(self): + self.validate_identity("BEGIN") + self.validate_identity("END") + self.validate_identity("SET XACT_ABORT ON") + + def test_fullproc(self): sql = """ CREATE procedure [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER @@ -838,7 +895,8 @@ def test_format(self): write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, ) self.validate_all( - "SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"} + "SELECT FORMAT(1234567, 'f')", + write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}, ) self.validate_all( "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')", @@ -853,7 +911,8 @@ def test_format(self): write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}, ) self.validate_all( - "SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"} + "SELECT FORMAT(num_col, 'c')", + write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}, ) def test_string(self):