From c5ab2ff0b59e5ab4adda7e1c0df573c25016bc65 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Thu, 6 Jul 2023 15:48:52 -0400 Subject: [PATCH 01/18] Modify tsql BEGIN TRANSACTION to be BEGIN TRANSACTION --- sqlglot/dialects/tsql.py | 3 +++ tests/dialects/test_dialect.py | 1 + 2 files changed, 4 insertions(+) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 92bb755539..5b4bd913f2 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -532,3 +532,6 @@ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: table = expression.args.get("table") table = f"{table} " if table else "" return f"RETURNS {table}{self.sql(expression, 'this')}" + + def transaction_sql(self, expression: exp.Transaction) -> str: + return "BEGIN TRANSACTION" diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 05738cf34e..3ca47aba2b 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1512,6 +1512,7 @@ def test_transactions(self): "redshift": "BEGIN", "snowflake": "BEGIN", "sqlite": "BEGIN TRANSACTION", + "tsql": "BEGIN TRANSACTION", }, ) self.validate_all( From c6ea5ce891b9cf5c7b6b3bd9fc27e8be5f6bb53d Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Thu, 6 Jul 2023 19:58:48 -0400 Subject: [PATCH 02/18] enhance tsql transaction mapping and stored proc begin... interim commit --- sqlglot/dialects/tsql.py | 17 +++++++++++++++- sqlglot/parser.py | 2 +- tests/dialects/test_tsql.py | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 5b4bd913f2..f33f4659df 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -281,8 +281,16 @@ class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] HEX_STRINGS = [("0x", ""), ("0X", "")] + COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON} + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, + "COMMIT": TokenType.COMMIT, + "ROLLBACK": TokenType.ROLLBACK, + # "END": TokenType.END, + "SET": TokenType.SET, "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DECLARE": TokenType.COMMAND, @@ -534,4 +542,11 @@ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: return f"RETURNS {table}{self.sql(expression, 'this')}" def transaction_sql(self, expression: exp.Transaction) -> str: - return "BEGIN TRANSACTION" + this = expression.this + this = f" {this}" if this else "" + return f"BEGIN{this} TRANSACTION" + + def commit_sql(self, expression: exp.Commit) -> str: + this = expression.this + this = f" {this}" if this else "" + return f"COMMIT{this} TRANSACTION" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f7fd6ba6b4..b1d22babf6 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1205,7 +1205,7 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None: extend_props(self._parse_properties()) self._match(TokenType.ALIAS) - begin = self._match(TokenType.BEGIN) + begin = self._match_text_seq("BEGIN") # _match(TokenType.BEGIN) return_ = self._match_text_seq("RETURN") expression = self._parse_statement() diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 3604c719dc..6e1c699c17 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -360,7 +360,46 @@ def test_types_bin(self): }, ) + def test_transaction(self): + # BEGIN { TRAN | TRANSACTION } + # [ { transaction_name | @tran_name_variable } + # [ WITH MARK [ 'description' ] ] + # ] + # [ ; ] + self.validate_identity("BEGIN TRANSACTION") + self.validate_all("BEGIN TRAN", write={"tsql": "BEGIN TRAN"}) + # self.validate_identity("BEGIN TRANSACTION transaction_name") + # self.validate_identity("BEGIN TRANSACTION @tran_name_variable") + # self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description") + + # COMMIT [ { TRAN | TRANSACTION } [ transaction_name | @tran_name_variable ] ] [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] [ ; ] + + # self.validate_identity("COMMIT") + # self.validate_identity("COMMIT TRAN") + self.validate_identity("COMMIT TRANSACTION") + # self.validate_identity("COMMIT TRANSACTION transaction_name") + # self.validate_identity("COMMIT TRANSACTION @tran_name_variable") + # self.validate_identity("COMMIT TRANSACTION @tran_name_variable WITH DELAYED_DURABILITY = ON") + + # Applies to SQL Server and Azure SQL Database + # ROLLBACK { TRAN | TRANSACTION } + # [ transaction_name | @tran_name_variable + # | savepoint_name | @savepoint_variable ] + # [ ; ] + self.validate_identity("ROLLBACK") + # self.validate_identity("ROLLBACK TRAN") + # self.validate_identity("ROLLBACK TRANSACTION") + # self.validate_identity("ROLLBACK TRANSACTION transaction_name") + # self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable") + # self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable WITH DELAYED_DURABILITY = ON") + def test_udf(self): + self.validate_identity("BEGIN") + self.validate_identity("END") + self.validate_identity("SET XACT_ABORT ON") + self.validate_identity( + "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)" + ) self.validate_identity( "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" ) @@ -417,6 +456,7 @@ def test_udf(self): pretty=True, ) + def test_fullproc(self): sql = """ CREATE procedure [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER From 474ea398836bbfd18b37b4af39696056f472cd15 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Fri, 7 Jul 2023 01:21:33 -0400 Subject: [PATCH 03/18] Make tsql BEGIN TRANSACTION parse in all it's flavors --- sqlglot/dialects/tsql.py | 56 ++++++++++++++++++++++++++++++++++--- tests/dialects/test_tsql.py | 21 ++++++++++---- 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index f33f4659df..a336edb460 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim if isinstance(expression, exp.NumberToStr) else exp.Literal.string( format_time( - expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING) + expression.text("format"), + t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING), ) ) ) @@ -311,6 +312,7 @@ class Tokenizer(tokens.Tokenizer): "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, "SYSTEM_USER": TokenType.CURRENT_USER, + "WITH": TokenType.WITH, } # TSQL allows @, # to appear as a variable/identifier prefix @@ -321,7 +323,9 @@ class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( - this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), @@ -372,6 +376,46 @@ class Parser(parser.Parser): CONCAT_NULL_OUTPUTS_STRING = True + def _parse_transaction(self) -> exp.Transaction: + this = None + if self._match_texts(self.TRANSACTION_KIND): + this = self._prev.text + + self._match_texts({"TRANSACTION", "WORK"}) + + modes = [] + while True: + mode = [] + while self._match_set( + ( + TokenType.PARAMETER, + TokenType.VAR, + TokenType.WITH, + TokenType.STRING, + TokenType.QUOTE, + ) + ): + # if self._prev.token_type == TokenType.PARAMETER: + # mode.append(f"{self._prev.text}{self._curr.text}") + if self._prev.token_type == TokenType.STRING: + mode.append(f"'{self._prev.text}'") + elif self._prev.token_type == TokenType.PARAMETER: + {} + elif ( + self._tokens[-2].token_type == TokenType.PARAMETER + and self._tokens[-1].token_type == TokenType.VAR + ): + mode.append(f"@{self._prev.text}") + else: + mode.append(self._prev.text) + + if mode: + modes.append(" ".join(mode)) + if not self._match(TokenType.BREAK): + break + + return self.expression(exp.Transaction, this=this, modes=modes) + def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): return None @@ -502,7 +546,9 @@ class Generator(generator.Generator): exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( - "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this + "HASHBYTES", + exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), + e.this, ), exp.TimeToStr: _format_sql, } @@ -544,7 +590,9 @@ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: def transaction_sql(self, expression: exp.Transaction) -> str: this = expression.this this = f" {this}" if this else "" - return f"BEGIN{this} TRANSACTION" + modes = expression.args.get("modes") + modes = f" {', '.join(modes)}" if modes else "" + return f"BEGIN{this} TRANSACTION{modes}" def commit_sql(self, expression: exp.Commit) -> str: this = expression.this diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 6e1c699c17..b7022f87cf 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -360,6 +360,15 @@ def test_types_bin(self): }, ) + def test_ddl(self): + sql = """CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24));""" + self.validate_all( + " ".join(sql.split()), + write={ + "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIMESTAMP(4), d FLOAT(24))" + }, + ) + def test_transaction(self): # BEGIN { TRAN | TRANSACTION } # [ { transaction_name | @tran_name_variable } @@ -368,9 +377,9 @@ def test_transaction(self): # [ ; ] self.validate_identity("BEGIN TRANSACTION") self.validate_all("BEGIN TRAN", write={"tsql": "BEGIN TRAN"}) - # self.validate_identity("BEGIN TRANSACTION transaction_name") - # self.validate_identity("BEGIN TRANSACTION @tran_name_variable") - # self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description") + self.validate_identity("BEGIN TRANSACTION transaction_name") + self.validate_identity("BEGIN TRANSACTION @tran_name_variable") + self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'") # COMMIT [ { TRAN | TRANSACTION } [ transaction_name | @tran_name_variable ] ] [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] [ ; ] @@ -849,7 +858,8 @@ def test_format(self): write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, ) self.validate_all( - "SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"} + "SELECT FORMAT(1234567, 'f')", + write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}, ) self.validate_all( "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')", @@ -864,7 +874,8 @@ def test_format(self): write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}, ) self.validate_all( - "SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"} + "SELECT FORMAT(num_col, 'c')", + write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}, ) def test_string(self): From dae681875a304634557537b93d6b5cc7cec9d9f7 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Mon, 10 Jul 2023 19:54:55 -0400 Subject: [PATCH 04/18] Enrich TSQL begin transaction, commit transaction, rollback transaction per tsql specifications --- sqlglot/dialects/tsql.py | 68 +++++++++++++++++++++++++++++++++++-- sqlglot/expressions.py | 12 +++++-- sqlglot/parser.py | 3 +- tests/dialects/test_tsql.py | 32 +++++++++++------ 4 files changed, 97 insertions(+), 18 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a336edb460..eeb126c9f4 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -369,6 +369,7 @@ class Parser(parser.Parser): STATEMENT_PARSERS = { **parser.Parser.STATEMENT_PARSERS, TokenType.END: lambda self: self._parse_command(), + TokenType.ALIAS: lambda self: self._parse_alias(), } LOG_BASE_FIRST = False @@ -376,6 +377,47 @@ class Parser(parser.Parser): CONCAT_NULL_OUTPUTS_STRING = True + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: + """Applies to SQL Server and Azure SQL Database + COMMIT [ { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable ] ] + [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] + [ ; ] + + ROLLBACK { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable + | savepoint_name | @savepoint_variable ] + [ ; ] + + Returns: + exp.Commit | exp.Rollback: _description_ + """ + rollback = False + if self._prev.token_type == TokenType.ROLLBACK: + rollback = True + + transaction = None + if self._match_texts({"TRAN", "TRANSACTION"}): + transaction = self._prev.text + txn_name = self._parse_id_var() + + durability = None + if self._match_text_seq("WITH", "(", "DELAYED_DURABILITY", "="): + if self._match_text_seq("OFF"): + durability = False + else: + self._match(TokenType.ON) + durability = True + + self._match_r_paren() + + if rollback: + return self.expression(exp.Rollback, this=txn_name, transaction=transaction) + + return self.expression( + exp.Commit, this=txn_name, durability=durability, transaction=transaction + ) + def _parse_transaction(self) -> exp.Transaction: this = None if self._match_texts(self.TRANSACTION_KIND): @@ -459,7 +501,7 @@ def _parse_returns(self) -> exp.ReturnsProperty: returns.set("table", table) return returns - def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: + def gnvert(self, strict: bool) -> t.Optional[exp.Expression]: to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_conjunction() @@ -594,7 +636,27 @@ def transaction_sql(self, expression: exp.Transaction) -> str: modes = f" {', '.join(modes)}" if modes else "" return f"BEGIN{this} TRANSACTION{modes}" + def _durability_sql(self, expression) -> str: + durability = expression.args.get("durability") + durability_sql = "" + if durability is not None: + if durability: + durability = "ON" + else: + durability = "OFF" + durability_sql = f" WITH (DELAYED_DURABILITY = {durability})" + return durability_sql + def commit_sql(self, expression: exp.Commit) -> str: - this = expression.this + this = self.sql(expression, "this") + this = f" {this}" if this else "" + transaction = expression.args.get("transaction") + transaction = f" {transaction}" if transaction else "" + return f"COMMIT{transaction}{this}{self._durability_sql(expression)}" + + def rollback_sql(self, expression: exp.Rollback) -> str: + this = self.sql(expression, "this") this = f" {this}" if this else "" - return f"COMMIT{this} TRANSACTION" + transaction = expression.args.get("transaction") + transaction = f" {transaction}" if transaction else "" + return f"ROLLBACK{transaction}{this}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index fdf02c8184..1da01df8dc 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3464,11 +3464,17 @@ class Transaction(Expression): class Commit(Expression): - arg_types = {"chain": False} + arg_types = { + "chain": False, + "this": False, + "transaction": False, + "durability": False, + "modes": False, + } class Rollback(Expression): - arg_types = {"savepoint": False} + arg_types = {"savepoint": False, "this": False, "transaction": False} class AlterTable(Expression): @@ -3944,7 +3950,7 @@ def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case: class Cast(Func): - arg_types = {"this": True, "to": True, "format": False} + arg_types = {"this": True, "to": False, "format": False} @property def name(self) -> str: diff --git a/sqlglot/parser.py b/sqlglot/parser.py index b1d22babf6..7e199fb29d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1205,7 +1205,8 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None: extend_props(self._parse_properties()) self._match(TokenType.ALIAS) - begin = self._match_text_seq("BEGIN") # _match(TokenType.BEGIN) + begin = self._match_text_seq("BEGIN") + # begin = self._match(TokenType.BEGIN) return_ = self._match_text_seq("RETURN") expression = self._parse_statement() diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b7022f87cf..59feea9d90 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -369,6 +369,9 @@ def test_ddl(self): }, ) + def test_convert(self): + self.validate_identity("CONVERT(INT, CONVERT(NUMERIC, '444.75'))") + def test_transaction(self): # BEGIN { TRAN | TRANSACTION } # [ { transaction_name | @tran_name_variable } @@ -376,31 +379,38 @@ def test_transaction(self): # ] # [ ; ] self.validate_identity("BEGIN TRANSACTION") - self.validate_all("BEGIN TRAN", write={"tsql": "BEGIN TRAN"}) + self.validate_identity("BEGIN TRAN") self.validate_identity("BEGIN TRANSACTION transaction_name") self.validate_identity("BEGIN TRANSACTION @tran_name_variable") self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'") + def test_commit(self): # COMMIT [ { TRAN | TRANSACTION } [ transaction_name | @tran_name_variable ] ] [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] [ ; ] - # self.validate_identity("COMMIT") - # self.validate_identity("COMMIT TRAN") + self.validate_identity("COMMIT") + self.validate_all("COMMIT TRAN") self.validate_identity("COMMIT TRANSACTION") - # self.validate_identity("COMMIT TRANSACTION transaction_name") - # self.validate_identity("COMMIT TRANSACTION @tran_name_variable") - # self.validate_identity("COMMIT TRANSACTION @tran_name_variable WITH DELAYED_DURABILITY = ON") + self.validate_identity("COMMIT TRANSACTION transaction_name") + self.validate_identity("COMMIT TRANSACTION @tran_name_variable") + + self.validate_identity( + "COMMIT TRANSACTION @tran_name_variable WITH (DELAYED_DURABILITY = ON)" + ) + self.validate_identity( + "COMMIT TRANSACTION transaction_name WITH (DELAYED_DURABILITY = OFF)" + ) + def test_rollback(self): # Applies to SQL Server and Azure SQL Database # ROLLBACK { TRAN | TRANSACTION } # [ transaction_name | @tran_name_variable # | savepoint_name | @savepoint_variable ] # [ ; ] self.validate_identity("ROLLBACK") - # self.validate_identity("ROLLBACK TRAN") - # self.validate_identity("ROLLBACK TRANSACTION") - # self.validate_identity("ROLLBACK TRANSACTION transaction_name") - # self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable") - # self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable WITH DELAYED_DURABILITY = ON") + self.validate_all("ROLLBACK TRAN") + self.validate_identity("ROLLBACK TRANSACTION") + self.validate_identity("ROLLBACK TRANSACTION transaction_name") + self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable") def test_udf(self): self.validate_identity("BEGIN") From 95d14b132f91e048f7e7ade77155efab340f6b37 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Mon, 10 Jul 2023 20:01:43 -0400 Subject: [PATCH 05/18] undo fat fingering of _parse_convert --- sqlglot/dialects/tsql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index eeb126c9f4..ff2fbc846c 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -501,7 +501,7 @@ def _parse_returns(self) -> exp.ReturnsProperty: returns.set("table", table) return returns - def gnvert(self, strict: bool) -> t.Optional[exp.Expression]: + def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_conjunction() From c9e6ad2318ed50e6d7bb3d920a9229dc20cbb377 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Mon, 10 Jul 2023 20:04:36 -0400 Subject: [PATCH 06/18] deal with merge conflict on return_sql --- sqlglot/dialects/tsql.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index ff2fbc846c..a87a2cf836 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -628,6 +628,11 @@ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: table = expression.args.get("table") table = f"{table} " if table else "" return f"RETURNS {table}{self.sql(expression, 'this')}" + + def returning_sql(self, expression: exp.Returning) -> str: + into = self.sql(expression, "into") + into = self.seg(f"INTO {into}") if into else "" + return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}" def transaction_sql(self, expression: exp.Transaction) -> str: this = expression.this From 6ba0901a72400f5f727a636181ae63dbdf1631a3 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Mon, 10 Jul 2023 20:58:03 -0400 Subject: [PATCH 07/18] re-write BEGIN TRANSACTION --- sqlglot/dialects/tsql.py | 66 +++++++++++++------------------------ sqlglot/expressions.py | 2 +- tests/dialects/test_tsql.py | 2 +- 3 files changed, 25 insertions(+), 45 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a87a2cf836..86de4838cd 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -419,44 +419,24 @@ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: ) def _parse_transaction(self) -> exp.Transaction: - this = None - if self._match_texts(self.TRANSACTION_KIND): - this = self._prev.text - - self._match_texts({"TRANSACTION", "WORK"}) - - modes = [] - while True: - mode = [] - while self._match_set( - ( - TokenType.PARAMETER, - TokenType.VAR, - TokenType.WITH, - TokenType.STRING, - TokenType.QUOTE, - ) - ): - # if self._prev.token_type == TokenType.PARAMETER: - # mode.append(f"{self._prev.text}{self._curr.text}") - if self._prev.token_type == TokenType.STRING: - mode.append(f"'{self._prev.text}'") - elif self._prev.token_type == TokenType.PARAMETER: - {} - elif ( - self._tokens[-2].token_type == TokenType.PARAMETER - and self._tokens[-1].token_type == TokenType.VAR - ): - mode.append(f"@{self._prev.text}") - else: - mode.append(self._prev.text) - - if mode: - modes.append(" ".join(mode)) - if not self._match(TokenType.BREAK): - break - - return self.expression(exp.Transaction, this=this, modes=modes) + """Syntax: + BEGIN { TRAN | TRANSACTION } + [ { transaction_name | @tran_name_variable } + [ WITH MARK [ 'description' ] ] + ] + [ ; ] + + Returns: + exp.Transaction: _description_ + """ + mark = None + + txn_name = self._parse_id_var() + + if self._match_text_seq("WITH", "MARK"): + mark = self._parse_string() + + return self.expression(exp.Transaction, this=txn_name, mark=mark) def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): @@ -628,18 +608,18 @@ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: table = expression.args.get("table") table = f"{table} " if table else "" return f"RETURNS {table}{self.sql(expression, 'this')}" - + def returning_sql(self, expression: exp.Returning) -> str: into = self.sql(expression, "into") into = self.seg(f"INTO {into}") if into else "" return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}" def transaction_sql(self, expression: exp.Transaction) -> str: - this = expression.this + this = self.sql(expression, "this") this = f" {this}" if this else "" - modes = expression.args.get("modes") - modes = f" {', '.join(modes)}" if modes else "" - return f"BEGIN{this} TRANSACTION{modes}" + mark = expression.args.get("mark") + mark = f" WITH MARK {mark}" if mark else "" + return f"BEGIN TRANSACTION{this}{mark}" def _durability_sql(self, expression) -> str: durability = expression.args.get("durability") diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1da01df8dc..4126a3cd92 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3460,7 +3460,7 @@ class Command(Expression): class Transaction(Expression): - arg_types = {"this": False, "modes": False} + arg_types = {"this": False, "modes": False, "mark": False} class Commit(Expression): diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 59feea9d90..d6d457e24d 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -379,7 +379,7 @@ def test_transaction(self): # ] # [ ; ] self.validate_identity("BEGIN TRANSACTION") - self.validate_identity("BEGIN TRAN") + # self.validate_identity("BEGIN TRAN") self.validate_identity("BEGIN TRANSACTION transaction_name") self.validate_identity("BEGIN TRANSACTION @tran_name_variable") self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'") From a097ea71ae323fa5f1ae7a10bb8b110b462b9c6a Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Mon, 10 Jul 2023 21:26:03 -0400 Subject: [PATCH 08/18] remove extraineous CONVERT unit test --- tests/dialects/test_tsql.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d6d457e24d..a45da0e96e 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -369,9 +369,6 @@ def test_ddl(self): }, ) - def test_convert(self): - self.validate_identity("CONVERT(INT, CONVERT(NUMERIC, '444.75'))") - def test_transaction(self): # BEGIN { TRAN | TRANSACTION } # [ { transaction_name | @tran_name_variable } From ca3d6d2bbb01cbb055ee2f4ac92b2ca4e8a4312f Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Mon, 10 Jul 2023 21:34:16 -0400 Subject: [PATCH 09/18] run make check --- sqlglot/dialects/tsql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 4bec7c1869..8e6289dca6 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -646,4 +646,4 @@ def rollback_sql(self, expression: exp.Rollback) -> str: this = f" {this}" if this else "" transaction = expression.args.get("transaction") transaction = f" {transaction}" if transaction else "" - return f"ROLLBACK{transaction}{this}" \ No newline at end of file + return f"ROLLBACK{transaction}{this}" From 9856cee428aab56e4fa11a16b10f195edc7706f2 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 14:58:33 -0400 Subject: [PATCH 10/18] PR review comments update, still work in progress --- sqlglot/dialects/tsql.py | 47 ++++++++++++++----------------------- sqlglot/expressions.py | 2 +- tests/dialects/test_tsql.py | 14 +++++------ 3 files changed, 26 insertions(+), 37 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 8e6289dca6..a90a945f67 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -282,15 +282,11 @@ class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] HEX_STRINGS = [("0x", ""), ("0X", "")] - COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON} - KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "BEGIN": TokenType.COMMAND, + # "BEGIN TRAN": TokenType.BEGIN, "BEGIN TRANSACTION": TokenType.BEGIN, - "COMMIT": TokenType.COMMIT, - "ROLLBACK": TokenType.ROLLBACK, - # "END": TokenType.END, "SET": TokenType.SET, "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, @@ -370,7 +366,6 @@ class Parser(parser.Parser): STATEMENT_PARSERS = { **parser.Parser.STATEMENT_PARSERS, TokenType.END: lambda self: self._parse_command(), - TokenType.ALIAS: lambda self: self._parse_alias(), } LOG_BASE_FIRST = False @@ -393,17 +388,15 @@ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: Returns: exp.Commit | exp.Rollback: _description_ """ - rollback = False - if self._prev.token_type == TokenType.ROLLBACK: - rollback = True + rollback = self._prev.token_type == TokenType.ROLLBACK - transaction = None - if self._match_texts({"TRAN", "TRANSACTION"}): - transaction = self._prev.text + transaction = self._match_texts({"TRAN", "TRANSACTION"}) and self._prev.text txn_name = self._parse_id_var() durability = None - if self._match_text_seq("WITH", "(", "DELAYED_DURABILITY", "="): + if self._match_pair(TokenType.WITH, TokenType.L_PAREN): + self._match_text_seq("DELAYED_DURABILITY") + self._match(TokenType.EQ) if self._match_text_seq("OFF"): durability = False else: @@ -618,32 +611,28 @@ def returning_sql(self, expression: exp.Returning) -> str: def transaction_sql(self, expression: exp.Transaction) -> str: this = self.sql(expression, "this") + # this = None if this in "TRAN" else this this = f" {this}" if this else "" - mark = expression.args.get("mark") + mark = self.sql(expression, "mark") mark = f" WITH MARK {mark}" if mark else "" return f"BEGIN TRANSACTION{this}{mark}" - def _durability_sql(self, expression) -> str: - durability = expression.args.get("durability") - durability_sql = "" - if durability is not None: - if durability: - durability = "ON" - else: - durability = "OFF" - durability_sql = f" WITH (DELAYED_DURABILITY = {durability})" - return durability_sql - def commit_sql(self, expression: exp.Commit) -> str: this = self.sql(expression, "this") this = f" {this}" if this else "" - transaction = expression.args.get("transaction") + transaction = self.sql(expression, "transaction") transaction = f" {transaction}" if transaction else "" - return f"COMMIT{transaction}{this}{self._durability_sql(expression)}" + durability = expression.args.get("durability") + durability = ( + f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})" + if durability is not None + else "" + ) + return f"COMMIT TRANSACTION{this}{durability}" def rollback_sql(self, expression: exp.Rollback) -> str: this = self.sql(expression, "this") this = f" {this}" if this else "" - transaction = expression.args.get("transaction") + transaction = self.sql(expression, "transaction") transaction = f" {transaction}" if transaction else "" - return f"ROLLBACK{transaction}{this}" + return f"ROLLBACK TRANSACTION{this}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 44a7061f5c..00ed7832f9 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3966,7 +3966,7 @@ def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case: class Cast(Func): - arg_types = {"this": True, "to": False, "format": False} + arg_types = {"this": True, "to": True, "format": False} @property def name(self) -> str: diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 776bab0523..54218cb86c 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -390,9 +390,8 @@ def test_types_bin(self): ) def test_ddl(self): - sql = """CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24));""" self.validate_all( - " ".join(sql.split()), + "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", write={ "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIMESTAMP(4), d FLOAT(24))" }, @@ -405,7 +404,8 @@ def test_transaction(self): # ] # [ ; ] self.validate_identity("BEGIN TRANSACTION") - # self.validate_identity("BEGIN TRAN") + self.validate_identity("BEGIN TRAN") + # self.validate_all("BEGIN TRAN", write={"tsql":"BEGIN TRANSACTION"}) self.validate_identity("BEGIN TRANSACTION transaction_name") self.validate_identity("BEGIN TRANSACTION @tran_name_variable") self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'") @@ -413,8 +413,8 @@ def test_transaction(self): def test_commit(self): # COMMIT [ { TRAN | TRANSACTION } [ transaction_name | @tran_name_variable ] ] [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] [ ; ] - self.validate_identity("COMMIT") - self.validate_all("COMMIT TRAN") + self.validate_all("COMMIT", write={"tsql": "COMMIT TRANSACTION"}) + self.validate_all("COMMIT TRAN", write={"tsql": "COMMIT TRANSACTION"}) self.validate_identity("COMMIT TRANSACTION") self.validate_identity("COMMIT TRANSACTION transaction_name") self.validate_identity("COMMIT TRANSACTION @tran_name_variable") @@ -432,8 +432,8 @@ def test_rollback(self): # [ transaction_name | @tran_name_variable # | savepoint_name | @savepoint_variable ] # [ ; ] - self.validate_identity("ROLLBACK") - self.validate_all("ROLLBACK TRAN") + self.validate_all("ROLLBACK", write={"tsql": "ROLLBACK TRANSACTION"}) + self.validate_all("ROLLBACK TRAN", write={"tsql": "ROLLBACK TRANSACTION"}) self.validate_identity("ROLLBACK TRANSACTION") self.validate_identity("ROLLBACK TRANSACTION transaction_name") self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable") From cb07d9cbf0d2bbb2dbf326b92b1182fd8f2fa8a2 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 15:04:10 -0400 Subject: [PATCH 11/18] Remove BEGIN TRANSACTION keyword --- sqlglot/dialects/tsql.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a90a945f67..5b657db681 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -285,8 +285,6 @@ class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "BEGIN": TokenType.COMMAND, - # "BEGIN TRAN": TokenType.BEGIN, - "BEGIN TRANSACTION": TokenType.BEGIN, "SET": TokenType.SET, "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, From 4b929a02470c3454fac23767c025bb961b06efa4 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 16:20:51 -0400 Subject: [PATCH 12/18] Remove transaction parameter --- sqlglot/dialects/tsql.py | 13 +++---------- sqlglot/expressions.py | 3 +-- tests/dialects/test_tsql.py | 1 - 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 5b657db681..724c066529 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -388,7 +388,7 @@ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: """ rollback = self._prev.token_type == TokenType.ROLLBACK - transaction = self._match_texts({"TRAN", "TRANSACTION"}) and self._prev.text + self._match_texts({"TRAN", "TRANSACTION"}) txn_name = self._parse_id_var() durability = None @@ -404,11 +404,9 @@ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: self._match_r_paren() if rollback: - return self.expression(exp.Rollback, this=txn_name, transaction=transaction) + return self.expression(exp.Rollback, this=txn_name) - return self.expression( - exp.Commit, this=txn_name, durability=durability, transaction=transaction - ) + return self.expression(exp.Commit, this=txn_name, durability=durability) def _parse_transaction(self) -> exp.Transaction: """Syntax: @@ -609,7 +607,6 @@ def returning_sql(self, expression: exp.Returning) -> str: def transaction_sql(self, expression: exp.Transaction) -> str: this = self.sql(expression, "this") - # this = None if this in "TRAN" else this this = f" {this}" if this else "" mark = self.sql(expression, "mark") mark = f" WITH MARK {mark}" if mark else "" @@ -618,8 +615,6 @@ def transaction_sql(self, expression: exp.Transaction) -> str: def commit_sql(self, expression: exp.Commit) -> str: this = self.sql(expression, "this") this = f" {this}" if this else "" - transaction = self.sql(expression, "transaction") - transaction = f" {transaction}" if transaction else "" durability = expression.args.get("durability") durability = ( f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})" @@ -631,6 +626,4 @@ def commit_sql(self, expression: exp.Commit) -> str: def rollback_sql(self, expression: exp.Rollback) -> str: this = self.sql(expression, "this") this = f" {this}" if this else "" - transaction = self.sql(expression, "transaction") - transaction = f" {transaction}" if transaction else "" return f"ROLLBACK TRANSACTION{this}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 00ed7832f9..1fbcca3f61 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3479,14 +3479,13 @@ class Commit(Expression): arg_types = { "chain": False, "this": False, - "transaction": False, "durability": False, "modes": False, } class Rollback(Expression): - arg_types = {"savepoint": False, "this": False, "transaction": False} + arg_types = {"savepoint": False, "this": False} class AlterTable(Expression): diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 54218cb86c..9ab8c4ac74 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -405,7 +405,6 @@ def test_transaction(self): # [ ; ] self.validate_identity("BEGIN TRANSACTION") self.validate_identity("BEGIN TRAN") - # self.validate_all("BEGIN TRAN", write={"tsql":"BEGIN TRANSACTION"}) self.validate_identity("BEGIN TRANSACTION transaction_name") self.validate_identity("BEGIN TRANSACTION @tran_name_variable") self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'") From 064125a8f1db345ef2a125de1f2c217ad1bca484 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 17:34:33 -0400 Subject: [PATCH 13/18] Remove modes from Commit interface --- sqlglot/expressions.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1fbcca3f61..e4edeed1cf 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3476,12 +3476,7 @@ class Transaction(Expression): class Commit(Expression): - arg_types = { - "chain": False, - "this": False, - "durability": False, - "modes": False, - } + arg_types = {"chain": False, "this": False, "durability": False} class Rollback(Expression): From bcc52630165918ba54e351b673a68eb38f849026 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 18:51:11 -0400 Subject: [PATCH 14/18] move CREATE PROCEDURE keyword tests, remove TokenType.BEGIN from tsql KEYWORDS list, however change signature of _parse_transaction to reflect returning a Begin expression --- sqlglot/dialects/tsql.py | 31 +++++++++++++++++++------------ sqlglot/parser.py | 3 +-- tests/dialects/test_tsql.py | 10 ++++++---- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 724c066529..6490306f7d 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -284,7 +284,7 @@ class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "BEGIN": TokenType.COMMAND, + # "BEGIN": TokenType.COMMAND, "SET": TokenType.SET, "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, @@ -408,7 +408,7 @@ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: return self.expression(exp.Commit, this=txn_name, durability=durability) - def _parse_transaction(self) -> exp.Transaction: + def _parse_transaction(self) -> exp.Transaction | exp.Command: """Syntax: BEGIN { TRAN | TRANSACTION } [ { transaction_name | @tran_name_variable } @@ -419,14 +419,16 @@ def _parse_transaction(self) -> exp.Transaction: Returns: exp.Transaction: _description_ """ - mark = None + if self._match_texts(("TRAN", "TRANSACTION")): + # we have a transaction and not a BEGIN + txn_name = self._parse_id_var() - txn_name = self._parse_id_var() - - if self._match_text_seq("WITH", "MARK"): - mark = self._parse_string() + mark = None + if self._match_text_seq("WITH", "MARK"): + mark = self._parse_string() - return self.expression(exp.Transaction, this=txn_name, mark=mark) + return self.expression(exp.Transaction, this=txn_name, mark=mark) + return self.expression(exp.Command, this="BEGIN") def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): @@ -607,10 +609,15 @@ def returning_sql(self, expression: exp.Returning) -> str: def transaction_sql(self, expression: exp.Transaction) -> str: this = self.sql(expression, "this") - this = f" {this}" if this else "" - mark = self.sql(expression, "mark") - mark = f" WITH MARK {mark}" if mark else "" - return f"BEGIN TRANSACTION{this}{mark}" + if this in ["TRAN", "TRANSACTION"]: + mark = self.sql(expression, "mark") + mark = f" WITH MARK {mark}" if mark else "" + return f"BEGIN TRANSACTION{mark}" + else: + this = f" {this}" if this else "" + mark = self.sql(expression, "mark") + mark = f" WITH MARK {mark}" if mark else "" + return f"BEGIN TRANSACTION{this}{mark}" def commit_sql(self, expression: exp.Commit) -> str: this = self.sql(expression, "this") diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 38e35f0f25..68e060bcdb 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1206,7 +1206,6 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None: self._match(TokenType.ALIAS) begin = self._match_text_seq("BEGIN") - # begin = self._match(TokenType.BEGIN) return_ = self._match_text_seq("RETURN") expression = self._parse_statement() @@ -4352,7 +4351,7 @@ def _parse_ddl_select(self) -> t.Optional[exp.Expression]: self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False)) ) - def _parse_transaction(self) -> exp.Transaction: + def _parse_transaction(self) -> exp.Transaction | exp.Command: this = None if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 9ab8c4ac74..5fa881b616 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -404,7 +404,7 @@ def test_transaction(self): # ] # [ ; ] self.validate_identity("BEGIN TRANSACTION") - self.validate_identity("BEGIN TRAN") + self.validate_all("BEGIN TRAN", write={"tsql": "BEGIN TRANSACTION"}) self.validate_identity("BEGIN TRANSACTION transaction_name") self.validate_identity("BEGIN TRANSACTION @tran_name_variable") self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'") @@ -438,9 +438,6 @@ def test_rollback(self): self.validate_identity("ROLLBACK TRANSACTION @tran_name_variable") def test_udf(self): - self.validate_identity("BEGIN") - self.validate_identity("END") - self.validate_identity("SET XACT_ABORT ON") self.validate_identity( "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)" ) @@ -500,6 +497,11 @@ def test_udf(self): pretty=True, ) + def test_procedure_keywords(self): + self.validate_identity("BEGIN") + self.validate_identity("END") + self.validate_identity("SET XACT_ABORT ON") + def test_fullproc(self): sql = """ CREATE procedure [TRANSF].[SP_Merge_Sales_Real] From 3fcd22a741cf8bb2bf4028f7021b82e12af1d9ee Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 19:16:06 -0400 Subject: [PATCH 15/18] Remove redundant tokens in tsql KEYWORDS, remove description from method comments --- sqlglot/dialects/tsql.py | 39 +++++++++++++++------------------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 6490306f7d..d18647ce60 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -284,8 +284,6 @@ class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - # "BEGIN": TokenType.COMMAND, - "SET": TokenType.SET, "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DECLARE": TokenType.COMMAND, @@ -307,7 +305,6 @@ class Tokenizer(tokens.Tokenizer): "XML": TokenType.XML, "OUTPUT": TokenType.RETURNING, "SYSTEM_USER": TokenType.CURRENT_USER, - "WITH": TokenType.WITH, } # TSQL allows @, # to appear as a variable/identifier prefix @@ -373,18 +370,15 @@ class Parser(parser.Parser): def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: """Applies to SQL Server and Azure SQL Database - COMMIT [ { TRAN | TRANSACTION } - [ transaction_name | @tran_name_variable ] ] - [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] - [ ; ] - - ROLLBACK { TRAN | TRANSACTION } - [ transaction_name | @tran_name_variable - | savepoint_name | @savepoint_variable ] - [ ; ] - - Returns: - exp.Commit | exp.Rollback: _description_ + COMMIT [ { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable ] ] + [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] + [ ; ] + + ROLLBACK { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable + | savepoint_name | @savepoint_variable ] + [ ; ] """ rollback = self._prev.token_type == TokenType.ROLLBACK @@ -409,15 +403,12 @@ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: return self.expression(exp.Commit, this=txn_name, durability=durability) def _parse_transaction(self) -> exp.Transaction | exp.Command: - """Syntax: - BEGIN { TRAN | TRANSACTION } - [ { transaction_name | @tran_name_variable } - [ WITH MARK [ 'description' ] ] - ] - [ ; ] - - Returns: - exp.Transaction: _description_ + """Applies to SQL Server and Azure SQL Database + BEGIN { TRAN | TRANSACTION } + [ { transaction_name | @tran_name_variable } + [ WITH MARK [ 'description' ] ] + ] + [ ; ] """ if self._match_texts(("TRAN", "TRANSACTION")): # we have a transaction and not a BEGIN From 9579470c3ededf15b27dc148831e302f2e84292d Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 19:25:02 -0400 Subject: [PATCH 16/18] tighten up code for _parse_transaction per feedback --- sqlglot/dialects/tsql.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index d18647ce60..acbd3d23c1 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -412,13 +412,11 @@ def _parse_transaction(self) -> exp.Transaction | exp.Command: """ if self._match_texts(("TRAN", "TRANSACTION")): # we have a transaction and not a BEGIN - txn_name = self._parse_id_var() - - mark = None + transaction = self.expression(exp.Transaction, this=self._parse_id_var()) if self._match_text_seq("WITH", "MARK"): - mark = self._parse_string() + transaction.set("mark", self._parse_string()) + return transaction - return self.expression(exp.Transaction, this=txn_name, mark=mark) return self.expression(exp.Command, this="BEGIN") def _parse_system_time(self) -> t.Optional[exp.Expression]: From e92e7eb7cf28ec21a7326231120b9cd4b96c9b2e Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 19:34:13 -0400 Subject: [PATCH 17/18] Revert to matching on TokenType.BEGIN --- sqlglot/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 68e060bcdb..166e3b8e66 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1205,7 +1205,7 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None: extend_props(self._parse_properties()) self._match(TokenType.ALIAS) - begin = self._match_text_seq("BEGIN") + begin = self._match(TokenType.BEGIN) return_ = self._match_text_seq("RETURN") expression = self._parse_statement() From 19f9af938363b23388d54d6118261b7c7ea944b7 Mon Sep 17 00:00:00 2001 From: Douglas Moore Date: Tue, 11 Jul 2023 19:43:57 -0400 Subject: [PATCH 18/18] Handle BEGIN as a command --- sqlglot/dialects/tsql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index acbd3d23c1..efb49bd9a9 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -417,7 +417,7 @@ def _parse_transaction(self) -> exp.Transaction | exp.Command: transaction.set("mark", self._parse_string()) return transaction - return self.expression(exp.Command, this="BEGIN") + return self._parse_as_command(self._prev) def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"):