Skip to content

Commit

Permalink
Fix: options inside of bigquery struct closes #1562
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 6, 2023
1 parent fb819f0 commit 7b09bff
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 22 deletions.
10 changes: 9 additions & 1 deletion sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ class Generator(Hive.Generator):
exp.Create: _create_sql,
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
Expand Down Expand Up @@ -222,5 +221,14 @@ def cast_sql(self, expression: exp.Cast) -> str:

return super(Hive.Generator, self).cast_sql(expression)

def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(
expression,
sep=": "
if isinstance(expression.parent, exp.DataType)
and expression.parent.is_type(exp.DataType.Type.STRUCT)
else sep,
)

class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]
4 changes: 0 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,10 +3097,6 @@ class PseudoType(Expression):
pass


class StructKwarg(Expression):
arg_types = {"this": True, "expression": True}


# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
class SubqueryPredicate(Predicate):
pass
Expand Down
7 changes: 2 additions & 5 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,12 +534,12 @@ def columnposition_sql(self, expression: exp.ColumnPosition) -> str:
position = self.sql(expression, "position")
return f"{position}{this}"

def columndef_sql(self, expression: exp.ColumnDef) -> str:
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
kind = f" {kind}" if kind else ""
kind = f"{sep}{kind}" if kind else ""
constraints = f" {constraints}" if constraints else ""
position = self.sql(expression, "position")
position = f" {position}" if position else ""
Expand Down Expand Up @@ -1510,9 +1510,6 @@ def star_sql(self, expression: exp.Star) -> str:
replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}"

def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"

def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
Expand Down
21 changes: 9 additions & 12 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,7 @@ def _parse_returns(self) -> exp.Expression:
value = self.expression(
exp.Schema,
this="TABLE",
expressions=self._parse_csv(self._parse_struct_kwargs),
expressions=self._parse_csv(self._parse_struct_types),
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
Expand Down Expand Up @@ -2802,7 +2802,7 @@ def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:

if self._match(TokenType.L_PAREN):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
expressions = self._parse_csv(self._parse_struct_types)
elif nested:
expressions = self._parse_csv(self._parse_types)
else:
Expand Down Expand Up @@ -2837,7 +2837,7 @@ def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
expressions = self._parse_csv(self._parse_struct_types)
else:
expressions = self._parse_csv(self._parse_types)

Expand Down Expand Up @@ -2895,16 +2895,10 @@ def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
prefix=prefix,
)

def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
index = self._index
this = self._parse_id_var()
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
this = self._parse_type()
self._match(TokenType.COLON)
data_type = self._parse_types()

if not data_type:
self._retreat(index)
return self._parse_types()
return self.expression(exp.StructKwarg, this=this, expression=data_type)
return self._parse_column_def(this)

def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.AT_TIME_ZONE):
Expand Down Expand Up @@ -3178,6 +3172,9 @@ def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[e
return self.expression(exp.Schema, this=this, expressions=args)

def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
# column defs are not really columns, they're identifiers
if isinstance(this, exp.Column):
this = this.this
kind = self._parse_types()

if self._match_text_seq("FOR", "ORDINALITY"):
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_bigquery(self):
self.validate_identity(
"""CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)"""
)
self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""")
self.validate_identity(
"SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))"
)
Expand Down

0 comments on commit 7b09bff

Please sign in to comment.